-
-
Save ashikns/960b22034c7afb9cce7a451b43b599e6 to your computer and use it in GitHub Desktop.
end-to-end pipeline for hard-negative mining, Sentence-Transformers training, and evaluation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Improved end-to-end pipeline for hard-negative mining, Sentence-Transformers training, | |
| and evaluation (including chain-recall) for multi-hop retrieval tasks. | |
| High-level features implemented: | |
| - Three stages implemented: (1) hard negative mining, (2) training, (3) evaluation. | |
| - Search API definition (async-friendly) that your baseline retrieval system must | |
| implement to provide prioritized baseline hard-negatives. | |
| - Baseline hard negatives are given highest priority when merging candidates. | |
| - BM25 margin-based mining (Lexical mining) is performed and merged with baseline | |
| candidates. The relative margin filtering follows the Hugging Face guidance | |
| (mitigating false-negatives by keeping high-scoring BM25 candidates close to golds). | |
| - Disk-caching of mining output keyed by corpus+queries+gold_map+mining-params | |
| (idempotent: identical inputs/params will reuse previous results). | |
| - Training supports both MultipleNegativesRankingLoss and GISTEmbedLoss (guide model). | |
| - Option to avoid duplicate examples from same original query inside a batch via a | |
| PyTorch Sampler (implemented as NoSameQuerySampler using torch.utils.data.Sampler). | |
| - Evaluation: standard recall@k and CHAIN recall@k (the fraction of queries where *all* | |
| gold docs for that query appear in top-k) and a small progressive-hop simulator | |
| helper (optional extension stub). | |
| - Improved, idiomatic usage of SentenceTransformers encoding APIs and careful batching. | |
| Design notes, rationale and best-practices (brief): | |
| - With tiny supervised sets (e.g. ~50 queries) hard negatives are THE most important | |
| signal: prioritize baseline semantic candidates re-scored by your best available | |
| reranker, then add lexical candidates from BM25 inside a margin threshold. | |
| - In-batch negatives are very effective. If you know you may have false negatives | |
| inside a batch (other positives from same original query), either use a | |
| guide model + GISTEmbedLoss (preferred) or the batch-sampler that minimizes | |
| same-query collisions. GISTEmbedLoss requires a guide model available during | |
| training; it masks false negatives dynamically. | |
| - Cache mining results: mining can be expensive; store a cache keyed by a hash | |
| of (corpus ids, query ids, gold pairs, mining params). This guarantees idempotence | |
| and reproducibility across identical runs. | |
| - Monitor recall@20 and CHAIN recall@20 as primary metrics for multi-hop retrieval. | |
| - Use cross-encoder rescoring if available to pick the hardest negatives out of | |
| the candidate set before training. This script provides an adapter point to do so. | |
| The script is inspired by the following Hugging Face blog post: | |
| https://huggingface.co/blog/dragonkue/mitigating-false-negatives-in-retriever-training | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import pathlib | |
| from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple | |
| import numpy as np | |
| import onnx | |
| import onnxruntime as ort | |
| import torch | |
| from datasets import Dataset | |
| from huggingface_hub import snapshot_download | |
| from onnxruntime.transformers.float16 import convert_float_to_float16 | |
| from sentence_transformers import SentenceTransformer, losses | |
| from sentence_transformers.trainer import SentenceTransformerTrainer | |
| from sentence_transformers.training_args import ( | |
| BatchSamplers, | |
| SentenceTransformerTrainingArguments, | |
| ) | |
| from sentence_transformers.util import mine_hard_negatives as st_mine_hard_negatives | |
| from sklearn.externals.array_api_compat.torch import cosine_similarity | |
| from torch import nn | |
| from torch.export import Dim | |
| from tqdm.auto import tqdm | |
| from transformers import AutoTokenizer | |
| # Logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| EMBEDDING_DIM = 256 | |
| ## Utilities (in separate file normally) ## | |
| class SentenceEmbeddingWrapper(nn.Module): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| def forward(self, input_ids, attention_mask, token_type_ids=None): | |
| # Build features dict | |
| features = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| } | |
| if token_type_ids is not None: | |
| features["token_type_ids"] = token_type_ids | |
| # Pass through all modules in the sentence transformer | |
| # This includes: Transformer -> Pooling -> Normalization | |
| for idx in range(len(self.model._modules)): | |
| module = self.model._modules[str(idx)] | |
| features = module(features) | |
| # Return the sentence embedding | |
| return features["sentence_embedding"] | |
| def export_to_onnx( | |
| output_dir: str, | |
| device: str, | |
| ) -> str: | |
| onnx_dir = os.path.join(output_dir, "onnx") | |
| pathlib.Path(onnx_dir).mkdir(parents=True, exist_ok=True) | |
| logger.info(f"Exporting model to ONNX FP16 format: {onnx_dir}") | |
| target_device = torch.device(device) | |
| dummy_batch_size = 4 # Use > 1 to make dynamic shapes clearer | |
| dummy_seq_length = 128 | |
| dummy_input_ids = torch.randint(0, 30522, (dummy_batch_size, dummy_seq_length)).to(target_device) | |
| dummy_attention_mask = torch.ones((dummy_batch_size, dummy_seq_length), dtype=torch.long).to(target_device) | |
| st = SentenceTransformer(output_dir, device=str(target_device)) | |
| st.eval() | |
| wrapper = SentenceEmbeddingWrapper(st).to(target_device) | |
| batch = Dim("batch", min=1, max=1024) | |
| seq = Dim("seq", min=1, max=512) | |
| torch.onnx.export( | |
| wrapper, | |
| (dummy_input_ids, dummy_attention_mask), | |
| os.path.join(onnx_dir, "model.onnx"), | |
| input_names=["input_ids", "attention_mask"], | |
| output_names=["sentence_embedding"], | |
| # dynamic_axes={ | |
| # "input_ids": {0: "batch_size", 1: "sequence_length"}, | |
| # "attention_mask": {0: "batch_size", 1: "sequence_length"}, | |
| # "sentence_embedding": {0: "batch_size"}, | |
| # }, | |
| dynamic_shapes={ | |
| "input_ids": {0: batch, 1: seq}, | |
| "attention_mask": {0: batch, 1: seq}, | |
| }, | |
| export_params=True, | |
| opset_version=21, | |
| do_constant_folding=False, | |
| external_data=False, | |
| ) | |
| model_fp32 = onnx.load(os.path.join(onnx_dir, "model.onnx")) | |
| model_fp16 = convert_float_to_float16( | |
| model_fp32, | |
| keep_io_types=True, | |
| ) | |
| onnx.save_model( | |
| model_fp16, | |
| os.path.join(onnx_dir, "model_fp16.onnx"), | |
| save_as_external_data=False, | |
| all_tensors_to_one_file=True, | |
| location="model_fp16.onnx_data", | |
| convert_attribute=False, | |
| ) | |
| logger.info(f"ONNX model saved to {onnx_dir}") | |
| return onnx_dir | |
| def sanity_check_onnx_export( | |
| hf_download_path: str, | |
| local_model_path: str, | |
| device: str, | |
| ) -> Dict[str, Any]: | |
| logger.info("Comparing PyTorch model vs ONNX model") | |
| logger.info(f" HF: {hf_download_path}") | |
| logger.info(f" ONNX: {local_model_path}") | |
| # Generate test inputs | |
| test_texts = [ | |
| "This is a test sentence for comparison.", | |
| "Another sample text to validate model equivalence.", | |
| "Machine learning and deep neural networks.", | |
| "ONNX is the open neural network exchange format.", | |
| "Comparing model outputs for quality assurance.", | |
| "Embedding models encode text into vectors.", | |
| "Transfer learning with pre-trained transformers.", | |
| "Sentence embeddings capture semantic meaning.", | |
| "Model evaluation is crucial for deployment.", | |
| "Cross-framework compatibility testing.", | |
| ] | |
| # Tokenize | |
| tokenizer = AutoTokenizer.from_pretrained(hf_download_path) | |
| inputs = tokenizer( | |
| test_texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| ) | |
| hf_model = ort.InferenceSession(hf_download_path + "/onnx/model.onnx", providers=["CUDAExecutionProvider"]) | |
| hf_model_16 = ort.InferenceSession(hf_download_path + "/onnx/model_fp16.onnx", providers=["CUDAExecutionProvider"]) | |
| local_model = ort.InferenceSession(local_model_path + "/onnx/model.onnx", providers=["CUDAExecutionProvider"]) | |
| local_model_16 = ort.InferenceSession( | |
| local_model_path + "/onnx/model_fp16.onnx", providers=["CUDAExecutionProvider"] | |
| ) | |
| hf_emb = hf_model.run( | |
| ["sentence_embedding"], | |
| { | |
| "input_ids": inputs["input_ids"].numpy(), | |
| "attention_mask": inputs["attention_mask"].numpy(), | |
| "token_type_ids": np.zeros_like(inputs["input_ids"].numpy()), | |
| }, | |
| ) | |
| local_emb = local_model.run( | |
| ["sentence_embedding"], | |
| { | |
| "input_ids": inputs["input_ids"].numpy(), | |
| "attention_mask": inputs["attention_mask"].numpy(), | |
| }, | |
| ) | |
| hf_emb_16 = hf_model_16.run( | |
| ["sentence_embedding"], | |
| { | |
| "input_ids": inputs["input_ids"].numpy(), | |
| "attention_mask": inputs["attention_mask"].numpy(), | |
| "token_type_ids": np.zeros_like(inputs["input_ids"].numpy()), | |
| }, | |
| ) | |
| local_emb_16 = local_model_16.run( | |
| ["sentence_embedding"], | |
| { | |
| "input_ids": inputs["input_ids"].numpy(), | |
| "attention_mask": inputs["attention_mask"].numpy(), | |
| }, | |
| ) | |
| hf_emb_arr = np.array([np.array(t, dtype=np.float32) for t in hf_emb]) | |
| local_emb_arr = np.array([np.array(t, dtype=np.float32) for t in local_emb]) | |
| hf_emb_16_arr = np.array([np.array(t, dtype=np.float32) for t in hf_emb_16]) | |
| local_emb_16_arr = np.array([np.array(t, dtype=np.float32) for t in local_emb_16]) | |
| # Compare outputs | |
| diffs = np.abs(hf_emb_arr - local_emb_arr) | |
| diffs_16 = np.abs(hf_emb_16_arr - local_emb_16_arr) | |
| max_diff = np.max(diffs) | |
| mean_diff = np.mean(diffs) | |
| std_diff = np.std(diffs) | |
| median_diff = np.median(diffs) | |
| max_diff_16 = np.max(diffs_16) | |
| mean_diff_16 = np.mean(diffs_16) | |
| std_diff_16 = np.std(diffs_16) | |
| median_diff_16 = np.median(diffs_16) | |
| cosine_similarity_scores = cosine_similarity( | |
| torch.tensor(hf_emb_arr), | |
| torch.tensor(local_emb_arr), | |
| ).numpy() | |
| mean_cosine_similarity = np.mean(cosine_similarity_scores) | |
| cosine_similarity_scores_16 = cosine_similarity( | |
| torch.tensor(hf_emb_16_arr), | |
| torch.tensor(local_emb_16_arr), | |
| ).numpy() | |
| mean_cosine_similarity_16 = np.mean(cosine_similarity_scores_16) | |
| return { | |
| "max_diff": float(max_diff), | |
| "mean_diff": float(mean_diff), | |
| "std_diff": float(std_diff), | |
| "median_diff": float(median_diff), | |
| "mean_cosine_similarity": float(mean_cosine_similarity), | |
| "max_diff_16": float(max_diff_16), | |
| "mean_diff_16": float(mean_diff_16), | |
| "std_diff_16": float(std_diff_16), | |
| "median_diff_16": float(median_diff_16), | |
| "mean_cosine_similarity_16": float(mean_cosine_similarity_16), | |
| } | |
| def push_to_huggingface( | |
| model_path: str, | |
| repo_id: str, | |
| onnx_path: Optional[str] = None, | |
| commit_message: str = "Upload trained model", | |
| ) -> None: | |
| """ | |
| Push trained Sentence-Transformers model and ONNX variant to Hugging Face Hub. | |
| Args: | |
| model_path: Local path to the saved PyTorch model | |
| repo_id: Hugging Face repo ID (format: username/repo-name) | |
| onnx_path: Optional path to ONNX model directory | |
| commit_message: Commit message for the upload | |
| """ | |
| try: | |
| from huggingface_hub import HfApi | |
| except ImportError: | |
| logger.error("Hugging Face upload requires 'huggingface-hub'. Install with: pip install huggingface-hub") | |
| return | |
| try: | |
| logger.info(f"Pushing PyTorch model to {repo_id}") | |
| model = SentenceTransformer(model_path) | |
| model.push_to_hub( | |
| repo_id, | |
| commit_message=commit_message, | |
| exist_ok=True, # Allow overwriting existing repo | |
| replace_model_card=True, # Replace existing model card | |
| ) | |
| logger.info("PyTorch model pushed successfully") | |
| # Upload ONNX model if provided | |
| if onnx_path and os.path.exists(onnx_path): | |
| logger.info(f"Pushing ONNX FP16 model to {repo_id}") | |
| api = HfApi() | |
| # Upload ONNX files | |
| for file_name in os.listdir(onnx_path): | |
| file_path = os.path.join(onnx_path, file_name) | |
| if os.path.isfile(file_path) and (file_name.endswith((".onnx")) or file_name.endswith(".onnx_data")): | |
| api.upload_file( | |
| path_or_fileobj=file_path, | |
| path_in_repo=f"onnx/{file_name}", | |
| repo_id=repo_id, | |
| commit_message="Add ONNX FP16 files", | |
| ) | |
| logger.info("ONNX model files uploaded successfully") | |
| logger.info(f"Model available at: https://huggingface.co/{repo_id}") | |
| except Exception as e: | |
| logger.error(f"Error pushing to Hugging Face: {e}") | |
| ## - ## | |
| def load_jsonl(path: str) -> List[Dict[str, Any]]: | |
| items = [] | |
| with open(path, "r", encoding="utf-8") as fh: | |
| for line in fh: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| items.append(json.loads(line)) | |
| return items | |
| def save_jsonl(items: Iterable[Dict[str, Any]], path: str) -> None: | |
| with open(path, "w", encoding="utf-8") as fh: | |
| for it in items: | |
| fh.write(json.dumps(it, ensure_ascii=False) + "\n") | |
| def _prepare_corpus_maps( | |
| corpus: List[Dict[str, Any]], | |
| ) -> Tuple[List[str], List[str], Dict[str, int]]: | |
| corpus_texts = [c["text"] for c in corpus] | |
| corpus_ids = [c["id"] for c in corpus] | |
| id2idx = {cid: i for i, cid in enumerate(corpus_ids)} | |
| return corpus_texts, corpus_ids, id2idx | |
| def _load_baseline_results() -> Dict[str, List[str]]: | |
| path = "./training_data/baseline_search_results.json" | |
| json_data = json.load(open(path, "r", encoding="utf-8")) | |
| return { | |
| item["queryId"]: item["searchDocIds"] for item in json_data["searchResults"] | |
| } | |
| def mine_hard_negatives_st( | |
| queries: List[Dict[str, Any]], | |
| corpus: List[Dict[str, Any]], | |
| gold_pairs: List[Dict[str, Any]], | |
| model: SentenceTransformer, | |
| max_negs_per_example: int = 8, | |
| ) -> Dataset: | |
| corpus_dict = {c["id"]: c["text"] for c in corpus} | |
| query_dict = {q["id"]: q["text"] for q in queries} | |
| dataset_dict = { | |
| "anchor": [query_dict[r["query_id"]] for r in gold_pairs], | |
| "positive": [corpus_dict[r["doc_id"]] for r in gold_pairs], | |
| } | |
| return st_mine_hard_negatives( | |
| dataset=Dataset.from_dict(dataset_dict), | |
| model=model, | |
| corpus=[c["text"] for c in corpus], | |
| query_prompt_name="query", | |
| # relative_margin=0.01, # 0.05 means that the negative is at most 95% as similar to the anchor as the positive | |
| num_negatives=max_negs_per_example, # 10 or less is recommended | |
| sampling_strategy="top", # "top" means that we sample the top candidates as negatives | |
| # batch_size=args.batch_size, # Adjust as needed | |
| ) | |
| def mine_hard_negatives_custom( | |
| queries: List[Dict[str, Any]], | |
| corpus: List[Dict[str, Any]], | |
| gold_pairs: List[Dict[str, Any]], | |
| baseline_search_results: Dict[str, List[str]], | |
| max_negs_per_example: int = 8, | |
| multiple_negs_per_record: bool = True, | |
| ) -> Dataset: | |
| corpus_texts, corpus_ids, id2idx = _prepare_corpus_maps(corpus) | |
| results: List[Dict[str, Any]] = [] | |
| # Synchronous loop with async baseline calls handled via asyncio.run per query | |
| for q in tqdm(queries, desc="Mining queries"): | |
| qid = q["id"] | |
| qtext = q["text"] | |
| golds = set(r["doc_id"] for r in gold_pairs if r["query_id"] == qid) | |
| if not golds: | |
| continue # skip queries with no golds | |
| # 1) baseline candidates (semantic) | |
| baseline_candidates = baseline_search_results.get(qid, []) | |
| baseline_candidates = set( | |
| [cid for cid in baseline_candidates if cid not in golds] | |
| ) | |
| assert all(gid in corpus_ids for gid in golds), "Gold doc ID not in corpus" | |
| assert all(cid in corpus_ids for cid in baseline_candidates), ( | |
| "Baseline candidate ID not in corpus" | |
| ) | |
| assert len(baseline_candidates) >= max_negs_per_example, ( | |
| "Not enough negatives mined" | |
| ) | |
| # 4) explode per gold doc | |
| for gold_id in golds: | |
| rec = { | |
| "query_text": qtext, | |
| "gold_text": corpus[id2idx[gold_id]]["text"], | |
| "negatives": list(baseline_candidates)[:max_negs_per_example], | |
| } | |
| results.append(rec) | |
| if multiple_negs_per_record: | |
| dataset_dict = { | |
| "anchor": [r["query_text"] for r in results], | |
| "positive": [r["gold_text"] for r in results], | |
| } | |
| max_negs = max(len(r.get("negatives", [])) for r in results) if results else 0 | |
| for i in range(max_negs): | |
| dataset_dict[f"negative_{i + 1}"] = [ | |
| corpus[id2idx[r.get("negatives", [])[i]]]["text"] | |
| if i < len(r.get("negatives", [])) | |
| else "" | |
| for r in results | |
| ] | |
| else: | |
| dataset_dict = { | |
| "anchor": [ | |
| r["query_text"] | |
| for r in results | |
| for n in r["negatives"][:max_negs_per_example] | |
| ], | |
| "positive": [ | |
| r["gold_text"] | |
| for r in results | |
| for n in r["negatives"][:max_negs_per_example] | |
| ], | |
| "negative": [ | |
| corpus[id2idx[n]]["text"] | |
| for r in results | |
| for n in r["negatives"][:max_negs_per_example] | |
| ], | |
| } | |
| return Dataset.from_dict(dataset_dict) | |
| async def train( | |
| queries: List[Dict[str, Any]], | |
| corpus: List[Dict[str, Any]], | |
| gold_pairs: List[Dict[str, Any]], | |
| baseline_results: Dict[str, List[str]], | |
| args: argparse.Namespace, | |
| ) -> None: | |
| model = SentenceTransformer(args.model_id, device=args.device) | |
| dataset = mine_hard_negatives_custom( | |
| queries=queries, | |
| corpus=corpus, | |
| gold_pairs=gold_pairs, | |
| baseline_search_results=baseline_results, | |
| multiple_negs_per_record=False, | |
| ) | |
| # dataset = mine_hard_negatives_st( | |
| # queries=queries, | |
| # corpus=corpus, | |
| # gold_pairs=gold_pairs, | |
| # model=model, | |
| # ) | |
| # # write to disk for inspection | |
| # json.dump(dataset.to_dict(), open("mined_custom.json", "w", encoding="utf-8"), indent=2) | |
| # json.dump(dataset2.to_dict(), open("mined_st.json", "w", encoding="utf-8"), indent=2) | |
| training_args = SentenceTransformerTrainingArguments( | |
| output_dir=args.output_dir, | |
| num_train_epochs=args.epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| learning_rate=args.lr, | |
| #warmup_ratio=0.1, | |
| fp16=(args.device != "cpu"), | |
| batch_sampler=BatchSamplers.NO_DUPLICATES, | |
| logging_steps=50, | |
| save_strategy="no", | |
| ) | |
| # loss = MatryoshkaLoss( | |
| # model, | |
| # losses.MultipleNegativesRankingLoss(model), | |
| # matryoshka_dims=[EMBEDDING_DIM], | |
| # ) | |
| trainer = SentenceTransformerTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| loss=losses.MultipleNegativesRankingLoss(model), | |
| ) | |
| trainer.train() | |
| model.save(args.output_dir) | |
| def evaluate( | |
| queries: List[Dict[str, Any]], | |
| corpus: List[Dict[str, Any]], | |
| gold_pairs: List[Dict[str, Any]], | |
| model_id: str, | |
| args: argparse.Namespace, | |
| top_k_list: Sequence[int] = (5, 10, 20), | |
| ) -> Dict[str, Any]: | |
| """ | |
| Compute standard recall@k and CHAIN recall@k where CHAIN recall@k measures | |
| the fraction of queries for which *all* gold docs are present in the top-k | |
| retrieved documents (for multi-hop evaluation). | |
| """ | |
| model = SentenceTransformer(model_id, device=args.device) | |
| corpus_texts, corpus_ids, id2idx = _prepare_corpus_maps(corpus) | |
| logger.info("Encoding corpus for evaluation (model=%s)", model.__class__.__name__) | |
| corpus_emb = model.encode( | |
| corpus_texts, | |
| convert_to_numpy=True, | |
| show_progress_bar=True, | |
| device=args.device, | |
| truncate_dim=EMBEDDING_DIM, | |
| ) | |
| recall_at_k = {k: 0 for k in top_k_list} | |
| chain_recall_at_k = {k: 0 for k in top_k_list} | |
| total = 0 | |
| for q in tqdm(queries, desc="Evaluation queries"): | |
| qid = q["id"] | |
| qtext = q["text"] | |
| golds = set([gp["doc_id"] for gp in gold_pairs if gp["query_id"] == qid]) | |
| if not golds: | |
| continue # skip queries with no golds | |
| total += 1 | |
| q_emb = model.encode( | |
| [qtext], | |
| convert_to_numpy=True, | |
| device=args.device, | |
| truncate_dim=EMBEDDING_DIM, | |
| ) | |
| # fallback: brute force similarity | |
| q_emb_n = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-12) | |
| corpus_n = corpus_emb / ( | |
| np.linalg.norm(corpus_emb, axis=1, keepdims=True) + 1e-12 | |
| ) | |
| sims = (corpus_n @ q_emb_n.T).squeeze(-1) | |
| ranked_idx = np.argsort(sims)[::-1][: max(top_k_list)] | |
| retrieved = [corpus_ids[i] for i in ranked_idx] | |
| for k in top_k_list: | |
| topk = set(retrieved[:k]) | |
| if golds & topk: | |
| recall_at_k[k] += 1 | |
| # chain recall: check whether all golds are included in topk | |
| if golds and golds.issubset(topk): | |
| chain_recall_at_k[k] += 1 | |
| recall_at_k = {k: recall_at_k[k] / total for k in recall_at_k} | |
| chain_recall_at_k = {k: chain_recall_at_k[k] / total for k in chain_recall_at_k} | |
| return { | |
| "recall_at_k": recall_at_k, | |
| "chain_recall_at_k": chain_recall_at_k, | |
| "total_queries": total, | |
| } | |
| async def main(argv: Optional[List[str]] = None) -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_id", default="MongoDB/mdbr-leaf-ir") | |
| parser.add_argument("--queries_path", required=True) | |
| parser.add_argument("--corpus_path", required=True) | |
| parser.add_argument("--gold_path", required=True) | |
| parser.add_argument("--output_dir", required=True) | |
| parser.add_argument( | |
| "--hf_upload_model_id", default="ashikns/mdbr-leaf-ir-finetuned" | |
| ) | |
| parser.add_argument("--train", action="store_true") | |
| parser.add_argument("--evaluate", action="store_true") | |
| parser.add_argument("--export_onnx", default=False) | |
| parser.add_argument("--compare_onnx", default=True) | |
| parser.add_argument("--push_to_hub", default=False) | |
| parser.add_argument("--epochs", type=int, default=30) | |
| parser.add_argument("--batch_size", type=int, default=16) | |
| parser.add_argument("--lr", type=float, default=2e-5) | |
| parser.add_argument("--device", type=str, default="cuda") | |
| args = parser.parse_args(argv) | |
| pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| hf_download_path = snapshot_download( | |
| repo_id=args.model_id, | |
| local_dir="./hf_model", | |
| ) | |
| corpus = load_jsonl(args.corpus_path) | |
| queries = load_jsonl(args.queries_path) | |
| gold_pairs = load_jsonl(args.gold_path) | |
| baseline_results = _load_baseline_results() | |
| # baseline_seen = set( | |
| # [r for results in baseline_results.values() for r in results] | |
| # + [gp["doc_id"] for gp in gold_pairs] | |
| # ) | |
| #corpus = [c for c in corpus if c["id"] in baseline_seen] | |
| # model = SentenceTransformer(args.model_id, device=args.device) | |
| # model.save(args.output_dir) | |
| # export_to_onnx(args.output_dir, args.device) | |
| # print(sanity_check_onnx_export(hf_download_path, args.output_dir, device=args.device)) | |
| if args.train: | |
| await train(queries, corpus, gold_pairs, baseline_results, args) | |
| # 3) Evaluation | |
| if args.evaluate: | |
| # Use the trained model from output_dir if training was performed, otherwise use model_id | |
| model_path = args.output_dir if args.train else args.model_id | |
| metrics = evaluate(queries, corpus, gold_pairs, model_path, args) | |
| logger.info("Evaluation results: %s", json.dumps(metrics, indent=2)) | |
| with open( | |
| os.path.join(args.output_dir, "evaluation.json"), "w", encoding="utf-8" | |
| ) as fh: | |
| json.dump(metrics, fh, indent=2) | |
| # 4) Export to ONNX if requested | |
| onnx_path = None | |
| if args.train and args.export_onnx: | |
| onnx_path = export_to_onnx(args.output_dir, args.device) | |
| # 4b) Compare PyTorch vs ONNX if requested | |
| if args.compare_onnx and onnx_path: | |
| comparison_results = sanity_check_onnx_export( | |
| hf_download_path, args.output_dir, device=args.device | |
| ) | |
| print("ONNX vs PyTorch comparison results:") | |
| print(json.dumps(comparison_results, indent=2)) | |
| # 5) Push to Hugging Face Hub if requested | |
| if args.push_to_hub: | |
| model_path = args.output_dir if args.train else args.model_id | |
| push_to_huggingface(model_path, args.hf_upload_model_id, onnx_path) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment