|
| 1 | +import os |
| 2 | +import re |
| 3 | +import sys |
| 4 | + |
| 5 | +from mathruler.grader import extract_boxed_content, grade_answer |
| 6 | + |
| 7 | +from areal.api.cli_args import GRPOConfig, load_expr_config |
| 8 | +from areal.dataset import get_custom_dataset |
| 9 | +from areal.experimental.trainer import PPOTrainer |
| 10 | +from areal.utils.hf_utils import load_hf_processor_and_tokenizer |
| 11 | +from areal.utils.stats_logger import StatsLogger |
| 12 | +from areal.workflow.vision_rlvr import VisionRLVRWorkflow |
| 13 | + |
| 14 | + |
| 15 | +def format_reward(predict_str: str) -> float: |
| 16 | + pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL) |
| 17 | + match_result = re.fullmatch(pattern, predict_str) |
| 18 | + return 1.0 if match_result else 0.0 |
| 19 | + |
| 20 | + |
| 21 | +def acc_reward(predict_str: str, ground_truth: str) -> float: |
| 22 | + answer = extract_boxed_content(predict_str) |
| 23 | + return 1.0 if grade_answer(answer, ground_truth) else 0.0 |
| 24 | + |
| 25 | + |
| 26 | +def geometry3k_reward_fn( |
| 27 | + prompt, completions, prompt_ids, completion_ids, answer, **kwargs |
| 28 | +): |
| 29 | + format_reward_val = format_reward(completions) |
| 30 | + acc_reward_val = acc_reward(completions, answer) |
| 31 | + format_score = 0.1 |
| 32 | + score = (1.0 - format_score) * (acc_reward_val) + format_score * format_reward_val |
| 33 | + return score |
| 34 | + |
| 35 | + |
| 36 | +def main(args): |
| 37 | + config, _ = load_expr_config(args, GRPOConfig) |
| 38 | + processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path) |
| 39 | + |
| 40 | + train_dataset = get_custom_dataset( |
| 41 | + split="train", |
| 42 | + dataset_config=config.train_dataset, |
| 43 | + tokenizer=tokenizer, |
| 44 | + processor=processor, |
| 45 | + ) |
| 46 | + |
| 47 | + valid_dataset = get_custom_dataset( |
| 48 | + split="test", |
| 49 | + dataset_config=config.valid_dataset, |
| 50 | + tokenizer=tokenizer, |
| 51 | + processor=processor, |
| 52 | + ) |
| 53 | + |
| 54 | + with PPOTrainer( |
| 55 | + config, |
| 56 | + train_dataset=train_dataset, |
| 57 | + valid_dataset=valid_dataset, |
| 58 | + ) as trainer: |
| 59 | + workflow = VisionRLVRWorkflow( |
| 60 | + reward_fn=geometry3k_reward_fn, |
| 61 | + gconfig=config.gconfig, |
| 62 | + tokenizer=trainer.tokenizer, |
| 63 | + processor=trainer.processor, |
| 64 | + enable_thinking=False, |
| 65 | + dump_dir=os.path.join( |
| 66 | + StatsLogger.get_log_path(config.stats_logger), "generated" |
| 67 | + ), |
| 68 | + ) |
| 69 | + eval_workflow = VisionRLVRWorkflow( |
| 70 | + reward_fn=geometry3k_reward_fn, |
| 71 | + gconfig=config.gconfig.new(temperature=0.6), |
| 72 | + tokenizer=trainer.tokenizer, |
| 73 | + processor=trainer.processor, |
| 74 | + enable_thinking=False, |
| 75 | + rollout_stat_scope="eval-rollout", |
| 76 | + dump_dir=os.path.join( |
| 77 | + StatsLogger.get_log_path(config.stats_logger), "generated-eval" |
| 78 | + ), |
| 79 | + ) |
| 80 | + trainer.train(workflow, eval_workflow) |
| 81 | + |
| 82 | + |
| 83 | +if __name__ == "__main__": |
| 84 | + main(sys.argv[1:]) |
0 commit comments