Skip to content

Instantly share code, notes, and snippets.

@dhbrojas
Last active June 30, 2025 12:10
Show Gist options
  • Select an option

  • Save dhbrojas/30ce33fcbdb55f973a4e64fde5e52fd7 to your computer and use it in GitHub Desktop.

Select an option

Save dhbrojas/30ce33fcbdb55f973a4e64fde5e52fd7 to your computer and use it in GitHub Desktop.
Data Processing for LLM Training
from abc import ABC, abstractmethod
from dataclasses import dataclass
from random import choices, randint
from typing import Any, Callable, Dict, Generic, List, TypeVar
import torch
from torch import Tensor
T = TypeVar("T")
@dataclass
class TextDocument:
"""A single text document"""
text: str
document: int
"""The unique ID of the document this sequence belongs to"""
source: int
"""The unique ID of the source this sequence belongs to"""
@dataclass
class Sequence:
"""A sequence of tokens corresponding to a document"""
ids: List[int]
document: int
"""The unique ID of the document this sequence belongs to"""
source: int
"""The unique ID of the source this sequence belongs to"""
@dataclass
class TrainSequence:
"""A sequence of input and target tokens corresponding to a document for training"""
x: List[int]
y: List[int]
document: int
"""The unique ID of the document this sequence belongs to"""
source: int
"""The unique ID of the source this sequence belongs to"""
def __post_init__(self):
assert len(self.x) == len(self.y)
@dataclass
class PackedTrainSequence:
"""A sequence of input and target tokens corresponding to one or more documents for training"""
x: List[int]
y: List[int]
document: List[int]
"""For each token in `x`, the unique ID of the document it belongs to"""
source: List[int]
"""For each token in `y`, the unique ID of the source it belongs to"""
def __post_init__(self):
assert len(self.x) == len(self.y) == len(self.document) == len(self.source)
def __len__(self):
return len(self.x)
def __add__(self, other: "PackedTrainSequence") -> "PackedTrainSequence":
return PackedTrainSequence(
x=self.x + other.x,
y=self.y + other.y,
document=self.document + other.document,
source=self.source + other.source,
)
def __getitem__(self, key: slice) -> "PackedTrainSequence":
return PackedTrainSequence(
x=self.x[key],
y=self.y[key],
document=self.document[key],
source=self.source[key],
)
@staticmethod
def from_train_sequence(sequence: "TrainSequence") -> "PackedTrainSequence":
return PackedTrainSequence(
x=sequence.x,
y=sequence.y,
document=[sequence.document] * len(sequence.x),
source=[sequence.source] * len(sequence.x),
)
@staticmethod
def padding(length: int, pad: int) -> "PackedTrainSequence":
return PackedTrainSequence(
x=[pad] * length,
y=[-100] * length,
document=[-1] * length,
source=[-1] * length,
)
@dataclass
class Batch:
"""A batch of training data"""
x: Tensor
y: Tensor
document: Tensor
"""For each token in `x`, the unique ID of the document it belongs to"""
source: Tensor
"""For each token in `y`, the unique ID of the source it belongs to"""
class Tokenizer:
def encode(self, text: str) -> List[int]:
# Encodes a string into a list of tokens.
...
@property
def special_tokens(self) -> Dict[str, int]:
# Returns the mapping from special token string to its ID.
...
class Step(ABC, Generic[T]):
@abstractmethod
def next(self) -> T | None: ...
class ParquetStep(Step[Dict[str, Any]]):
def __init__(self, path: str):
self.path = path
def next(self) -> Dict[str, Any] | None:
# Loads samples from a Parquet file.
...
class FormatStep(Step[TextDocument]):
def __init__(
self,
parent: Step[Dict[str, Any]],
*,
formatter: Callable[[Dict[str, Any]], str],
source: int,
):
self.parent = parent
self.source = source
self.formatter = formatter
def next(self) -> TextDocument | None:
item = self.parent.next()
if item is None:
return None
return TextDocument(
text=self.formatter(item),
source=self.source,
document=randint(0, 2**16),
)
class MaskStep(Step[TrainSequence]):
def __init__(self, parent: Step[TrainSequence]):
self.parent = parent
def next(self) -> TrainSequence:
# Mask the loss on "system" and "user" messages.
...
class TokenizeStep(Step[Sequence]):
def __init__(self, parent: Step[TextDocument], tokenizer: Tokenizer, bos: int):
self.parent = parent
self.tokenizer = tokenizer
self.bos = bos
def next(self) -> Sequence | None:
item = self.parent.next()
if item is None:
return None
ids = [self.bos] + self.tokenizer.encode(item.text)
return Sequence(
ids=ids,
document=item.document,
source=item.source,
)
class IntoTrainSequenceStep(Step[TrainSequence]):
def __init__(self, parent: Step[Sequence]):
self.parent = parent
def next(self) -> TrainSequence | None:
item = self.parent.next()
if item is None:
return None
return TrainSequence(
x=item.ids[:-1],
y=item.ids[1:],
document=item.document,
source=item.source,
)
class WeightedSampleStep(Step[T]):
def __init__(self, parents: List[Step[T]], weights: List[float]):
self.parents = parents
self.weights = weights
def next(self) -> T | None:
parent = choices(self.parents, self.weights)[0]
return parent.next()
class PackAndPadStep(Step[PackedTrainSequence]):
def __init__(
self,
parent: Step[TrainSequence],
*,
sequence_length: int,
pad: int,
max_num_buckets: int = 32,
):
self.parent = parent
self.sequence_length = sequence_length
self.pad = pad
self.max_num_buckets = max_num_buckets
self.buckets = []
def next(self) -> PackedTrainSequence | None:
while True:
if len(self.buckets) > self.max_num_buckets or self.parent is None:
if len(self.buckets) == 0:
# We've exhausted data source and have no more buckets.
return None
# Find the largest bucket and return it.
highest = 0
highest_index = 0
for i in range(len(self.buckets)):
if len(self.buckets[i]) > highest:
highest = len(self.buckets[i])
highest_index = i
sequence = self.buckets.pop(highest_index)
missing = self.sequence_length - len(sequence)
sequence = sequence + PackedTrainSequence.padding(missing, self.pad)
return sequence
sequence = self.parent.next()
if sequence is None:
self.parent = None
continue
sequence = PackedTrainSequence.from_train_sequence(sequence)
while len(sequence) > 0:
subsequence = sequence[: self.sequence_length]
sequence = sequence[self.sequence_length :]
for i in range(len(self.buckets)):
if len(self.buckets[i]) + len(subsequence) <= self.sequence_length:
self.buckets[i] = self.buckets[i] + subsequence
subsequence = None
break
if subsequence is not None:
# Could not find a fit, create a new bucket.
self.buckets.append(subsequence)
class CollateStep(Step[Batch]):
def __init__(
self, parent: Step[PackedTrainSequence], batch_size: int, device: torch.device
):
self.parent = parent
self.batch_size = batch_size
self.device = device
def next(self) -> Batch | None:
batch = []
for _ in range(self.batch_size):
sequence = self.parent.next()
if sequence is None:
return None
batch.append(sequence)
return Batch(
x=torch.stack(
[torch.tensor(sequence.x, device=self.device) for sequence in batch]
),
y=torch.stack(
[torch.tensor(sequence.y, device=self.device) for sequence in batch]
),
source=torch.stack(
[
torch.tensor(sequence.source, device=self.device)
for sequence in batch
]
),
document=torch.stack(
[
torch.tensor(sequence.document, device=self.device)
for sequence in batch
]
),
)
class BackgroundPrefetchStep(Step[T]):
def __init__(self, parent: Step[T], prefetch: int = 1):
# Moves all the computation in a worker process.
self.parent = parent
self.process = ...
self.queue = ...
def next(self) -> T | None: ...
# Example usage
tokenizer = Tokenizer()
def format_python_sample(sample: Dict[str, Any]) -> str:
header = f"// File: {sample['file']}\n"
body = sample["code"]
return header + body
def format_chat_sample(sample: Dict[str, Any]) -> str:
v = ""
for message in sample["messages"]:
v += f"<|{message['role']}\n"
v += message["content"] + "\n"
return v
sources = {"python": 0, "chat": 1}
python_samples = ParquetStep("python.parquet")
python_samples = FormatStep(
python_samples, formatter=format_python_sample, source=sources["python"]
)
chat_samples = ParquetStep("chat.parquet")
chat_samples = FormatStep(
chat_samples, formatter=format_chat_sample, source=sources["chat"]
)
samples = WeightedSampleStep(
[python_samples, chat_samples], [0.9, 0.1]
) # 90% python, 10% chat samples
samples = TokenizeStep(samples, tokenizer, bos=tokenizer.special_tokens["<|bos|>"])
samples = IntoTrainSequenceStep(samples)
samples = MaskStep(samples)
samples = PackAndPadStep(
samples, sequence_length=2048, pad=tokenizer.special_tokens["<|pad|>"]
)
samples = CollateStep(samples, batch_size=32, device="cuda")
samples = BackgroundPrefetchStep(
samples, prefetch=64
) # Samples automatically processed in the background
for step in range(10000):
batch = samples.next()
if batch is None:
break
# Feed batch.x to model, use batch.y to compute cross-entropy loss
# ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment