Skip to content

Instantly share code, notes, and snippets.

@NohTow
Created March 12, 2026 14:12
Show Gist options
  • Select an option

  • Save NohTow/0f99fddf595e4ea80c3f3a351fb25a23 to your computer and use it in GitHub Desktop.

Select an option

Save NohTow/0f99fddf595e4ea80c3f3a351fb25a23 to your computer and use it in GitHub Desktop.
boilerplate_rlhn.py
# Script to create and cache the rlhn-680K dataset
from __future__ import annotations
import os
from datasets import Dataset, load_dataset
def _passage_text(passage) -> str:
"""Combine title and text for a passage. Handles both dict and Arrow struct formats."""
if isinstance(passage, dict):
title = passage.get("title", "")
text = passage.get("text", "")
else:
title = getattr(passage, "title", "") or ""
text = getattr(passage, "text", "") or ""
if title:
return title + " " + text
return text
def _convert_negatives_to_text(example, max_negatives):
"""Convert negative_passages list of dicts into negative_0, negative_1, ... string columns."""
negatives = example["negative_passages"]
for i in range(min(len(negatives), max_negatives)):
example[f"negative_{i}"] = _passage_text(negatives[i])
example["positive_text"] = _passage_text(example["positive_passages"][0])
return example
def _convert_all_to_text(batch, max_negatives):
"""Convert both positives and negatives to text, creating one row per positive.
Called with batched=True, batch_size=1, so each value is a list of length 1.
"""
result = {"query": [], "positive": []}
for i in range(max_negatives):
result[f"negative_{i}"] = []
for query, positives, negatives in zip(
batch["query"], batch["positive_passages"], batch["negative_passages"]
):
n_neg = min(len(negatives), max_negatives)
neg_texts = [_passage_text(negatives[i]) for i in range(n_neg)]
neg_texts += [""] * (max_negatives - n_neg)
for pos in positives:
result["query"].append(query)
result["positive"].append(_passage_text(pos))
for i in range(max_negatives):
result[f"negative_{i}"].append(neg_texts[i])
return result
def load_train_dataset(single_positive: bool = False):
suffix = "_single" if single_positive else ""
cache_dir = f"/home/antoine_chaffin/rlhn_680k_data{suffix}"
os.makedirs(cache_dir, exist_ok=True)
try:
dataset = Dataset.load_from_disk(cache_dir)
print("Loaded cached rlhn-680K dataset from disk.")
return dataset
except FileNotFoundError:
pass
print("Loading rlhn/rlhn-680K dataset...")
raw = load_dataset("rlhn/rlhn-680K", split="train", num_proc=45)
# Find max negatives we can use (min across all examples, capped at 50)
min_negatives = min(len(neg) for neg in raw["negative_passages"])
max_negatives = min(min_negatives, 50)
print(f"Min negatives: {min_negatives}, using {max_negatives}")
if single_positive:
print("Converting negatives to text and keeping first positive...")
raw = raw.map(
lambda x: _convert_negatives_to_text(x, max_negatives),
remove_columns=["query_id", "positive_passages", "negative_passages", "subset"],
num_proc=11,
desc="Converting to text",
)
raw = raw.rename_column("positive_text", "positive")
else:
print("Converting all passages to text (one row per positive)...")
raw = raw.map(
lambda x: _convert_all_to_text(x, max_negatives),
remove_columns=["query_id", "positive_passages", "negative_passages", "subset"],
batched=True,
batch_size=1,
num_proc=11,
desc="Converting to text",
)
raw.save_to_disk(cache_dir)
print(f"Saved processed dataset to {cache_dir}")
return raw
if __name__ == "__main__":
dataset = load_train_dataset(single_positive=False)
print(dataset)
print(dataset[0])
print("Done!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment