Created
March 12, 2026 14:12
-
-
Save NohTow/0f99fddf595e4ea80c3f3a351fb25a23 to your computer and use it in GitHub Desktop.
boilerplate_rlhn.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Script to create and cache the rlhn-680K dataset | |
| from __future__ import annotations | |
| import os | |
| from datasets import Dataset, load_dataset | |
| def _passage_text(passage) -> str: | |
| """Combine title and text for a passage. Handles both dict and Arrow struct formats.""" | |
| if isinstance(passage, dict): | |
| title = passage.get("title", "") | |
| text = passage.get("text", "") | |
| else: | |
| title = getattr(passage, "title", "") or "" | |
| text = getattr(passage, "text", "") or "" | |
| if title: | |
| return title + " " + text | |
| return text | |
| def _convert_negatives_to_text(example, max_negatives): | |
| """Convert negative_passages list of dicts into negative_0, negative_1, ... string columns.""" | |
| negatives = example["negative_passages"] | |
| for i in range(min(len(negatives), max_negatives)): | |
| example[f"negative_{i}"] = _passage_text(negatives[i]) | |
| example["positive_text"] = _passage_text(example["positive_passages"][0]) | |
| return example | |
| def _convert_all_to_text(batch, max_negatives): | |
| """Convert both positives and negatives to text, creating one row per positive. | |
| Called with batched=True, batch_size=1, so each value is a list of length 1. | |
| """ | |
| result = {"query": [], "positive": []} | |
| for i in range(max_negatives): | |
| result[f"negative_{i}"] = [] | |
| for query, positives, negatives in zip( | |
| batch["query"], batch["positive_passages"], batch["negative_passages"] | |
| ): | |
| n_neg = min(len(negatives), max_negatives) | |
| neg_texts = [_passage_text(negatives[i]) for i in range(n_neg)] | |
| neg_texts += [""] * (max_negatives - n_neg) | |
| for pos in positives: | |
| result["query"].append(query) | |
| result["positive"].append(_passage_text(pos)) | |
| for i in range(max_negatives): | |
| result[f"negative_{i}"].append(neg_texts[i]) | |
| return result | |
| def load_train_dataset(single_positive: bool = False): | |
| suffix = "_single" if single_positive else "" | |
| cache_dir = f"/home/antoine_chaffin/rlhn_680k_data{suffix}" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| try: | |
| dataset = Dataset.load_from_disk(cache_dir) | |
| print("Loaded cached rlhn-680K dataset from disk.") | |
| return dataset | |
| except FileNotFoundError: | |
| pass | |
| print("Loading rlhn/rlhn-680K dataset...") | |
| raw = load_dataset("rlhn/rlhn-680K", split="train", num_proc=45) | |
| # Find max negatives we can use (min across all examples, capped at 50) | |
| min_negatives = min(len(neg) for neg in raw["negative_passages"]) | |
| max_negatives = min(min_negatives, 50) | |
| print(f"Min negatives: {min_negatives}, using {max_negatives}") | |
| if single_positive: | |
| print("Converting negatives to text and keeping first positive...") | |
| raw = raw.map( | |
| lambda x: _convert_negatives_to_text(x, max_negatives), | |
| remove_columns=["query_id", "positive_passages", "negative_passages", "subset"], | |
| num_proc=11, | |
| desc="Converting to text", | |
| ) | |
| raw = raw.rename_column("positive_text", "positive") | |
| else: | |
| print("Converting all passages to text (one row per positive)...") | |
| raw = raw.map( | |
| lambda x: _convert_all_to_text(x, max_negatives), | |
| remove_columns=["query_id", "positive_passages", "negative_passages", "subset"], | |
| batched=True, | |
| batch_size=1, | |
| num_proc=11, | |
| desc="Converting to text", | |
| ) | |
| raw.save_to_disk(cache_dir) | |
| print(f"Saved processed dataset to {cache_dir}") | |
| return raw | |
| if __name__ == "__main__": | |
| dataset = load_train_dataset(single_positive=False) | |
| print(dataset) | |
| print(dataset[0]) | |
| print("Done!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment