Last active
June 30, 2025 12:10
-
-
Save dhbrojas/30ce33fcbdb55f973a4e64fde5e52fd7 to your computer and use it in GitHub Desktop.
Data Processing for LLM Training
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from random import choices, randint | |
| from typing import Any, Callable, Dict, Generic, List, TypeVar | |
| import torch | |
| from torch import Tensor | |
| T = TypeVar("T") | |
| @dataclass | |
| class TextDocument: | |
| """A single text document""" | |
| text: str | |
| document: int | |
| """The unique ID of the document this sequence belongs to""" | |
| source: int | |
| """The unique ID of the source this sequence belongs to""" | |
| @dataclass | |
| class Sequence: | |
| """A sequence of tokens corresponding to a document""" | |
| ids: List[int] | |
| document: int | |
| """The unique ID of the document this sequence belongs to""" | |
| source: int | |
| """The unique ID of the source this sequence belongs to""" | |
| @dataclass | |
| class TrainSequence: | |
| """A sequence of input and target tokens corresponding to a document for training""" | |
| x: List[int] | |
| y: List[int] | |
| document: int | |
| """The unique ID of the document this sequence belongs to""" | |
| source: int | |
| """The unique ID of the source this sequence belongs to""" | |
| def __post_init__(self): | |
| assert len(self.x) == len(self.y) | |
| @dataclass | |
| class PackedTrainSequence: | |
| """A sequence of input and target tokens corresponding to one or more documents for training""" | |
| x: List[int] | |
| y: List[int] | |
| document: List[int] | |
| """For each token in `x`, the unique ID of the document it belongs to""" | |
| source: List[int] | |
| """For each token in `y`, the unique ID of the source it belongs to""" | |
| def __post_init__(self): | |
| assert len(self.x) == len(self.y) == len(self.document) == len(self.source) | |
| def __len__(self): | |
| return len(self.x) | |
| def __add__(self, other: "PackedTrainSequence") -> "PackedTrainSequence": | |
| return PackedTrainSequence( | |
| x=self.x + other.x, | |
| y=self.y + other.y, | |
| document=self.document + other.document, | |
| source=self.source + other.source, | |
| ) | |
| def __getitem__(self, key: slice) -> "PackedTrainSequence": | |
| return PackedTrainSequence( | |
| x=self.x[key], | |
| y=self.y[key], | |
| document=self.document[key], | |
| source=self.source[key], | |
| ) | |
| @staticmethod | |
| def from_train_sequence(sequence: "TrainSequence") -> "PackedTrainSequence": | |
| return PackedTrainSequence( | |
| x=sequence.x, | |
| y=sequence.y, | |
| document=[sequence.document] * len(sequence.x), | |
| source=[sequence.source] * len(sequence.x), | |
| ) | |
| @staticmethod | |
| def padding(length: int, pad: int) -> "PackedTrainSequence": | |
| return PackedTrainSequence( | |
| x=[pad] * length, | |
| y=[-100] * length, | |
| document=[-1] * length, | |
| source=[-1] * length, | |
| ) | |
| @dataclass | |
| class Batch: | |
| """A batch of training data""" | |
| x: Tensor | |
| y: Tensor | |
| document: Tensor | |
| """For each token in `x`, the unique ID of the document it belongs to""" | |
| source: Tensor | |
| """For each token in `y`, the unique ID of the source it belongs to""" | |
| class Tokenizer: | |
| def encode(self, text: str) -> List[int]: | |
| # Encodes a string into a list of tokens. | |
| ... | |
| @property | |
| def special_tokens(self) -> Dict[str, int]: | |
| # Returns the mapping from special token string to its ID. | |
| ... | |
| class Step(ABC, Generic[T]): | |
| @abstractmethod | |
| def next(self) -> T | None: ... | |
| class ParquetStep(Step[Dict[str, Any]]): | |
| def __init__(self, path: str): | |
| self.path = path | |
| def next(self) -> Dict[str, Any] | None: | |
| # Loads samples from a Parquet file. | |
| ... | |
| class FormatStep(Step[TextDocument]): | |
| def __init__( | |
| self, | |
| parent: Step[Dict[str, Any]], | |
| *, | |
| formatter: Callable[[Dict[str, Any]], str], | |
| source: int, | |
| ): | |
| self.parent = parent | |
| self.source = source | |
| self.formatter = formatter | |
| def next(self) -> TextDocument | None: | |
| item = self.parent.next() | |
| if item is None: | |
| return None | |
| return TextDocument( | |
| text=self.formatter(item), | |
| source=self.source, | |
| document=randint(0, 2**16), | |
| ) | |
| class MaskStep(Step[TrainSequence]): | |
| def __init__(self, parent: Step[TrainSequence]): | |
| self.parent = parent | |
| def next(self) -> TrainSequence: | |
| # Mask the loss on "system" and "user" messages. | |
| ... | |
| class TokenizeStep(Step[Sequence]): | |
| def __init__(self, parent: Step[TextDocument], tokenizer: Tokenizer, bos: int): | |
| self.parent = parent | |
| self.tokenizer = tokenizer | |
| self.bos = bos | |
| def next(self) -> Sequence | None: | |
| item = self.parent.next() | |
| if item is None: | |
| return None | |
| ids = [self.bos] + self.tokenizer.encode(item.text) | |
| return Sequence( | |
| ids=ids, | |
| document=item.document, | |
| source=item.source, | |
| ) | |
| class IntoTrainSequenceStep(Step[TrainSequence]): | |
| def __init__(self, parent: Step[Sequence]): | |
| self.parent = parent | |
| def next(self) -> TrainSequence | None: | |
| item = self.parent.next() | |
| if item is None: | |
| return None | |
| return TrainSequence( | |
| x=item.ids[:-1], | |
| y=item.ids[1:], | |
| document=item.document, | |
| source=item.source, | |
| ) | |
| class WeightedSampleStep(Step[T]): | |
| def __init__(self, parents: List[Step[T]], weights: List[float]): | |
| self.parents = parents | |
| self.weights = weights | |
| def next(self) -> T | None: | |
| parent = choices(self.parents, self.weights)[0] | |
| return parent.next() | |
| class PackAndPadStep(Step[PackedTrainSequence]): | |
| def __init__( | |
| self, | |
| parent: Step[TrainSequence], | |
| *, | |
| sequence_length: int, | |
| pad: int, | |
| max_num_buckets: int = 32, | |
| ): | |
| self.parent = parent | |
| self.sequence_length = sequence_length | |
| self.pad = pad | |
| self.max_num_buckets = max_num_buckets | |
| self.buckets = [] | |
| def next(self) -> PackedTrainSequence | None: | |
| while True: | |
| if len(self.buckets) > self.max_num_buckets or self.parent is None: | |
| if len(self.buckets) == 0: | |
| # We've exhausted data source and have no more buckets. | |
| return None | |
| # Find the largest bucket and return it. | |
| highest = 0 | |
| highest_index = 0 | |
| for i in range(len(self.buckets)): | |
| if len(self.buckets[i]) > highest: | |
| highest = len(self.buckets[i]) | |
| highest_index = i | |
| sequence = self.buckets.pop(highest_index) | |
| missing = self.sequence_length - len(sequence) | |
| sequence = sequence + PackedTrainSequence.padding(missing, self.pad) | |
| return sequence | |
| sequence = self.parent.next() | |
| if sequence is None: | |
| self.parent = None | |
| continue | |
| sequence = PackedTrainSequence.from_train_sequence(sequence) | |
| while len(sequence) > 0: | |
| subsequence = sequence[: self.sequence_length] | |
| sequence = sequence[self.sequence_length :] | |
| for i in range(len(self.buckets)): | |
| if len(self.buckets[i]) + len(subsequence) <= self.sequence_length: | |
| self.buckets[i] = self.buckets[i] + subsequence | |
| subsequence = None | |
| break | |
| if subsequence is not None: | |
| # Could not find a fit, create a new bucket. | |
| self.buckets.append(subsequence) | |
| class CollateStep(Step[Batch]): | |
| def __init__( | |
| self, parent: Step[PackedTrainSequence], batch_size: int, device: torch.device | |
| ): | |
| self.parent = parent | |
| self.batch_size = batch_size | |
| self.device = device | |
| def next(self) -> Batch | None: | |
| batch = [] | |
| for _ in range(self.batch_size): | |
| sequence = self.parent.next() | |
| if sequence is None: | |
| return None | |
| batch.append(sequence) | |
| return Batch( | |
| x=torch.stack( | |
| [torch.tensor(sequence.x, device=self.device) for sequence in batch] | |
| ), | |
| y=torch.stack( | |
| [torch.tensor(sequence.y, device=self.device) for sequence in batch] | |
| ), | |
| source=torch.stack( | |
| [ | |
| torch.tensor(sequence.source, device=self.device) | |
| for sequence in batch | |
| ] | |
| ), | |
| document=torch.stack( | |
| [ | |
| torch.tensor(sequence.document, device=self.device) | |
| for sequence in batch | |
| ] | |
| ), | |
| ) | |
| class BackgroundPrefetchStep(Step[T]): | |
| def __init__(self, parent: Step[T], prefetch: int = 1): | |
| # Moves all the computation in a worker process. | |
| self.parent = parent | |
| self.process = ... | |
| self.queue = ... | |
| def next(self) -> T | None: ... | |
| # Example usage | |
| tokenizer = Tokenizer() | |
| def format_python_sample(sample: Dict[str, Any]) -> str: | |
| header = f"// File: {sample['file']}\n" | |
| body = sample["code"] | |
| return header + body | |
| def format_chat_sample(sample: Dict[str, Any]) -> str: | |
| v = "" | |
| for message in sample["messages"]: | |
| v += f"<|{message['role']}\n" | |
| v += message["content"] + "\n" | |
| return v | |
| sources = {"python": 0, "chat": 1} | |
| python_samples = ParquetStep("python.parquet") | |
| python_samples = FormatStep( | |
| python_samples, formatter=format_python_sample, source=sources["python"] | |
| ) | |
| chat_samples = ParquetStep("chat.parquet") | |
| chat_samples = FormatStep( | |
| chat_samples, formatter=format_chat_sample, source=sources["chat"] | |
| ) | |
| samples = WeightedSampleStep( | |
| [python_samples, chat_samples], [0.9, 0.1] | |
| ) # 90% python, 10% chat samples | |
| samples = TokenizeStep(samples, tokenizer, bos=tokenizer.special_tokens["<|bos|>"]) | |
| samples = IntoTrainSequenceStep(samples) | |
| samples = MaskStep(samples) | |
| samples = PackAndPadStep( | |
| samples, sequence_length=2048, pad=tokenizer.special_tokens["<|pad|>"] | |
| ) | |
| samples = CollateStep(samples, batch_size=32, device="cuda") | |
| samples = BackgroundPrefetchStep( | |
| samples, prefetch=64 | |
| ) # Samples automatically processed in the background | |
| for step in range(10000): | |
| batch = samples.next() | |
| if batch is None: | |
| break | |
| # Feed batch.x to model, use batch.y to compute cross-entropy loss | |
| # ... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment