Created
May 14, 2025 08:28
-
-
Save Raibows/abb3c77245a682ee5567d7d93075baa2 to your computer and use it in GitHub Desktop.
rule-based math extraction and verify
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import json | |
| from dataclasses import dataclass | |
| from multiprocessing import Pool | |
| from collections import defaultdict | |
| from datetime import datetime | |
| from math_verify import math_metric, LatexExtractionConfig, ExprExtractionConfig, StringExtractionConfig, parse, verify | |
| from tqdm import tqdm | |
| from simpleArgParser import parse_args | |
| # before run, please install the following | |
| """ | |
| pip install math-verify[antlr4_13_2] simpleArgParser tqdm | |
| """ | |
| def tools_json_load(path) -> dict | list: | |
| with open(path, 'r') as file: | |
| return json.load(file) | |
| def tools_json_dump(obj, path): | |
| with open(path, 'w') as file: | |
| json.dump(obj, file, indent=4) | |
| def tools_get_time() -> str: | |
| import pytz | |
| ZONE = pytz.timezone("US/Eastern") | |
| return datetime.now(ZONE).strftime("%y-%m-%d-%H_%M_%S") | |
| def tools_elapsed_time(previous_time_str: str) -> str: | |
| previous_dt = datetime.strptime(previous_time_str, "%y-%m-%d-%H_%M_%S") | |
| current_dt = datetime.strptime(tools_get_time(), "%y-%m-%d-%H_%M_%S") | |
| delta = current_dt - previous_dt | |
| days = delta.days | |
| seconds = delta.seconds | |
| hours = seconds // 3600 | |
| minutes = (seconds % 3600) // 60 | |
| seconds = (seconds % 3600) % 60 | |
| return f"{days} days, {hours} hours, {minutes} minutes, {seconds} seconds" | |
| @dataclass | |
| class Args: | |
| # the input results folder | |
| input: str | |
| # task name or note, e.g., gsm8k / aime25 / math and etc. | |
| task: str | |
| output: str | None = None | |
| limit: int | None = None | |
| save: bool = True | |
| def post_process(self): | |
| if self.limit is None: | |
| self.limit = int(1e18) | |
| def split_thinking(response: str) -> tuple[str, str]: | |
| "return (thinking, conclusion)" | |
| temp = response.split('</think>') | |
| # actually, for normal llms, it should be treated as conclusion instead of thinking | |
| if len(temp) == 1: | |
| return "", response | |
| else: | |
| return temp[0], temp[1] | |
| def init_worker(): | |
| global verify_func | |
| verify_func = math_metric( | |
| gold_extraction_target=(LatexExtractionConfig(), ExprExtractionConfig()), | |
| pred_extraction_target=(LatexExtractionConfig(), ExprExtractionConfig()), | |
| aggregation_function=max, | |
| precision=6 | |
| ) | |
| def process_item(args_tuple: tuple) -> tuple[str, list[dict], int, int]: | |
| uuid, item = args_tuple | |
| verify_extracted = [] | |
| correct = total = 0 | |
| raw_gold = item['original']['solution'] | |
| pred_ans = None | |
| grade = 0 | |
| for raw_generated in item['generated']: | |
| try: | |
| thinking, gen_answer = split_thinking(raw_generated) | |
| grade, extracted = verify_func([raw_gold], [gen_answer]) | |
| if extracted is not None: | |
| gold_ans = extracted[0][0] | |
| pred_ans = extracted[1][0] | |
| else: | |
| gold_ans = pred_ans = None | |
| if grade != 1 and pred_ans and pred_ans.lower() in set('abcdef'): | |
| gold_ans = parse(raw_gold, extraction_config=[StringExtractionConfig()])[0] | |
| grade = int(verify(gold_ans, pred_ans)) | |
| if grade == 1: | |
| correct += 1 | |
| except Exception as e: | |
| pass | |
| # print(f"[{uuid}] error:", e) | |
| finally: | |
| total += 1 | |
| verify_extracted.append({ | |
| 'grade': grade, | |
| 'gold': gold_ans, | |
| 'generated': pred_ans, | |
| }) | |
| return uuid, verify_extracted, correct, total | |
| def process_answers(args: Args, data: dict) -> tuple[dict, str]: | |
| """ | |
| return processed data, and accuracy messages | |
| """ | |
| items = list(data.items())[:args.limit] | |
| n_proc = min(os.cpu_count(), 16) | |
| correct_count = 0 | |
| total_count = 0 | |
| task_correct = defaultdict(lambda : 0) | |
| task_total = defaultdict(lambda : 0) | |
| with Pool(processes=n_proc, initializer=init_worker) as pool: | |
| for uuid, ver_list, corr, tot in tqdm( | |
| pool.imap(process_item, items), | |
| total=len(items), | |
| desc=f"Processing {args.input}" | |
| ): | |
| if 'original' in data[uuid] and 'source' in data[uuid]['original']: | |
| source = data[uuid]['original']['source'] | |
| task_correct[source] += corr | |
| task_total[source] += tot | |
| data[uuid]['verify_extracted'] = ver_list | |
| correct_count += corr | |
| total_count += tot | |
| accuracy = correct_count / total_count if total_count > 0 else 0 | |
| print("\nEvaluation Results:") | |
| print(f"Total examples: {total_count}") | |
| print(f"Correct answers: {correct_count}") | |
| print(f"Accuracy: {accuracy:.2%}") | |
| msg=f"Acc={accuracy*100:.3f}" | |
| for task, correct in task_correct.items(): | |
| acc = correct / task_total[task] | |
| print(f'task={task}, acc = {correct} / {task_total[task]} = {acc:.3%}') | |
| msg += f'.{task}={acc*100:.3f}' | |
| return data, msg | |
| def read_data_from_folder(path: str, task: str) -> dict: | |
| from uuid import uuid4 | |
| print(f'reading data from {path}, task = {task}') | |
| data = {} | |
| for f in tqdm(os.listdir(path)): | |
| if not f.startswith('question'): continue | |
| item = tools_json_load(f'{path}/{f}') | |
| uuid = str(uuid4()) | |
| data[uuid] = { | |
| 'original': { | |
| 'question': item['question'], | |
| 'source': task, | |
| 'solution': '$' + item['answer'] + '$', | |
| }, | |
| 'generated': [ | |
| item['response'] | |
| ] | |
| } | |
| return data | |
| if __name__ == "__main__": | |
| args: Args = parse_args( | |
| Args, | |
| ) | |
| print(args) | |
| start_time = tools_get_time() | |
| data = read_data_from_folder(args.input, args.task) | |
| results, msg = process_answers(args, data) | |
| print(msg) | |
| if args.output is None: | |
| args.output = f"{args.input}/merged.verify.{msg}.json" | |
| if args.save: | |
| tools_json_dump(results, args.output) | |
| print(f"Results are saved to={args.output}, costs {tools_elapsed_time(start_time)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment