Created
January 27, 2026 14:47
-
-
Save staghado/e0834f1afd105459030bac5ff8385ad1 to your computer and use it in GitHub Desktop.
Run DeepSeek-OCR-2 on OlmOCR-bench
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| DeepSeek-OCR-2 markdown extraction for olmocr-bench. | |
| Generates markdown files from images. For scoring, see: https://github.com/allenai/olmocr | |
| config: | |
| Prompt: "<image>\n<|grounding|>Convert the document to markdown." | |
| Repeats: 3 | |
| setup: | |
| git clone https://github.com/deepseek-ai/DeepSeek-OCR-2.git deepseek_ocr2_repo | |
| cd deepseek_ocr2_repo && git checkout 9529503f1765eb5a87a0f5ef5cdf3dc81d246dd1 && cd .. | |
| uv venv .venv-deepseek-ocr2 --python 3.12 && source .venv-deepseek-ocr2/bin/activate | |
| uv pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cu118 | |
| wget https://github.com/vllm-project/vllm/releases/download/v0.8.5/vllm-0.8.5+cu118-cp38-abi3-manylinux1_x86_64.whl | |
| uv pip install ./vllm-0.8.5+cu118-cp38-abi3-manylinux1_x86_64.whl | |
| uv pip install -r deepseek_ocr2_repo/requirements.txt | |
| uv pip install flash-attn==2.7.3 --no-build-isolation matplotlib | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| import re | |
| from pathlib import Path | |
| from concurrent.futures import ThreadPoolExecutor | |
| from tqdm import tqdm | |
| # Must set before importing vllm | |
| os.environ['VLLM_USE_V1'] = '0' | |
| import torch | |
| if torch.version.cuda == '11.8': | |
| os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas" | |
| from PIL import Image, ExifTags | |
| # Add DeepSeek-OCR-2 vllm code to path | |
| DEEPSEEK_VLLM_PATH = "/home/said_taghadouini/vision/scripts/lightonocr/benchmark/vllm_combined_benches/deepseek_ocr2_repo/DeepSeek-OCR2-master/DeepSeek-OCR2-vllm" | |
| sys.path.insert(0, DEEPSEEK_VLLM_PATH) | |
| # Import and register custom model | |
| from deepseek_ocr2 import DeepseekOCR2ForCausalLM | |
| from vllm.model_executor.models.registry import ModelRegistry | |
| ModelRegistry.register_model("DeepseekOCR2ForCausalLM", DeepseekOCR2ForCausalLM) | |
| from vllm import LLM, SamplingParams | |
| from process.ngram_norepeat import NoRepeatNGramLogitsProcessor | |
| from process.image_process import DeepseekOCR2Processor | |
| def correct_image_orientation(image): | |
| """Fix image orientation based on EXIF data.""" | |
| try: | |
| exif = image._getexif() | |
| if exif is not None: | |
| for tag, value in ExifTags.TAGS.items(): | |
| if value == 'Orientation': | |
| orientation_key = tag | |
| break | |
| orientation = exif.get(orientation_key, 1) | |
| if orientation == 3: | |
| image = image.rotate(180, expand=True) | |
| elif orientation == 6: | |
| image = image.rotate(270, expand=True) | |
| elif orientation == 8: | |
| image = image.rotate(90, expand=True) | |
| except Exception as e: | |
| pass | |
| return image | |
| def clean_formula(text): | |
| """Clean up LaTeX formulas in the output.""" | |
| formula_pattern = r'\\\[(.*?)\\\]' | |
| def process_formula(match): | |
| formula = match.group(1) | |
| formula = re.sub(r'\\quad\s*\([^)]*\)', '', formula) | |
| formula = formula.strip() | |
| return r'\[' + formula + r'\]' | |
| return re.sub(formula_pattern, process_formula, text) | |
| def re_match(text): | |
| """Find and extract reference matches.""" | |
| pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| matches_other = [m[0] for m in matches] | |
| return matches, matches_other | |
| def process_single_image(args): | |
| """Process a single image for vLLM batch input.""" | |
| image, prompt, crop_mode = args | |
| cache_item = { | |
| "prompt": prompt, | |
| "multi_modal_data": { | |
| "image": DeepseekOCR2Processor().tokenize_with_images( | |
| images=[image], bos=True, eos=True, cropping=crop_mode | |
| ) | |
| }, | |
| } | |
| return cache_item | |
| def main(): | |
| parser = argparse.ArgumentParser(description="DeepSeek-OCR-2 benchmark for olmocr-bench") | |
| parser.add_argument("--model-path", default="deepseek-ai/DeepSeek-OCR-2", help="Model path") | |
| parser.add_argument("--olmo-root", required=True, help="Path to olmo images") | |
| parser.add_argument("--results-dir", required=True, help="Results directory") | |
| parser.add_argument("--model-name", default="deepseek-ocr-2", help="Model name for results folder") | |
| parser.add_argument("--olmo-repeats", type=int, default=3, help="Number of repeats per image") | |
| parser.add_argument("--max-concurrency", type=int, default=64, help="Max concurrent requests") | |
| parser.add_argument("--num-workers", type=int, default=8, help="Workers for image preprocessing") | |
| parser.add_argument("--crop-mode", action="store_true", default=True, help="Use crop mode") | |
| parser.add_argument("--gpu-memory-utilization", type=float, default=0.7, help="GPU memory utilization") | |
| parser.add_argument("--batch-size", type=int, default=100, help="Batch size for processing (to avoid OOM)") | |
| args = parser.parse_args() | |
| # Update config for DeepSeek processor | |
| import config | |
| config.CROP_MODE = args.crop_mode | |
| config.PROMPT = "<image>\n<|grounding|>Convert the document to markdown." | |
| # config.PROMPT = "<image>\nFree OCR." | |
| prompt = config.PROMPT | |
| print(f"Model: {args.model_path}") | |
| print(f"Prompt: {prompt}") | |
| print(f"OLMO root: {args.olmo_root}") | |
| print(f"Results dir: {args.results_dir}") | |
| print(f"Repeats: {args.olmo_repeats}") | |
| # Initialize vLLM | |
| print("Initializing vLLM...") | |
| llm = LLM( | |
| model=args.model_path, | |
| hf_overrides={"architectures": ["DeepseekOCR2ForCausalLM"]}, | |
| block_size=256, | |
| enforce_eager=False, | |
| trust_remote_code=True, | |
| max_model_len=8192, | |
| swap_space=0, | |
| max_num_seqs=args.max_concurrency, | |
| tensor_parallel_size=1, | |
| gpu_memory_utilization=args.gpu_memory_utilization, | |
| ) | |
| logits_processors = [ | |
| NoRepeatNGramLogitsProcessor( | |
| ngram_size=40, window_size=90, | |
| whitelist_token_ids={128821, 128822} | |
| ) | |
| ] | |
| sampling_params = SamplingParams( | |
| temperature=0.0, | |
| max_tokens=8192, | |
| logits_processors=logits_processors, | |
| skip_special_tokens=False, | |
| ) | |
| # Gather images | |
| olmo_root = Path(args.olmo_root) | |
| image_files = sorted(olmo_root.rglob("*.png")) | |
| print(f"Found {len(image_files)} images") | |
| # Results directory | |
| results_base = Path(args.results_dir) / "olmo" / args.model_name | |
| # Process each repeat | |
| for repeat_idx in range(1, args.olmo_repeats + 1): | |
| print(f"\n=== Repeat {repeat_idx}/{args.olmo_repeats} ===") | |
| # Check which images need processing | |
| to_process = [] | |
| for img_path in image_files: | |
| rel_path = img_path.relative_to(olmo_root) | |
| out_path = results_base / rel_path.parent / f"{img_path.stem}_pg1_repeat{repeat_idx}.md" | |
| if not out_path.exists(): | |
| to_process.append((img_path, out_path)) | |
| if not to_process: | |
| print(f"All images already processed for repeat {repeat_idx}") | |
| continue | |
| print(f"Processing {len(to_process)} images...") | |
| # Process in batches to avoid OOM | |
| batch_size = args.batch_size | |
| num_batches = (len(to_process) + batch_size - 1) // batch_size | |
| for batch_idx in range(num_batches): | |
| start_idx = batch_idx * batch_size | |
| end_idx = min((batch_idx + 1) * batch_size, len(to_process)) | |
| batch_items = to_process[start_idx:end_idx] | |
| print(f"\nBatch {batch_idx + 1}/{num_batches} ({len(batch_items)} images)") | |
| # Load images for this batch | |
| images = [] | |
| out_paths = [] | |
| for img_path, out_path in batch_items: | |
| image = Image.open(img_path) | |
| image = correct_image_orientation(image) | |
| images.append(image.convert('RGB')) | |
| out_paths.append(out_path) | |
| # Preprocess for vLLM | |
| with ThreadPoolExecutor(max_workers=args.num_workers) as executor: | |
| batch_inputs = list(tqdm( | |
| executor.map( | |
| process_single_image, | |
| [(img, prompt, args.crop_mode) for img in images] | |
| ), | |
| total=len(images), | |
| desc="Preprocessing" | |
| )) | |
| outputs_list = llm.generate(batch_inputs, sampling_params=sampling_params) | |
| # Save results and free memory | |
| for output, out_path in zip(outputs_list, out_paths): | |
| content = output.outputs[0].text | |
| # Clean up output | |
| content = clean_formula(content) | |
| matches_ref, matches_other = re_match(content) | |
| for match in matches_other: | |
| content = content.replace(match, '').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n') | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| out_path.write_text(content.strip(), encoding='utf-8') | |
| del images, batch_inputs, outputs_list | |
| import gc | |
| gc.collect() | |
| print("\nDone!") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment