Skip to content

Instantly share code, notes, and snippets.

@vukrosic
Last active August 5, 2025 18:43
Show Gist options
  • Select an option

  • Save vukrosic/a8c75ddb54eb5c3276b233f8b481977c to your computer and use it in GitHub Desktop.

Select an option

Save vukrosic/a8c75ddb54eb5c3276b233f8b481977c to your computer and use it in GitHub Desktop.
Small text to video diffusion skeleton (generates 1 small video and overfits on it), build it from here
"""
Memory-Efficient Wan Text-to-Video Model - Single File Implementation
Generates synthetic data, trains on it, and demonstrates inference
Optimized for Google Colab with minimal GPU memory usage
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple
import numpy as np
from einops import rearrange
import torch.optim as optim
from torch.amp import autocast, GradScaler
import gc
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML, display
import os
from PIL import Image
# ============================================================================
# Core Utilities and Helpers
# ============================================================================
def sinusoidal_embedding_1d(dim, position):
assert dim % 2 == 0
half = dim // 2
position = position.float() # Use float32 instead of float64
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half, device=position.device, dtype=position.dtype).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.float() # Ensure output is float32
@torch.amp.autocast('cuda', enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@torch.amp.autocast('cuda', enabled=False)
def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(seq_len, 1, -1)
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])
output.append(x_i)
return torch.stack(output).float()
def flash_attention_fallback(q, k, v, k_lens=None):
"""Fallback attention implementation when flash attention is not available"""
if k_lens is not None:
# Create attention mask based on sequence lengths
b, lq, lk = q.size(0), q.size(1), k.size(1)
mask = torch.arange(lk, device=k.device).unsqueeze(0) < k_lens.unsqueeze(1)
mask = mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, Lk]
else:
mask = None
q = q.transpose(1, 2) # [B, H, Lq, D]
k = k.transpose(1, 2) # [B, H, Lk, D]
v = v.transpose(1, 2) # [B, H, Lk, D]
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
return out.transpose(1, 2).contiguous() # [B, Lq, H, D]
# ============================================================================
# Memory-Efficient VAE Components (Simplified)
# ============================================================================
class SimpleConv3d(nn.Module):
"""Memory-efficient 3D convolution using 2D conv + 1D temporal"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super().__init__()
self.spatial_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.temporal_conv = nn.Conv1d(out_channels, out_channels, 3, padding=1)
def forward(self, x):
b, c, t, h, w = x.shape
# Spatial convolution
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.spatial_conv(x)
_, c_new, h_new, w_new = x.shape
x = rearrange(x, '(b t) c h w -> b c t h w', b=b, t=t)
# Temporal convolution - fix the shape calculation
x = rearrange(x, 'b c t h w -> (b h w) c t')
x = self.temporal_conv(x)
x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h_new, w=w_new)
return x
class TinyVAE(nn.Module):
"""Tiny VAE perfectly sized for single circle task"""
def __init__(self, z_dim=4): # Even smaller latent space
super().__init__()
self.z_dim = z_dim
# Encoder: 3 -> 16 -> z_dim (minimal for single circle)
self.encoder = nn.Sequential(
SimpleConv3d(3, 16, 3, 1, 1), nn.ReLU(),
SimpleConv3d(16, z_dim, 3, 2, 1) # Downsample to 16x16
)
# Decoder: z_dim -> 16 -> 3
self.decoder = nn.Sequential(
SimpleConv3d(z_dim, 16, 3, 1, 1), nn.ReLU(),
nn.Upsample(scale_factor=(1, 2, 2), mode='nearest'), # Back to 32x32
SimpleConv3d(16, 3, 3, 1, 1), nn.Tanh()
)
def encode(self, x):
return self.encoder(x)
def decode(self, z):
return self.decoder(z)
# ============================================================================
# Lightweight Text Encoder (Memory Efficient)
# ============================================================================
class TinyTextEncoder(nn.Module):
"""Tiny text encoder for single prompt overfitting"""
def __init__(self, vocab_size=20, dim=64, num_layers=2): # Much smaller
super().__init__()
self.dim = dim
self.embedding = nn.Embedding(vocab_size, dim)
# Just 2 simple layers for single prompt
self.blocks = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=dim,
nhead=4, # Fewer heads
dim_feedforward=dim*2,
batch_first=True,
norm_first=True
) for _ in range(num_layers)
])
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.embedding(x)
for block in self.blocks:
x = block(x)
return self.norm(x)
# ============================================================================
# Lightweight Diffusion Model (Memory Efficient)
# ============================================================================
class TinyDiffusionModel(nn.Module):
"""Tiny diffusion model with better spatial awareness"""
def __init__(self, in_dim=4, dim=96, text_dim=64, num_layers=4): # More layers for better learning
super().__init__()
self.dim = dim
# Input projection with spatial awareness
self.input_proj = nn.Linear(in_dim, dim)
# Positional encoding for spatial awareness
self.pos_embedding = nn.Parameter(torch.randn(1, 8*16*16, dim) * 0.02) # 8 frames, 16x16 spatial
# Time embedding - more capacity
self.time_embedding = nn.Sequential(
nn.Linear(64, dim), nn.SiLU(),
nn.Linear(dim, dim), nn.SiLU(),
nn.Linear(dim, dim)
)
# Text embedding
self.text_proj = nn.Linear(text_dim, dim)
# More transformer blocks for better spatial-temporal learning
self.blocks = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=dim,
nhead=8, # 96 is divisible by 8
dim_feedforward=dim*2, # Smaller feedforward to prevent overfitting to "all red"
batch_first=True,
norm_first=True,
dropout=0.1 # Add dropout to prevent overfitting
) for _ in range(num_layers)
])
# Output projection with skip connection
self.output_proj = nn.Sequential(
nn.Linear(dim, dim//2), nn.SiLU(),
nn.Dropout(0.1),
nn.Linear(dim//2, in_dim)
)
def forward(self, x, t, context):
# Flatten spatial dimensions
b, c, f, h, w = x.shape
x_flat = x.permute(0, 2, 3, 4, 1).reshape(b, f*h*w, c) # [B, L, C]
# Project input
x_emb = self.input_proj(x_flat)
# Add positional encoding for spatial awareness
x_emb = x_emb + self.pos_embedding[:, :x_emb.size(1), :]
# Time embedding
t_emb = self.time_embedding(sinusoidal_embedding_1d(64, t.float()).float())
t_emb = t_emb.unsqueeze(1).expand(-1, x_emb.size(1), -1)
x_emb = x_emb + t_emb
# Text context
context = self.text_proj(context)
# Transformer blocks
for block in self.blocks:
x_emb = block(x_emb, context)
# Output projection
noise_pred = self.output_proj(x_emb)
# Reshape back
noise_pred = noise_pred.reshape(b, f, h, w, c).permute(0, 4, 1, 2, 3)
return noise_pred
# ============================================================================
# Simple Scheduler and Data Generation
# ============================================================================
class SimpleScheduler:
def __init__(self, num_timesteps=50):
self.num_timesteps = num_timesteps
def add_noise(self, x, noise, t):
"""Add noise to clean data with better scheduling"""
import math
# Better noise schedule - more gradual
alpha = math.cos((float(t) / self.num_timesteps) * math.pi / 2) ** 2
return alpha * x + (1 - alpha) * noise
def step(self, model_output, t, sample):
"""Simple denoising step with better scheduling"""
import math
alpha = math.cos((float(t) / self.num_timesteps) * math.pi / 2) ** 2
return sample - (1 - alpha) * model_output
def create_single_training_video(frames=8, height=32, width=32):
"""Create ONE video with a clear red circle moving on black background"""
video = torch.zeros(3, frames, height, width)
# Create a clear red circle moving across the frame
for f in range(frames):
y, x = torch.meshgrid(torch.arange(height), torch.arange(width), indexing='ij')
# Moving red circle - smaller and more defined
center_x = 8 + (width - 16) * f / (frames - 1) # Move from left to right
center_y = height // 2
# Clear red circle - much smaller and well-defined
dist = torch.sqrt((x - center_x)**2 + (y - center_y)**2)
red_circle = (dist < 6).float() # Smaller, clearer circle
# Set channels - clear contrast
video[0, f] = red_circle # Red channel - only the circle
video[1, f] = torch.zeros_like(red_circle) # No green
video[2, f] = torch.zeros_like(red_circle) # No blue
# Background stays black (zeros)
# Normalize to [-1, 1] for training
video = video * 2 - 1
return video, "red circle moving"
def save_video_as_gif_simple(video_tensor, filename, fps=2):
"""Simple and reliable video saving using PIL"""
try:
# Convert tensor to numpy
video = video_tensor.detach().cpu().numpy().astype(np.float32)
# Normalize to [0, 1]
if video.min() < 0:
video = (video + 1) / 2
video = np.clip(video, 0, 1)
# Convert to uint8 and rearrange dimensions
video = (video * 255).astype(np.uint8)
video = np.transpose(video, (1, 2, 3, 0)) # (T, H, W, C)
# Create PIL images
frames = []
for i in range(video.shape[0]):
frame = video[i]
if frame.shape[2] == 3: # RGB
pil_frame = Image.fromarray(frame, 'RGB')
else: # Grayscale
pil_frame = Image.fromarray(frame[:,:,0], 'L')
frames.append(pil_frame)
# Save as GIF
if len(frames) > 0:
frames[0].save(
filename,
save_all=True,
append_images=frames[1:],
duration=int(1000/fps),
loop=0
)
print(f"βœ… Video saved as: {filename}")
else:
print(f"❌ No frames to save for {filename}")
except Exception as e:
print(f"❌ Error saving {filename}: {str(e)}")
# Save first frame as PNG
try:
video = video_tensor.detach().cpu().numpy().astype(np.float32)
if video.min() < 0:
video = (video + 1) / 2
video = np.clip(video, 0, 1)
first_frame = np.transpose(video[:, 0], (1, 2, 0))
first_frame = (first_frame * 255).astype(np.uint8)
pil_image = Image.fromarray(first_frame, 'RGB')
static_filename = filename.replace('.gif', '_frame0.png')
pil_image.save(static_filename)
print(f"βœ… Saved first frame as: {static_filename}")
except Exception as e2:
print(f"❌ Could not save any version: {str(e2)}")
return filename
def save_video_as_gif(video_tensor, filename, fps=2):
"""Save video tensor as animated GIF"""
try:
# Convert from tensor to numpy and normalize to [0, 1]
video = video_tensor.detach().cpu().numpy().astype(np.float32)
# Handle different input ranges
if video.min() < 0: # Assume [-1, 1] range
video = np.clip((video + 1) / 2, 0, 1)
else: # Assume [0, 1] range
video = np.clip(video, 0, 1)
# Rearrange from (C, T, H, W) to (T, H, W, C)
video = np.transpose(video, (1, 2, 3, 0))
# Convert to uint8 for proper image handling
video = (video * 255).astype(np.uint8)
# Create and save frames as images
frames = []
for i in range(video.shape[0]):
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(video[i])
ax.axis('off')
ax.set_title(f'Frame {i+1}/{len(video)}')
# Convert plot to image
fig.canvas.draw()
buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
buf = buf.reshape(fig.canvas.get_width_height()[::-1] + (3,))
frames.append(buf)
plt.close(fig)
# Save as GIF using matplotlib animation
if len(frames) > 0:
fig, ax = plt.subplots(figsize=(4, 4))
ax.axis('off')
im = ax.imshow(frames[0])
def animate(frame_idx):
im.set_array(frames[frame_idx])
ax.set_title(f'Frame {frame_idx+1}/{len(frames)}')
return [im]
anim = animation.FuncAnimation(fig, animate, frames=len(frames),
interval=1000//fps, repeat=True, blit=False)
# Save with pillow writer
anim.save(filename, writer='pillow', fps=fps)
plt.close(fig)
print(f"βœ… Video saved as: {filename}")
else:
print(f"❌ No frames to save for {filename}")
except Exception as e:
print(f"❌ Error saving {filename}: {str(e)}")
# Fallback: save first frame as static image
try:
video = video_tensor.detach().cpu().numpy().astype(np.float32)
if video.min() < 0:
video = np.clip((video + 1) / 2, 0, 1)
else:
video = np.clip(video, 0, 1)
first_frame = np.transpose(video[:, 0], (1, 2, 0))
plt.figure(figsize=(4, 4))
plt.imshow(first_frame)
plt.axis('off')
plt.title('First Frame (Static)')
static_filename = filename.replace('.gif', '_static.png')
plt.savefig(static_filename, bbox_inches='tight', dpi=150)
plt.close()
print(f"βœ… Saved static image as: {static_filename}")
except:
print(f"❌ Could not save any version of {filename}")
return filename
def display_video_grid(videos, titles, save_dir="generated_videos"):
"""Display multiple videos in a grid and save them"""
os.makedirs(save_dir, exist_ok=True)
n_videos = len(videos)
saved_files = []
# Save individual videos first (try simple method first)
for i, (video, title) in enumerate(zip(videos, titles)):
filename = os.path.join(save_dir, f"{title.replace(' ', '_')}.gif")
try:
save_video_as_gif_simple(video, filename)
except:
print(f"Simple method failed, trying alternative for {title}")
save_video_as_gif(video, filename)
saved_files.append(filename)
# Create grid of first frames
try:
fig, axes = plt.subplots(1, n_videos, figsize=(4*n_videos, 4))
if n_videos == 1:
axes = [axes]
for i, (video, title) in enumerate(zip(videos, titles)):
try:
# Show first frame in grid
video_np = video.detach().cpu().numpy().astype(np.float32)
# Handle different input ranges
if video_np.min() < 0:
video_np = np.clip((video_np + 1) / 2, 0, 1)
else:
video_np = np.clip(video_np, 0, 1)
first_frame = np.transpose(video_np[:, 0], (1, 2, 0))
axes[i].imshow(first_frame)
axes[i].set_title(title)
axes[i].axis('off')
except Exception as e:
print(f"❌ Error displaying video {i}: {str(e)}")
axes[i].text(0.5, 0.5, f'Error\n{title}', ha='center', va='center', transform=axes[i].transAxes)
axes[i].axis('off')
plt.tight_layout()
grid_filename = os.path.join(save_dir, "video_grid.png")
plt.savefig(grid_filename, dpi=150, bbox_inches='tight')
plt.show()
except Exception as e:
print(f"❌ Error creating video grid: {str(e)}")
return saved_files
# ============================================================================
# Complete Training and Inference Pipeline
# ============================================================================
class TinyWanT2V:
def __init__(self, device='cuda'):
self.device = device
# Initialize tiny components perfectly sized for single circle
self.vae = TinyVAE(z_dim=4).to(device).float()
self.text_encoder = TinyTextEncoder(vocab_size=20, dim=64).to(device).float()
self.diffusion_model = TinyDiffusionModel(in_dim=4, dim=96, text_dim=64).to(device).float()
self.scheduler = SimpleScheduler(num_timesteps=50) # Fewer timesteps
# Minimal vocabulary for single prompt
self.vocab = {
'red': 1, 'circle': 2, 'moving': 3, 'area': 4, 'blob': 5, 'color': 6, 'bright': 7
}
def tokenize(self, text, max_length=16):
"""Simple tokenization"""
words = text.lower().split()
tokens = [self.vocab.get(word, 0) for word in words]
tokens = tokens[:max_length] + [0] * (max_length - len(tokens))
return torch.tensor(tokens, device=self.device).unsqueeze(0)
def train(self, num_epochs=300):
"""Train on SINGLE video - INTENSIVE overfitting for 20x longer"""
print("πŸš€ Starting INTENSIVE overfitting on single red circle video...")
print(f"πŸ”₯ Training for {num_epochs} epochs (20x normal) for perfect memorization!")
# Create THE training video
train_video, train_prompt = create_single_training_video()
train_video = train_video.unsqueeze(0).to(self.device).float() # Add batch dim
print(f"οΏ½ TTraining data shape: {train_video.shape}")
print(f"οΏ½ Training prompt: '{train_prompt}'")
print(f"πŸ“ˆ Video range: [{train_video.min():.3f}, {train_video.max():.3f}]")
# Save original training video
self.original_video = train_video.clone()
# Set models to training mode
self.vae.train()
self.text_encoder.train()
self.diffusion_model.train()
# Optimizers with good learning rates for intensive overfitting
vae_optimizer = optim.Adam(self.vae.parameters(), lr=2e-3) # Higher LR for intensive training
text_optimizer = optim.Adam(self.text_encoder.parameters(), lr=2e-3)
diffusion_optimizer = optim.Adam(self.diffusion_model.parameters(), lr=2e-3)
# Learning rate schedulers for long training
vae_scheduler = optim.lr_scheduler.CosineAnnealingLR(vae_optimizer, T_max=num_epochs)
text_scheduler = optim.lr_scheduler.CosineAnnealingLR(text_optimizer, T_max=num_epochs)
diff_scheduler = optim.lr_scheduler.CosineAnnealingLR(diffusion_optimizer, T_max=num_epochs)
# Tokenize the single prompt
text_tokens = self.tokenize(train_prompt)
print("\n🎯 Starting INTENSIVE overfitting loop...")
print("πŸ“ˆ Progress will be shown every 100 epochs...")
print("🎬 Diffusion generations will be shown at 20%, 40%, 60%, 80%, and 100% completion...")
best_total_loss = float('inf')
best_vae_loss = float('inf')
best_diff_loss = float('inf')
# Create directory for training progress videos
os.makedirs("training_progress", exist_ok=True)
# Define checkpoints for generation (every 1/5th of training)
generation_checkpoints = [num_epochs // 5 * i for i in range(1, 6)] # 20%, 40%, 60%, 80%, 100%
for epoch in range(num_epochs):
# Clear gradients
vae_optimizer.zero_grad()
text_optimizer.zero_grad()
diffusion_optimizer.zero_grad()
# VAE reconstruction loss
latents = self.vae.encode(train_video)
reconstructed = self.vae.decode(latents)
vae_loss = F.mse_loss(reconstructed, train_video)
# Text encoding
text_embeddings = self.text_encoder(text_tokens)
# Diffusion training - train on multiple timesteps per epoch for better learning (5x longer)
total_diff_loss = 0
for _ in range(25): # 5x more diffusion steps per epoch for much richer signal
t = torch.randint(0, self.scheduler.num_timesteps, (1,), device=self.device, dtype=torch.float32)
noise = torch.randn_like(latents)
noisy_latents = self.scheduler.add_noise(latents[0], noise[0], t[0]).unsqueeze(0)
predicted_noise = self.diffusion_model(noisy_latents, t, text_embeddings)
diff_loss = F.mse_loss(predicted_noise, noise)
total_diff_loss += diff_loss
diffusion_loss = total_diff_loss / 25 # Average over multiple steps
# Combined loss with equal weighting for intensive training
total_loss = diffusion_loss + vae_loss
# Backward pass
total_loss.backward()
# Gradient clipping for stability during long training
torch.nn.utils.clip_grad_norm_(self.vae.parameters(), 0.5)
torch.nn.utils.clip_grad_norm_(self.text_encoder.parameters(), 0.5)
torch.nn.utils.clip_grad_norm_(self.diffusion_model.parameters(), 0.5)
vae_optimizer.step()
text_optimizer.step()
diffusion_optimizer.step()
# Update learning rates
vae_scheduler.step()
text_scheduler.step()
diff_scheduler.step()
# Track best losses
if total_loss.item() < best_total_loss:
best_total_loss = total_loss.item()
if vae_loss.item() < best_vae_loss:
best_vae_loss = vae_loss.item()
if diffusion_loss.item() < best_diff_loss:
best_diff_loss = diffusion_loss.item()
# Progress reporting every 100 epochs
if epoch % 100 == 0:
current_lr = vae_optimizer.param_groups[0]['lr']
print(f"Epoch {epoch:4d}/{num_epochs}, Total: {total_loss.item():.6f}, VAE: {vae_loss.item():.6f}, Diff: {diffusion_loss.item():.6f}, LR: {current_lr:.6f}")
# Show reconstruction quality
with torch.no_grad():
recon_error = F.mse_loss(reconstructed, train_video).item()
if recon_error < 0.001:
print(f" πŸ”₯ VAE reconstruction PERFECT: {recon_error:.8f}")
elif recon_error < 0.01:
print(f" βœ… VAE reconstruction excellent: {recon_error:.6f}")
elif recon_error < 0.1:
print(f" 🟑 VAE reconstruction good: {recon_error:.6f}")
else:
print(f" πŸ”΄ VAE reconstruction poor: {recon_error:.6f}")
# Memory cleanup during long training
if epoch % 500 == 0 and epoch > 0:
torch.cuda.empty_cache()
gc.collect()
print(f" 🧹 Memory cleanup at epoch {epoch}")
# Generate video at checkpoints to show training progress
if (epoch + 1) in generation_checkpoints:
checkpoint_idx = generation_checkpoints.index(epoch + 1) + 1
progress_percent = checkpoint_idx * 20
print(f"\n🎬 CHECKPOINT {checkpoint_idx}/5 ({progress_percent}% complete) - Generating video...")
# Temporarily set to eval mode for generation
self.vae.eval()
self.text_encoder.eval()
self.diffusion_model.eval()
with torch.no_grad():
# Generate video with current model state
generated_video = self.generate(train_prompt, num_steps=20, show_progress=False)
# Save the generated video
filename = f"training_progress/checkpoint_{checkpoint_idx}_{progress_percent}percent.gif"
save_video_as_gif_simple(generated_video, filename)
# Calculate similarity to training data
similarity = 1 - F.mse_loss(generated_video, train_video[0].cpu()).item()
print(f" πŸ“Š Similarity to training data: {similarity:.4f}")
print(f" πŸ’Ύ Saved: {filename}")
# Set back to training mode
self.vae.train()
self.text_encoder.train()
self.diffusion_model.train()
print(f"πŸ”„ Resuming training...\n")
print("βœ… INTENSIVE overfitting completed!")
print(f"πŸ† Best losses achieved:")
print(f" Total: {best_total_loss:.8f}")
print(f" VAE: {best_vae_loss:.8f}")
print(f" Diffusion: {best_diff_loss:.8f}")
# Set to eval mode
self.vae.eval()
self.text_encoder.eval()
self.diffusion_model.eval()
# Store final reconstruction for comparison
with torch.no_grad():
self.final_latents = self.vae.encode(train_video)
self.vae_reconstruction = self.vae.decode(self.final_latents)
def show_training_results(self):
"""Show original data, VAE reconstruction, and diffusion generation"""
print("\nπŸ“Š TRAINING RESULTS COMPARISON")
print("="*50)
with torch.no_grad():
# 1. Original training data
original = self.original_video[0].cpu()
print(f"1️⃣ Original video range: [{original.min():.3f}, {original.max():.3f}]")
# 2. VAE reconstruction
vae_recon = self.vae_reconstruction[0].cpu()
vae_error = F.mse_loss(vae_recon, original).item()
print(f"2️⃣ VAE reconstruction range: [{vae_recon.min():.3f}, {vae_recon.max():.3f}]")
print(f" VAE reconstruction error: {vae_error:.6f}")
# 3. Diffusion generation
generated = self.generate("red circle moving", num_steps=20, show_progress=False)
print(f"3️⃣ Generated video range: [{generated.min():.3f}, {generated.max():.3f}]")
# Save all three for comparison
videos_to_save = [original, vae_recon, generated]
titles = ["1_Original_Training_Data", "2_VAE_Reconstruction", "3_Diffusion_Generated"]
print("\nπŸ’Ύ Saving comparison videos...")
saved_files = display_video_grid(videos_to_save, titles, save_dir="training_results")
return saved_files
def generate(self, prompt, num_steps=20, frames=8, height=32, width=32, show_progress=True):
"""Generate video from text prompt"""
if show_progress:
print(f"🎬 Generating video for prompt: '{prompt}'")
with torch.no_grad():
# Tokenize prompt
text_tokens = self.tokenize(prompt)
text_embeddings = self.text_encoder(text_tokens)
# Start with random noise in latent space (smaller now)
latents = torch.randn(1, 4, frames, height//2, width//2, device=self.device, dtype=torch.float32)
# Simple denoising loop
for step in range(num_steps):
t = torch.tensor([num_steps - step - 1], device=self.device, dtype=torch.float32)
# Predict noise
predicted_noise = self.diffusion_model(latents, t, text_embeddings)
# Simple denoising step
latents = self.scheduler.step(predicted_noise, t.item(), latents)
if show_progress and step % 5 == 0:
print(f" Denoising step {step}/{num_steps}")
# Decode to video
video = self.vae.decode(latents)
if show_progress:
print("βœ… Video generation completed!")
return video.squeeze(0).cpu()
# ============================================================================
# Main Execution (Google Colab Ready)
# ============================================================================
def main():
"""Main function - overfit on single circle and show all steps"""
print("🎯 Initializing Tiny Wan Text-to-Video Model for Single Circle Overfitting")
print("="*70)
# Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# Show model sizes
print("\nπŸ“ Model Architecture:")
model = TinyWanT2V(device=device)
vae_params = sum(p.numel() for p in model.vae.parameters())
text_params = sum(p.numel() for p in model.text_encoder.parameters())
diff_params = sum(p.numel() for p in model.diffusion_model.parameters())
total_params = vae_params + text_params + diff_params
print(f" VAE parameters: {vae_params:,}")
print(f" Text Encoder parameters: {text_params:,}")
print(f" Diffusion Model parameters: {diff_params:,}")
print(f" Total parameters: {total_params:,}")
# Show training data
print("\nπŸ“Š STEP 1: TRAINING DATA")
print("-" * 30)
train_video, train_prompt = create_single_training_video()
print(f"Video shape: {train_video.shape}")
print(f"Video range: [{train_video.min():.3f}, {train_video.max():.3f}]")
print(f"Prompt: '{train_prompt}'")
# Save training data
os.makedirs("training_results", exist_ok=True)
save_video_as_gif_simple(train_video, "training_results/0_training_data.gif")
print("βœ… Training data saved as: training_results/0_training_data.gif")
# Train the model (overfit on single video) - MUCH longer training
print("\n🎯 STEP 2: INTENSIVE OVERFITTING TRAINING")
print("-" * 30)
model.train(num_epochs=300) # 20x longer training!
# Show training results
print("\nπŸ“ˆ STEP 3: RESULTS COMPARISON")
print("-" * 30)
saved_files = model.show_training_results()
# Test generation with same and different prompts
print("\nοΏ½ STEP 4: GENERATION TESTS")
print("-" * 30)
test_prompts = [
"red moving area", # Same as training
"moving red area", # Slightly different order
"red area", # Shorter
]
for i, prompt in enumerate(test_prompts):
print(f"\nTest {i+1}: '{prompt}'")
video = model.generate(prompt, num_steps=20, show_progress=False)
# Calculate similarity to training data
similarity = 1 - F.mse_loss(video, model.original_video[0].cpu()).item()
print(f" Similarity to training: {similarity:.4f}")
print(f" Video std: {video.std().item():.4f}")
# Save
os.makedirs("training_results", exist_ok=True)
filename = f"training_results/test_{i+1}_{prompt.replace(' ', '_')}.gif"
save_video_as_gif_simple(video, filename)
print(f" Saved: {filename}")
print("\nπŸŽ‰ COMPLETE PIPELINE DEMONSTRATION FINISHED!")
print("="*70)
print("πŸ“‚ Check 'training_results/' folder for all videos:")
print(" - 0_training_data.gif (original)")
print(" - 1_Original_Training_Data.gif (same as above)")
print(" - 2_VAE_Reconstruction.gif (autoencoder output)")
print(" - 3_Diffusion_Generated.gif (full pipeline)")
print(" - test_*.gif (generation tests)")
print("\nοΏ½ Check 'twraining_progress/' folder for diffusion training evolution:")
print(" - checkpoint_1_20percent.gif (20% trained)")
print(" - checkpoint_2_40percent.gif (40% trained)")
print(" - checkpoint_3_60percent.gif (60% trained)")
print(" - checkpoint_4_80percent.gif (80% trained)")
print(" - checkpoint_5_100percent.gif (100% trained)")
print("\nπŸ’‘ This shows: Data β†’ VAE β†’ Diffusion β†’ Generation pipeline")
print("🎯 Perfect overfitting means VAE and Diffusion should reproduce the training circle!")
print("πŸ“ˆ Training progress videos show how diffusion model learns over time!")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment