Last active
August 7, 2025 12:33
-
-
Save vukrosic/c1c85697b11b384fa0b356f211defc96 to your computer and use it in GitHub Desktop.
Training and inference of image generation diffusion model for MNIST hand written digits
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| import torch | |
| import torchvision | |
| import datasets | |
| import diffusers | |
| import accelerate | |
| from tqdm.auto import tqdm | |
| import matplotlib.pyplot as plt | |
| import os | |
| from dataclasses import dataclass | |
| import PIL | |
| # ---- Config ---- | |
| @dataclass | |
| class TrainingConfig: | |
| image_size = 32 # Resize the digits to be a power of two | |
| train_batch_size = 32 | |
| eval_batch_size = 32 | |
| num_epochs = 5 | |
| gradient_accumulation_steps = 1 | |
| learning_rate = 1e-4 | |
| lr_warmpup_steps = 500 | |
| mixed_precision = 'fp16' | |
| seed = 0 | |
| config = TrainingConfig() | |
| # ---- Dataset ---- | |
| def transform(dataset): | |
| preprocess = torchvision.transforms.Compose( | |
| [ | |
| torchvision.transforms.Resize((config.image_size, config.image_size)), | |
| torchvision.transforms.ToTensor(), | |
| torchvision.transforms.Lambda(lambda x: 2 * (x - 0.5)), | |
| ] | |
| ) | |
| images = [preprocess(image) for image in dataset["image"]] | |
| return {"images": images} | |
| def get_dataloader(): | |
| mnist_dataset = datasets.load_dataset('mnist', split='train') | |
| mnist_dataset.reset_format() | |
| mnist_dataset.set_transform(transform) | |
| return torch.utils.data.DataLoader( | |
| mnist_dataset, | |
| batch_size=config.train_batch_size, | |
| shuffle=True, | |
| ), mnist_dataset | |
| # ---- Model ---- | |
| def get_model(): | |
| return diffusers.UNet2DModel( | |
| sample_size=config.image_size, | |
| in_channels=1, | |
| out_channels=1, | |
| layers_per_block=2, | |
| block_out_channels=(128, 128, 256, 512), | |
| down_block_types=( | |
| "DownBlock2D", | |
| "DownBlock2D", | |
| "AttnDownBlock2D", | |
| "DownBlock2D", | |
| ), | |
| up_block_types=( | |
| "UpBlock2D", | |
| "AttnUpBlock2D", | |
| "UpBlock2D", | |
| "UpBlock2D", | |
| ), | |
| ) | |
| # ---- Noise Scheduler ---- | |
| def get_noise_scheduler(): | |
| return diffusers.DDPMScheduler(num_train_timesteps=200, tensor_format='pt') | |
| # ---- Optimizer ---- | |
| def get_optimizer(model): | |
| return torch.optim.AdamW(model.parameters(), lr=config.learning_rate) | |
| # ---- LR Scheduler ---- | |
| def get_lr_scheduler(optimizer, train_dataloader): | |
| return diffusers.optimization.get_cosine_schedule_with_warmup( | |
| optimizer=optimizer, | |
| num_warmup_steps=config.lr_warmpup_steps, | |
| num_training_steps=(len(train_dataloader) * config.num_epochs), | |
| ) | |
| # ---- Training Loop ---- | |
| def train_loop( | |
| config, | |
| model, | |
| noise_scheduler, | |
| optimizer, | |
| train_dataloader, | |
| lr_scheduler, | |
| ): | |
| accelerator = accelerate.Accelerator( | |
| mixed_precision=config.mixed_precision, | |
| gradient_accumulation_steps=config.gradient_accumulation_steps, | |
| ) | |
| model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |
| model, optimizer, train_dataloader, lr_scheduler | |
| ) | |
| for epoch in range(config.num_epochs): | |
| progress_bar = tqdm(total=len(train_dataloader), | |
| disable=not accelerator.is_local_main_process) | |
| progress_bar.set_description(f"Epoch {epoch}") | |
| for step, batch in enumerate(train_dataloader): | |
| clean_images = batch['images'] | |
| noise = torch.randn(clean_images.shape).to(clean_images.device) | |
| batch_size = clean_images.shape[0] | |
| # Sample a set of random time steps for each image in mini-batch | |
| timesteps = torch.randint( | |
| 0, noise_scheduler.num_train_timesteps, (batch_size,), device=clean_images.device) | |
| noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) | |
| with accelerator.accumulate(model): | |
| noise_pred = model(noisy_images, timesteps)["sample"] | |
| loss = torch.nn.functional.mse_loss(noise_pred, noise) | |
| accelerator.backward(loss) | |
| accelerator.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(1) | |
| logs = { | |
| "loss": loss.detach().item(), | |
| "lr": lr_scheduler.get_last_lr()[0], | |
| } | |
| progress_bar.set_postfix(**logs) | |
| accelerator.unwrap_model(model) | |
| # ---- Sampling ---- | |
| @torch.no_grad() | |
| def sample(unet, scheduler, seed, save_process_dir=None): | |
| torch.manual_seed(seed) | |
| if save_process_dir: | |
| if not os.path.exists(save_process_dir): | |
| os.mkdir(save_process_dir) | |
| scheduler.set_timesteps(1000) | |
| device = next(unet.parameters()).device | |
| image = torch.randn((1, 1, 32, 32)).to(device) | |
| num_steps = max(scheduler.timesteps).cpu().numpy() | |
| for t in scheduler.timesteps: | |
| model_output = unet(image, t)['sample'] | |
| image = scheduler.step(model_output, int(t), image, generator=None)['prev_sample'] | |
| if save_process_dir: | |
| save_image = torchvision.transforms.ToPILImage()(image.squeeze(0).cpu().clamp(-1, 1) * 0.5 + 0.5) | |
| save_image.resize((256, 256)).save( | |
| os.path.join(save_process_dir, "seed-" + str(seed) + "_" + f"{num_steps - t.cpu().numpy():03d}" + ".png"), | |
| format="png" | |
| ) | |
| return torchvision.transforms.ToPILImage()(image.squeeze(0).cpu().clamp(-1, 1) * 0.5 + 0.5) | |
| # ---- Main ---- | |
| def main(): | |
| torch.manual_seed(config.seed) | |
| # Data | |
| train_dataloader, mnist_dataset = get_dataloader() | |
| # Load trained model for inference | |
| model = diffusers.UNet2DModel.from_pretrained("mnist_diffusion_ckpt/unet") | |
| model = model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| noise_scheduler = get_noise_scheduler() | |
| # Optionally, test model input/output shape | |
| sample_image = mnist_dataset[0]["images"].unsqueeze(0).to(model.device) | |
| print("Input shape:", sample_image.shape) | |
| print('Output shape:', model(sample_image, timestep=0)["sample"].shape) | |
| # Sample and save images | |
| for seed in [2, 5, 42, 1991, 2022]: | |
| test_image = sample(model, noise_scheduler, seed) | |
| test_image.resize((256, 256)).save(f"mnist_diffusion_sample_seed{seed}.png") | |
| print(f"Saved sample for seed {seed}.") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torchvision | |
| import datasets | |
| import diffusers | |
| import accelerate | |
| from tqdm.auto import tqdm | |
| import os | |
| from dataclasses import dataclass | |
| @dataclass | |
| class TrainingConfig: | |
| image_size = 32 | |
| train_batch_size = 64 | |
| eval_batch_size = 32 | |
| num_epochs = 1 | |
| gradient_accumulation_steps = 1 | |
| learning_rate = 1e-4 | |
| lr_warmpup_steps = 500 | |
| mixed_precision = 'fp16' | |
| seed = 0 | |
| config = TrainingConfig() | |
| def transform(dataset): | |
| preprocess = torchvision.transforms.Compose([ | |
| torchvision.transforms.Resize((config.image_size, config.image_size)), | |
| torchvision.transforms.ToTensor(), | |
| torchvision.transforms.Lambda(lambda x: 2 * (x - 0.5)), | |
| ]) | |
| images = [preprocess(image) for image in dataset["image"]] | |
| return {"images": images} | |
| def get_dataloader(): | |
| mnist_dataset = datasets.load_dataset('mnist', split='train') | |
| mnist_dataset.reset_format() | |
| mnist_dataset.set_transform(transform) | |
| return torch.utils.data.DataLoader( | |
| mnist_dataset, | |
| batch_size=config.train_batch_size, | |
| shuffle=True, | |
| ), mnist_dataset | |
| def get_model(): | |
| return diffusers.UNet2DModel( | |
| sample_size=config.image_size, | |
| in_channels=1, | |
| out_channels=1, | |
| layers_per_block=2, | |
| block_out_channels=(128, 128, 256, 512), | |
| down_block_types=( | |
| "DownBlock2D", | |
| "DownBlock2D", | |
| "AttnDownBlock2D", | |
| "DownBlock2D", | |
| ), | |
| up_block_types=( | |
| "UpBlock2D", | |
| "AttnUpBlock2D", | |
| "UpBlock2D", | |
| "UpBlock2D", | |
| ), | |
| ) | |
| def get_noise_scheduler(): | |
| return diffusers.DDPMScheduler(num_train_timesteps=200, tensor_format='pt') | |
| def get_optimizer(model): | |
| return torch.optim.AdamW(model.parameters(), lr=config.learning_rate) | |
| def get_lr_scheduler(optimizer, train_dataloader): | |
| return diffusers.optimization.get_cosine_schedule_with_warmup( | |
| optimizer=optimizer, | |
| num_warmup_steps=config.lr_warmpup_steps, | |
| num_training_steps=(len(train_dataloader) * config.num_epochs), | |
| ) | |
| def train_loop( | |
| config, | |
| model, | |
| noise_scheduler, | |
| optimizer, | |
| train_dataloader, | |
| lr_scheduler, | |
| max_steps=None, | |
| ): | |
| accelerator = accelerate.Accelerator( | |
| mixed_precision=config.mixed_precision, | |
| gradient_accumulation_steps=config.gradient_accumulation_steps, | |
| ) | |
| model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |
| model, optimizer, train_dataloader, lr_scheduler | |
| ) | |
| for epoch in range(config.num_epochs): | |
| progress_bar = tqdm(total=len(train_dataloader), | |
| disable=not accelerator.is_local_main_process) | |
| progress_bar.set_description(f"Epoch {epoch}") | |
| for step, batch in enumerate(train_dataloader): | |
| if max_steps is not None and step >= max_steps: | |
| break | |
| clean_images = batch['images'] | |
| noise = torch.randn(clean_images.shape).to(clean_images.device) | |
| batch_size = clean_images.shape[0] | |
| timesteps = torch.randint( | |
| 0, noise_scheduler.num_train_timesteps, (batch_size,), device=clean_images.device) | |
| # Fix: move timesteps to cpu for scheduler | |
| noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps.cpu()) | |
| with accelerator.accumulate(model): | |
| noise_pred = model(noisy_images, timesteps)["sample"] | |
| loss = torch.nn.functional.mse_loss(noise_pred, noise) | |
| accelerator.backward(loss) | |
| accelerator.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(1) | |
| logs = { | |
| "loss": loss.detach().item(), | |
| "lr": lr_scheduler.get_last_lr()[0], | |
| } | |
| progress_bar.set_postfix(**logs) | |
| accelerator.unwrap_model(model) | |
| def main(): | |
| torch.manual_seed(config.seed) | |
| train_dataloader, mnist_dataset = get_dataloader() | |
| model = get_model() | |
| noise_scheduler = get_noise_scheduler() | |
| noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to("cuda") | |
| optimizer = get_optimizer(model) | |
| lr_scheduler = get_lr_scheduler(optimizer, train_dataloader) | |
| # Optionally, test model input/output shape | |
| sample_image = mnist_dataset[0]["images"].unsqueeze(0) | |
| print("Input shape:", sample_image.shape) | |
| print('Output shape:', model(sample_image, timestep=0)["sample"].shape) | |
| # Train | |
| train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler, max_steps=None) | |
| # Save model and scheduler | |
| os.makedirs("mnist_diffusion_ckpt", exist_ok=True) | |
| model.save_pretrained("mnist_diffusion_ckpt/unet") | |
| # noise_scheduler.save_pretrained("mnist_diffusion_ckpt/scheduler") # This line is commented out as per the edit hint | |
| print("Model and scheduler saved to mnist_diffusion_ckpt/") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment