Skip to content

Instantly share code, notes, and snippets.

@vukrosic
Last active August 7, 2025 12:33
Show Gist options
  • Select an option

  • Save vukrosic/c1c85697b11b384fa0b356f211defc96 to your computer and use it in GitHub Desktop.

Select an option

Save vukrosic/c1c85697b11b384fa0b356f211defc96 to your computer and use it in GitHub Desktop.
Training and inference of image generation diffusion model for MNIST hand written digits
# -*- coding: utf-8 -*-
import torch
import torchvision
import datasets
import diffusers
import accelerate
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import os
from dataclasses import dataclass
import PIL
# ---- Config ----
@dataclass
class TrainingConfig:
image_size = 32 # Resize the digits to be a power of two
train_batch_size = 32
eval_batch_size = 32
num_epochs = 5
gradient_accumulation_steps = 1
learning_rate = 1e-4
lr_warmpup_steps = 500
mixed_precision = 'fp16'
seed = 0
config = TrainingConfig()
# ---- Dataset ----
def transform(dataset):
preprocess = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((config.image_size, config.image_size)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Lambda(lambda x: 2 * (x - 0.5)),
]
)
images = [preprocess(image) for image in dataset["image"]]
return {"images": images}
def get_dataloader():
mnist_dataset = datasets.load_dataset('mnist', split='train')
mnist_dataset.reset_format()
mnist_dataset.set_transform(transform)
return torch.utils.data.DataLoader(
mnist_dataset,
batch_size=config.train_batch_size,
shuffle=True,
), mnist_dataset
# ---- Model ----
def get_model():
return diffusers.UNet2DModel(
sample_size=config.image_size,
in_channels=1,
out_channels=1,
layers_per_block=2,
block_out_channels=(128, 128, 256, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
# ---- Noise Scheduler ----
def get_noise_scheduler():
return diffusers.DDPMScheduler(num_train_timesteps=200, tensor_format='pt')
# ---- Optimizer ----
def get_optimizer(model):
return torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
# ---- LR Scheduler ----
def get_lr_scheduler(optimizer, train_dataloader):
return diffusers.optimization.get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmpup_steps,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)
# ---- Training Loop ----
def train_loop(
config,
model,
noise_scheduler,
optimizer,
train_dataloader,
lr_scheduler,
):
accelerator = accelerate.Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
for epoch in range(config.num_epochs):
progress_bar = tqdm(total=len(train_dataloader),
disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch['images']
noise = torch.randn(clean_images.shape).to(clean_images.device)
batch_size = clean_images.shape[0]
# Sample a set of random time steps for each image in mini-batch
timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps, (batch_size,), device=clean_images.device)
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
with accelerator.accumulate(model):
noise_pred = model(noisy_images, timesteps)["sample"]
loss = torch.nn.functional.mse_loss(noise_pred, noise)
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
logs = {
"loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
accelerator.unwrap_model(model)
# ---- Sampling ----
@torch.no_grad()
def sample(unet, scheduler, seed, save_process_dir=None):
torch.manual_seed(seed)
if save_process_dir:
if not os.path.exists(save_process_dir):
os.mkdir(save_process_dir)
scheduler.set_timesteps(1000)
device = next(unet.parameters()).device
image = torch.randn((1, 1, 32, 32)).to(device)
num_steps = max(scheduler.timesteps).cpu().numpy()
for t in scheduler.timesteps:
model_output = unet(image, t)['sample']
image = scheduler.step(model_output, int(t), image, generator=None)['prev_sample']
if save_process_dir:
save_image = torchvision.transforms.ToPILImage()(image.squeeze(0).cpu().clamp(-1, 1) * 0.5 + 0.5)
save_image.resize((256, 256)).save(
os.path.join(save_process_dir, "seed-" + str(seed) + "_" + f"{num_steps - t.cpu().numpy():03d}" + ".png"),
format="png"
)
return torchvision.transforms.ToPILImage()(image.squeeze(0).cpu().clamp(-1, 1) * 0.5 + 0.5)
# ---- Main ----
def main():
torch.manual_seed(config.seed)
# Data
train_dataloader, mnist_dataset = get_dataloader()
# Load trained model for inference
model = diffusers.UNet2DModel.from_pretrained("mnist_diffusion_ckpt/unet")
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
noise_scheduler = get_noise_scheduler()
# Optionally, test model input/output shape
sample_image = mnist_dataset[0]["images"].unsqueeze(0).to(model.device)
print("Input shape:", sample_image.shape)
print('Output shape:', model(sample_image, timestep=0)["sample"].shape)
# Sample and save images
for seed in [2, 5, 42, 1991, 2022]:
test_image = sample(model, noise_scheduler, seed)
test_image.resize((256, 256)).save(f"mnist_diffusion_sample_seed{seed}.png")
print(f"Saved sample for seed {seed}.")
if __name__ == "__main__":
main()
import torch
import torchvision
import datasets
import diffusers
import accelerate
from tqdm.auto import tqdm
import os
from dataclasses import dataclass
@dataclass
class TrainingConfig:
image_size = 32
train_batch_size = 64
eval_batch_size = 32
num_epochs = 1
gradient_accumulation_steps = 1
learning_rate = 1e-4
lr_warmpup_steps = 500
mixed_precision = 'fp16'
seed = 0
config = TrainingConfig()
def transform(dataset):
preprocess = torchvision.transforms.Compose([
torchvision.transforms.Resize((config.image_size, config.image_size)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Lambda(lambda x: 2 * (x - 0.5)),
])
images = [preprocess(image) for image in dataset["image"]]
return {"images": images}
def get_dataloader():
mnist_dataset = datasets.load_dataset('mnist', split='train')
mnist_dataset.reset_format()
mnist_dataset.set_transform(transform)
return torch.utils.data.DataLoader(
mnist_dataset,
batch_size=config.train_batch_size,
shuffle=True,
), mnist_dataset
def get_model():
return diffusers.UNet2DModel(
sample_size=config.image_size,
in_channels=1,
out_channels=1,
layers_per_block=2,
block_out_channels=(128, 128, 256, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
def get_noise_scheduler():
return diffusers.DDPMScheduler(num_train_timesteps=200, tensor_format='pt')
def get_optimizer(model):
return torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
def get_lr_scheduler(optimizer, train_dataloader):
return diffusers.optimization.get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmpup_steps,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)
def train_loop(
config,
model,
noise_scheduler,
optimizer,
train_dataloader,
lr_scheduler,
max_steps=None,
):
accelerator = accelerate.Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
for epoch in range(config.num_epochs):
progress_bar = tqdm(total=len(train_dataloader),
disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
if max_steps is not None and step >= max_steps:
break
clean_images = batch['images']
noise = torch.randn(clean_images.shape).to(clean_images.device)
batch_size = clean_images.shape[0]
timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps, (batch_size,), device=clean_images.device)
# Fix: move timesteps to cpu for scheduler
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps.cpu())
with accelerator.accumulate(model):
noise_pred = model(noisy_images, timesteps)["sample"]
loss = torch.nn.functional.mse_loss(noise_pred, noise)
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
logs = {
"loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
accelerator.unwrap_model(model)
def main():
torch.manual_seed(config.seed)
train_dataloader, mnist_dataset = get_dataloader()
model = get_model()
noise_scheduler = get_noise_scheduler()
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to("cuda")
optimizer = get_optimizer(model)
lr_scheduler = get_lr_scheduler(optimizer, train_dataloader)
# Optionally, test model input/output shape
sample_image = mnist_dataset[0]["images"].unsqueeze(0)
print("Input shape:", sample_image.shape)
print('Output shape:', model(sample_image, timestep=0)["sample"].shape)
# Train
train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler, max_steps=None)
# Save model and scheduler
os.makedirs("mnist_diffusion_ckpt", exist_ok=True)
model.save_pretrained("mnist_diffusion_ckpt/unet")
# noise_scheduler.save_pretrained("mnist_diffusion_ckpt/scheduler") # This line is commented out as per the edit hint
print("Model and scheduler saved to mnist_diffusion_ckpt/")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment