Skip to content

Instantly share code, notes, and snippets.

@Raibows
Created May 14, 2025 08:28
Show Gist options
  • Select an option

  • Save Raibows/abb3c77245a682ee5567d7d93075baa2 to your computer and use it in GitHub Desktop.

Select an option

Save Raibows/abb3c77245a682ee5567d7d93075baa2 to your computer and use it in GitHub Desktop.
rule-based math extraction and verify
import os
import json
from dataclasses import dataclass
from multiprocessing import Pool
from collections import defaultdict
from datetime import datetime
from math_verify import math_metric, LatexExtractionConfig, ExprExtractionConfig, StringExtractionConfig, parse, verify
from tqdm import tqdm
from simpleArgParser import parse_args
# before run, please install the following
"""
pip install math-verify[antlr4_13_2] simpleArgParser tqdm
"""
def tools_json_load(path) -> dict | list:
with open(path, 'r') as file:
return json.load(file)
def tools_json_dump(obj, path):
with open(path, 'w') as file:
json.dump(obj, file, indent=4)
def tools_get_time() -> str:
import pytz
ZONE = pytz.timezone("US/Eastern")
return datetime.now(ZONE).strftime("%y-%m-%d-%H_%M_%S")
def tools_elapsed_time(previous_time_str: str) -> str:
previous_dt = datetime.strptime(previous_time_str, "%y-%m-%d-%H_%M_%S")
current_dt = datetime.strptime(tools_get_time(), "%y-%m-%d-%H_%M_%S")
delta = current_dt - previous_dt
days = delta.days
seconds = delta.seconds
hours = seconds // 3600
minutes = (seconds % 3600) // 60
seconds = (seconds % 3600) % 60
return f"{days} days, {hours} hours, {minutes} minutes, {seconds} seconds"
@dataclass
class Args:
# the input results folder
input: str
# task name or note, e.g., gsm8k / aime25 / math and etc.
task: str
output: str | None = None
limit: int | None = None
save: bool = True
def post_process(self):
if self.limit is None:
self.limit = int(1e18)
def split_thinking(response: str) -> tuple[str, str]:
"return (thinking, conclusion)"
temp = response.split('</think>')
# actually, for normal llms, it should be treated as conclusion instead of thinking
if len(temp) == 1:
return "", response
else:
return temp[0], temp[1]
def init_worker():
global verify_func
verify_func = math_metric(
gold_extraction_target=(LatexExtractionConfig(), ExprExtractionConfig()),
pred_extraction_target=(LatexExtractionConfig(), ExprExtractionConfig()),
aggregation_function=max,
precision=6
)
def process_item(args_tuple: tuple) -> tuple[str, list[dict], int, int]:
uuid, item = args_tuple
verify_extracted = []
correct = total = 0
raw_gold = item['original']['solution']
pred_ans = None
grade = 0
for raw_generated in item['generated']:
try:
thinking, gen_answer = split_thinking(raw_generated)
grade, extracted = verify_func([raw_gold], [gen_answer])
if extracted is not None:
gold_ans = extracted[0][0]
pred_ans = extracted[1][0]
else:
gold_ans = pred_ans = None
if grade != 1 and pred_ans and pred_ans.lower() in set('abcdef'):
gold_ans = parse(raw_gold, extraction_config=[StringExtractionConfig()])[0]
grade = int(verify(gold_ans, pred_ans))
if grade == 1:
correct += 1
except Exception as e:
pass
# print(f"[{uuid}] error:", e)
finally:
total += 1
verify_extracted.append({
'grade': grade,
'gold': gold_ans,
'generated': pred_ans,
})
return uuid, verify_extracted, correct, total
def process_answers(args: Args, data: dict) -> tuple[dict, str]:
"""
return processed data, and accuracy messages
"""
items = list(data.items())[:args.limit]
n_proc = min(os.cpu_count(), 16)
correct_count = 0
total_count = 0
task_correct = defaultdict(lambda : 0)
task_total = defaultdict(lambda : 0)
with Pool(processes=n_proc, initializer=init_worker) as pool:
for uuid, ver_list, corr, tot in tqdm(
pool.imap(process_item, items),
total=len(items),
desc=f"Processing {args.input}"
):
if 'original' in data[uuid] and 'source' in data[uuid]['original']:
source = data[uuid]['original']['source']
task_correct[source] += corr
task_total[source] += tot
data[uuid]['verify_extracted'] = ver_list
correct_count += corr
total_count += tot
accuracy = correct_count / total_count if total_count > 0 else 0
print("\nEvaluation Results:")
print(f"Total examples: {total_count}")
print(f"Correct answers: {correct_count}")
print(f"Accuracy: {accuracy:.2%}")
msg=f"Acc={accuracy*100:.3f}"
for task, correct in task_correct.items():
acc = correct / task_total[task]
print(f'task={task}, acc = {correct} / {task_total[task]} = {acc:.3%}')
msg += f'.{task}={acc*100:.3f}'
return data, msg
def read_data_from_folder(path: str, task: str) -> dict:
from uuid import uuid4
print(f'reading data from {path}, task = {task}')
data = {}
for f in tqdm(os.listdir(path)):
if not f.startswith('question'): continue
item = tools_json_load(f'{path}/{f}')
uuid = str(uuid4())
data[uuid] = {
'original': {
'question': item['question'],
'source': task,
'solution': '$' + item['answer'] + '$',
},
'generated': [
item['response']
]
}
return data
if __name__ == "__main__":
args: Args = parse_args(
Args,
)
print(args)
start_time = tools_get_time()
data = read_data_from_folder(args.input, args.task)
results, msg = process_answers(args, data)
print(msg)
if args.output is None:
args.output = f"{args.input}/merged.verify.{msg}.json"
if args.save:
tools_json_dump(results, args.output)
print(f"Results are saved to={args.output}, costs {tools_elapsed_time(start_time)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment