Created
October 29, 2024 01:22
-
-
Save yf225/f1e36ff8e47d4547661e6e136701302a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # testing PT Lightning + FSDP2 + CPU offloading + compile + validation | |
| # toy training loop | |
| # 1. copy-paste from https://github.com/Lightning-AI/pytorch-lightning?tab=readme-ov-file#pytorch-lightning-example | |
| # 2. modify to include conversion to torchao.float8 and compiling encoder/decoder | |
| # main.py | |
| # ! pip install torchvision | |
| import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F | |
| from torch.utils.data import Subset | |
| import lightning as L | |
| import numpy as np | |
| from lightning.pytorch.strategies import ModelParallelStrategy | |
| from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy | |
| dp_size = 2 | |
| # -------------------------------- | |
| # Step 1: Define a LightningModule | |
| # -------------------------------- | |
| # A LightningModule (nn.Module subclass) defines a full *system* | |
| # (ie: an LLM, diffusion model, autoencoder, or simple image classifier). | |
| class LitAutoEncoder(L.LightningModule): | |
| def __init__(self): | |
| super().__init__() | |
| self.encoder = nn.Sequential( | |
| nn.Sigmoid(), | |
| nn.Linear(28 * 28, 1024), | |
| nn.ReLU(), | |
| nn.Linear(1024, 1024), | |
| nn.Sigmoid(), | |
| nn.Linear(1024, 4096), | |
| ) | |
| self.decoder = nn.Sequential( | |
| nn.Linear(4096, 4096), | |
| nn.Linear(4096, 1024), | |
| nn.Linear(1024, 28 * 28), | |
| ) | |
| # note: if you skip compile, no more error | |
| self.encoder = torch.compile(self.encoder) | |
| self.decoder = torch.compile(self.decoder) | |
| def configure_model(self): | |
| dp_mesh = self.device_mesh["data_parallel"] | |
| assert dp_mesh.size() > 1 | |
| """ | |
| FSDPStrategy( | |
| sharding_strategy="FULL_SHARD", | |
| backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | |
| sync_module_states=True, | |
| limit_all_gathers=True, | |
| mixed_precision=MixedPrecision( | |
| param_dtype=torch.bfloat16, | |
| reduce_dtype=torch.float32, | |
| ), | |
| cpu_offload=CPUOffload(offload_params=True), | |
| ) | |
| """ | |
| fsdp_policy = dict( | |
| mesh=dp_mesh, | |
| reshard_after_forward=True, | |
| mp_policy=MixedPrecisionPolicy( | |
| param_dtype=torch.bfloat16, | |
| reduce_dtype=torch.float32, | |
| output_dtype=torch.float32, | |
| ), | |
| offload_policy=CPUOffloadPolicy(), | |
| ) | |
| fully_shard( | |
| self.encoder, | |
| **fsdp_policy, | |
| ) | |
| fully_shard( | |
| self.decoder, | |
| **fsdp_policy, | |
| ) | |
| print(self) | |
| def forward(self, x): | |
| # in lightning, forward defines the prediction/inference actions | |
| embedding = self.encoder(x) | |
| return embedding | |
| def training_step(self, batch, batch_idx): | |
| # training_step defines the train loop. It is independent of forward | |
| x, _ = batch | |
| x = x.view(x.size(0), -1) | |
| z = self.encoder(x) | |
| x_hat = self.decoder(z) | |
| loss = F.mse_loss(x_hat, x) | |
| self.log("train_loss", loss) | |
| return loss | |
| def validation_step(self, batch): | |
| x, _ = batch | |
| x = x.view(x.size(0), -1) | |
| z = self.encoder(x) | |
| x_hat = self.decoder(z) | |
| loss = F.mse_loss(x_hat, x) | |
| self.log("val_loss", loss) | |
| return loss | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | |
| return optimizer | |
| # ------------------- | |
| # Step 2: Define data | |
| # ------------------- | |
| dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) | |
| train, val = data.random_split(dataset, [32, 55000 + 5000 - 32]) | |
| train = Subset(train, np.arange(32)) | |
| val = Subset(val, np.arange(32)) | |
| # ------------------- | |
| # Step 3: Train | |
| # ------------------- | |
| class ModelParallelStrategyWithCPUOffload(ModelParallelStrategy): | |
| def setup(self, trainer: "pl.Trainer") -> None: | |
| super().setup(trainer) | |
| self.lightning_module.encoder.to("cpu") | |
| self.lightning_module.decoder.to("cpu") | |
| strategy = ModelParallelStrategyWithCPUOffload( | |
| data_parallel_size=dp_size, | |
| tensor_parallel_size=1, | |
| ) | |
| autoencoder = LitAutoEncoder().to(torch.bfloat16) | |
| trainer = L.Trainer( | |
| strategy=strategy, | |
| # note: if you disable validation, no more error | |
| val_check_interval=1, # args.val_log_step, | |
| ) | |
| trainer.fit( | |
| autoencoder, | |
| data.DataLoader(train, batch_size=32), | |
| data.DataLoader(val, batch_size=32), | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment