Last active
August 5, 2025 18:43
-
-
Save vukrosic/a8c75ddb54eb5c3276b233f8b481977c to your computer and use it in GitHub Desktop.
Small text to video diffusion skeleton (generates 1 small video and overfits on it), build it from here
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Memory-Efficient Wan Text-to-Video Model - Single File Implementation | |
| Generates synthetic data, trains on it, and demonstrates inference | |
| Optimized for Google Colab with minimal GPU memory usage | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| from einops import rearrange | |
| import torch.optim as optim | |
| from torch.amp import autocast, GradScaler | |
| import gc | |
| import matplotlib.pyplot as plt | |
| import matplotlib.animation as animation | |
| from IPython.display import HTML, display | |
| import os | |
| from PIL import Image | |
| # ============================================================================ | |
| # Core Utilities and Helpers | |
| # ============================================================================ | |
| def sinusoidal_embedding_1d(dim, position): | |
| assert dim % 2 == 0 | |
| half = dim // 2 | |
| position = position.float() # Use float32 instead of float64 | |
| sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half, device=position.device, dtype=position.dtype).div(half))) | |
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) | |
| return x.float() # Ensure output is float32 | |
| @torch.amp.autocast('cuda', enabled=False) | |
| def rope_params(max_seq_len, dim, theta=10000): | |
| assert dim % 2 == 0 | |
| freqs = torch.outer( | |
| torch.arange(max_seq_len), | |
| 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)) | |
| ) | |
| freqs = torch.polar(torch.ones_like(freqs), freqs) | |
| return freqs | |
| @torch.amp.autocast('cuda', enabled=False) | |
| def rope_apply(x, grid_sizes, freqs): | |
| n, c = x.size(2), x.size(3) // 2 | |
| freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) | |
| output = [] | |
| for i, (f, h, w) in enumerate(grid_sizes.tolist()): | |
| seq_len = f * h * w | |
| x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)) | |
| freqs_i = torch.cat([ | |
| freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), | |
| freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), | |
| freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) | |
| ], dim=-1).reshape(seq_len, 1, -1) | |
| x_i = torch.view_as_real(x_i * freqs_i).flatten(2) | |
| x_i = torch.cat([x_i, x[i, seq_len:]]) | |
| output.append(x_i) | |
| return torch.stack(output).float() | |
| def flash_attention_fallback(q, k, v, k_lens=None): | |
| """Fallback attention implementation when flash attention is not available""" | |
| if k_lens is not None: | |
| # Create attention mask based on sequence lengths | |
| b, lq, lk = q.size(0), q.size(1), k.size(1) | |
| mask = torch.arange(lk, device=k.device).unsqueeze(0) < k_lens.unsqueeze(1) | |
| mask = mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, Lk] | |
| else: | |
| mask = None | |
| q = q.transpose(1, 2) # [B, H, Lq, D] | |
| k = k.transpose(1, 2) # [B, H, Lk, D] | |
| v = v.transpose(1, 2) # [B, H, Lk, D] | |
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) | |
| return out.transpose(1, 2).contiguous() # [B, Lq, H, D] | |
| # ============================================================================ | |
| # Memory-Efficient VAE Components (Simplified) | |
| # ============================================================================ | |
| class SimpleConv3d(nn.Module): | |
| """Memory-efficient 3D convolution using 2D conv + 1D temporal""" | |
| def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): | |
| super().__init__() | |
| self.spatial_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) | |
| self.temporal_conv = nn.Conv1d(out_channels, out_channels, 3, padding=1) | |
| def forward(self, x): | |
| b, c, t, h, w = x.shape | |
| # Spatial convolution | |
| x = rearrange(x, 'b c t h w -> (b t) c h w') | |
| x = self.spatial_conv(x) | |
| _, c_new, h_new, w_new = x.shape | |
| x = rearrange(x, '(b t) c h w -> b c t h w', b=b, t=t) | |
| # Temporal convolution - fix the shape calculation | |
| x = rearrange(x, 'b c t h w -> (b h w) c t') | |
| x = self.temporal_conv(x) | |
| x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h_new, w=w_new) | |
| return x | |
| class TinyVAE(nn.Module): | |
| """Tiny VAE perfectly sized for single circle task""" | |
| def __init__(self, z_dim=4): # Even smaller latent space | |
| super().__init__() | |
| self.z_dim = z_dim | |
| # Encoder: 3 -> 16 -> z_dim (minimal for single circle) | |
| self.encoder = nn.Sequential( | |
| SimpleConv3d(3, 16, 3, 1, 1), nn.ReLU(), | |
| SimpleConv3d(16, z_dim, 3, 2, 1) # Downsample to 16x16 | |
| ) | |
| # Decoder: z_dim -> 16 -> 3 | |
| self.decoder = nn.Sequential( | |
| SimpleConv3d(z_dim, 16, 3, 1, 1), nn.ReLU(), | |
| nn.Upsample(scale_factor=(1, 2, 2), mode='nearest'), # Back to 32x32 | |
| SimpleConv3d(16, 3, 3, 1, 1), nn.Tanh() | |
| ) | |
| def encode(self, x): | |
| return self.encoder(x) | |
| def decode(self, z): | |
| return self.decoder(z) | |
| # ============================================================================ | |
| # Lightweight Text Encoder (Memory Efficient) | |
| # ============================================================================ | |
| class TinyTextEncoder(nn.Module): | |
| """Tiny text encoder for single prompt overfitting""" | |
| def __init__(self, vocab_size=20, dim=64, num_layers=2): # Much smaller | |
| super().__init__() | |
| self.dim = dim | |
| self.embedding = nn.Embedding(vocab_size, dim) | |
| # Just 2 simple layers for single prompt | |
| self.blocks = nn.ModuleList([ | |
| nn.TransformerEncoderLayer( | |
| d_model=dim, | |
| nhead=4, # Fewer heads | |
| dim_feedforward=dim*2, | |
| batch_first=True, | |
| norm_first=True | |
| ) for _ in range(num_layers) | |
| ]) | |
| self.norm = nn.LayerNorm(dim) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| return self.norm(x) | |
| # ============================================================================ | |
| # Lightweight Diffusion Model (Memory Efficient) | |
| # ============================================================================ | |
| class TinyDiffusionModel(nn.Module): | |
| """Tiny diffusion model with better spatial awareness""" | |
| def __init__(self, in_dim=4, dim=96, text_dim=64, num_layers=4): # More layers for better learning | |
| super().__init__() | |
| self.dim = dim | |
| # Input projection with spatial awareness | |
| self.input_proj = nn.Linear(in_dim, dim) | |
| # Positional encoding for spatial awareness | |
| self.pos_embedding = nn.Parameter(torch.randn(1, 8*16*16, dim) * 0.02) # 8 frames, 16x16 spatial | |
| # Time embedding - more capacity | |
| self.time_embedding = nn.Sequential( | |
| nn.Linear(64, dim), nn.SiLU(), | |
| nn.Linear(dim, dim), nn.SiLU(), | |
| nn.Linear(dim, dim) | |
| ) | |
| # Text embedding | |
| self.text_proj = nn.Linear(text_dim, dim) | |
| # More transformer blocks for better spatial-temporal learning | |
| self.blocks = nn.ModuleList([ | |
| nn.TransformerDecoderLayer( | |
| d_model=dim, | |
| nhead=8, # 96 is divisible by 8 | |
| dim_feedforward=dim*2, # Smaller feedforward to prevent overfitting to "all red" | |
| batch_first=True, | |
| norm_first=True, | |
| dropout=0.1 # Add dropout to prevent overfitting | |
| ) for _ in range(num_layers) | |
| ]) | |
| # Output projection with skip connection | |
| self.output_proj = nn.Sequential( | |
| nn.Linear(dim, dim//2), nn.SiLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(dim//2, in_dim) | |
| ) | |
| def forward(self, x, t, context): | |
| # Flatten spatial dimensions | |
| b, c, f, h, w = x.shape | |
| x_flat = x.permute(0, 2, 3, 4, 1).reshape(b, f*h*w, c) # [B, L, C] | |
| # Project input | |
| x_emb = self.input_proj(x_flat) | |
| # Add positional encoding for spatial awareness | |
| x_emb = x_emb + self.pos_embedding[:, :x_emb.size(1), :] | |
| # Time embedding | |
| t_emb = self.time_embedding(sinusoidal_embedding_1d(64, t.float()).float()) | |
| t_emb = t_emb.unsqueeze(1).expand(-1, x_emb.size(1), -1) | |
| x_emb = x_emb + t_emb | |
| # Text context | |
| context = self.text_proj(context) | |
| # Transformer blocks | |
| for block in self.blocks: | |
| x_emb = block(x_emb, context) | |
| # Output projection | |
| noise_pred = self.output_proj(x_emb) | |
| # Reshape back | |
| noise_pred = noise_pred.reshape(b, f, h, w, c).permute(0, 4, 1, 2, 3) | |
| return noise_pred | |
| # ============================================================================ | |
| # Simple Scheduler and Data Generation | |
| # ============================================================================ | |
| class SimpleScheduler: | |
| def __init__(self, num_timesteps=50): | |
| self.num_timesteps = num_timesteps | |
| def add_noise(self, x, noise, t): | |
| """Add noise to clean data with better scheduling""" | |
| import math | |
| # Better noise schedule - more gradual | |
| alpha = math.cos((float(t) / self.num_timesteps) * math.pi / 2) ** 2 | |
| return alpha * x + (1 - alpha) * noise | |
| def step(self, model_output, t, sample): | |
| """Simple denoising step with better scheduling""" | |
| import math | |
| alpha = math.cos((float(t) / self.num_timesteps) * math.pi / 2) ** 2 | |
| return sample - (1 - alpha) * model_output | |
| def create_single_training_video(frames=8, height=32, width=32): | |
| """Create ONE video with a clear red circle moving on black background""" | |
| video = torch.zeros(3, frames, height, width) | |
| # Create a clear red circle moving across the frame | |
| for f in range(frames): | |
| y, x = torch.meshgrid(torch.arange(height), torch.arange(width), indexing='ij') | |
| # Moving red circle - smaller and more defined | |
| center_x = 8 + (width - 16) * f / (frames - 1) # Move from left to right | |
| center_y = height // 2 | |
| # Clear red circle - much smaller and well-defined | |
| dist = torch.sqrt((x - center_x)**2 + (y - center_y)**2) | |
| red_circle = (dist < 6).float() # Smaller, clearer circle | |
| # Set channels - clear contrast | |
| video[0, f] = red_circle # Red channel - only the circle | |
| video[1, f] = torch.zeros_like(red_circle) # No green | |
| video[2, f] = torch.zeros_like(red_circle) # No blue | |
| # Background stays black (zeros) | |
| # Normalize to [-1, 1] for training | |
| video = video * 2 - 1 | |
| return video, "red circle moving" | |
| def save_video_as_gif_simple(video_tensor, filename, fps=2): | |
| """Simple and reliable video saving using PIL""" | |
| try: | |
| # Convert tensor to numpy | |
| video = video_tensor.detach().cpu().numpy().astype(np.float32) | |
| # Normalize to [0, 1] | |
| if video.min() < 0: | |
| video = (video + 1) / 2 | |
| video = np.clip(video, 0, 1) | |
| # Convert to uint8 and rearrange dimensions | |
| video = (video * 255).astype(np.uint8) | |
| video = np.transpose(video, (1, 2, 3, 0)) # (T, H, W, C) | |
| # Create PIL images | |
| frames = [] | |
| for i in range(video.shape[0]): | |
| frame = video[i] | |
| if frame.shape[2] == 3: # RGB | |
| pil_frame = Image.fromarray(frame, 'RGB') | |
| else: # Grayscale | |
| pil_frame = Image.fromarray(frame[:,:,0], 'L') | |
| frames.append(pil_frame) | |
| # Save as GIF | |
| if len(frames) > 0: | |
| frames[0].save( | |
| filename, | |
| save_all=True, | |
| append_images=frames[1:], | |
| duration=int(1000/fps), | |
| loop=0 | |
| ) | |
| print(f"β Video saved as: {filename}") | |
| else: | |
| print(f"β No frames to save for {filename}") | |
| except Exception as e: | |
| print(f"β Error saving {filename}: {str(e)}") | |
| # Save first frame as PNG | |
| try: | |
| video = video_tensor.detach().cpu().numpy().astype(np.float32) | |
| if video.min() < 0: | |
| video = (video + 1) / 2 | |
| video = np.clip(video, 0, 1) | |
| first_frame = np.transpose(video[:, 0], (1, 2, 0)) | |
| first_frame = (first_frame * 255).astype(np.uint8) | |
| pil_image = Image.fromarray(first_frame, 'RGB') | |
| static_filename = filename.replace('.gif', '_frame0.png') | |
| pil_image.save(static_filename) | |
| print(f"β Saved first frame as: {static_filename}") | |
| except Exception as e2: | |
| print(f"β Could not save any version: {str(e2)}") | |
| return filename | |
| def save_video_as_gif(video_tensor, filename, fps=2): | |
| """Save video tensor as animated GIF""" | |
| try: | |
| # Convert from tensor to numpy and normalize to [0, 1] | |
| video = video_tensor.detach().cpu().numpy().astype(np.float32) | |
| # Handle different input ranges | |
| if video.min() < 0: # Assume [-1, 1] range | |
| video = np.clip((video + 1) / 2, 0, 1) | |
| else: # Assume [0, 1] range | |
| video = np.clip(video, 0, 1) | |
| # Rearrange from (C, T, H, W) to (T, H, W, C) | |
| video = np.transpose(video, (1, 2, 3, 0)) | |
| # Convert to uint8 for proper image handling | |
| video = (video * 255).astype(np.uint8) | |
| # Create and save frames as images | |
| frames = [] | |
| for i in range(video.shape[0]): | |
| fig, ax = plt.subplots(figsize=(4, 4)) | |
| ax.imshow(video[i]) | |
| ax.axis('off') | |
| ax.set_title(f'Frame {i+1}/{len(video)}') | |
| # Convert plot to image | |
| fig.canvas.draw() | |
| buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| buf = buf.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| frames.append(buf) | |
| plt.close(fig) | |
| # Save as GIF using matplotlib animation | |
| if len(frames) > 0: | |
| fig, ax = plt.subplots(figsize=(4, 4)) | |
| ax.axis('off') | |
| im = ax.imshow(frames[0]) | |
| def animate(frame_idx): | |
| im.set_array(frames[frame_idx]) | |
| ax.set_title(f'Frame {frame_idx+1}/{len(frames)}') | |
| return [im] | |
| anim = animation.FuncAnimation(fig, animate, frames=len(frames), | |
| interval=1000//fps, repeat=True, blit=False) | |
| # Save with pillow writer | |
| anim.save(filename, writer='pillow', fps=fps) | |
| plt.close(fig) | |
| print(f"β Video saved as: {filename}") | |
| else: | |
| print(f"β No frames to save for {filename}") | |
| except Exception as e: | |
| print(f"β Error saving {filename}: {str(e)}") | |
| # Fallback: save first frame as static image | |
| try: | |
| video = video_tensor.detach().cpu().numpy().astype(np.float32) | |
| if video.min() < 0: | |
| video = np.clip((video + 1) / 2, 0, 1) | |
| else: | |
| video = np.clip(video, 0, 1) | |
| first_frame = np.transpose(video[:, 0], (1, 2, 0)) | |
| plt.figure(figsize=(4, 4)) | |
| plt.imshow(first_frame) | |
| plt.axis('off') | |
| plt.title('First Frame (Static)') | |
| static_filename = filename.replace('.gif', '_static.png') | |
| plt.savefig(static_filename, bbox_inches='tight', dpi=150) | |
| plt.close() | |
| print(f"β Saved static image as: {static_filename}") | |
| except: | |
| print(f"β Could not save any version of {filename}") | |
| return filename | |
| def display_video_grid(videos, titles, save_dir="generated_videos"): | |
| """Display multiple videos in a grid and save them""" | |
| os.makedirs(save_dir, exist_ok=True) | |
| n_videos = len(videos) | |
| saved_files = [] | |
| # Save individual videos first (try simple method first) | |
| for i, (video, title) in enumerate(zip(videos, titles)): | |
| filename = os.path.join(save_dir, f"{title.replace(' ', '_')}.gif") | |
| try: | |
| save_video_as_gif_simple(video, filename) | |
| except: | |
| print(f"Simple method failed, trying alternative for {title}") | |
| save_video_as_gif(video, filename) | |
| saved_files.append(filename) | |
| # Create grid of first frames | |
| try: | |
| fig, axes = plt.subplots(1, n_videos, figsize=(4*n_videos, 4)) | |
| if n_videos == 1: | |
| axes = [axes] | |
| for i, (video, title) in enumerate(zip(videos, titles)): | |
| try: | |
| # Show first frame in grid | |
| video_np = video.detach().cpu().numpy().astype(np.float32) | |
| # Handle different input ranges | |
| if video_np.min() < 0: | |
| video_np = np.clip((video_np + 1) / 2, 0, 1) | |
| else: | |
| video_np = np.clip(video_np, 0, 1) | |
| first_frame = np.transpose(video_np[:, 0], (1, 2, 0)) | |
| axes[i].imshow(first_frame) | |
| axes[i].set_title(title) | |
| axes[i].axis('off') | |
| except Exception as e: | |
| print(f"β Error displaying video {i}: {str(e)}") | |
| axes[i].text(0.5, 0.5, f'Error\n{title}', ha='center', va='center', transform=axes[i].transAxes) | |
| axes[i].axis('off') | |
| plt.tight_layout() | |
| grid_filename = os.path.join(save_dir, "video_grid.png") | |
| plt.savefig(grid_filename, dpi=150, bbox_inches='tight') | |
| plt.show() | |
| except Exception as e: | |
| print(f"β Error creating video grid: {str(e)}") | |
| return saved_files | |
| # ============================================================================ | |
| # Complete Training and Inference Pipeline | |
| # ============================================================================ | |
| class TinyWanT2V: | |
| def __init__(self, device='cuda'): | |
| self.device = device | |
| # Initialize tiny components perfectly sized for single circle | |
| self.vae = TinyVAE(z_dim=4).to(device).float() | |
| self.text_encoder = TinyTextEncoder(vocab_size=20, dim=64).to(device).float() | |
| self.diffusion_model = TinyDiffusionModel(in_dim=4, dim=96, text_dim=64).to(device).float() | |
| self.scheduler = SimpleScheduler(num_timesteps=50) # Fewer timesteps | |
| # Minimal vocabulary for single prompt | |
| self.vocab = { | |
| 'red': 1, 'circle': 2, 'moving': 3, 'area': 4, 'blob': 5, 'color': 6, 'bright': 7 | |
| } | |
| def tokenize(self, text, max_length=16): | |
| """Simple tokenization""" | |
| words = text.lower().split() | |
| tokens = [self.vocab.get(word, 0) for word in words] | |
| tokens = tokens[:max_length] + [0] * (max_length - len(tokens)) | |
| return torch.tensor(tokens, device=self.device).unsqueeze(0) | |
| def train(self, num_epochs=300): | |
| """Train on SINGLE video - INTENSIVE overfitting for 20x longer""" | |
| print("π Starting INTENSIVE overfitting on single red circle video...") | |
| print(f"π₯ Training for {num_epochs} epochs (20x normal) for perfect memorization!") | |
| # Create THE training video | |
| train_video, train_prompt = create_single_training_video() | |
| train_video = train_video.unsqueeze(0).to(self.device).float() # Add batch dim | |
| print(f"οΏ½ TTraining data shape: {train_video.shape}") | |
| print(f"οΏ½ Training prompt: '{train_prompt}'") | |
| print(f"π Video range: [{train_video.min():.3f}, {train_video.max():.3f}]") | |
| # Save original training video | |
| self.original_video = train_video.clone() | |
| # Set models to training mode | |
| self.vae.train() | |
| self.text_encoder.train() | |
| self.diffusion_model.train() | |
| # Optimizers with good learning rates for intensive overfitting | |
| vae_optimizer = optim.Adam(self.vae.parameters(), lr=2e-3) # Higher LR for intensive training | |
| text_optimizer = optim.Adam(self.text_encoder.parameters(), lr=2e-3) | |
| diffusion_optimizer = optim.Adam(self.diffusion_model.parameters(), lr=2e-3) | |
| # Learning rate schedulers for long training | |
| vae_scheduler = optim.lr_scheduler.CosineAnnealingLR(vae_optimizer, T_max=num_epochs) | |
| text_scheduler = optim.lr_scheduler.CosineAnnealingLR(text_optimizer, T_max=num_epochs) | |
| diff_scheduler = optim.lr_scheduler.CosineAnnealingLR(diffusion_optimizer, T_max=num_epochs) | |
| # Tokenize the single prompt | |
| text_tokens = self.tokenize(train_prompt) | |
| print("\nπ― Starting INTENSIVE overfitting loop...") | |
| print("π Progress will be shown every 100 epochs...") | |
| print("π¬ Diffusion generations will be shown at 20%, 40%, 60%, 80%, and 100% completion...") | |
| best_total_loss = float('inf') | |
| best_vae_loss = float('inf') | |
| best_diff_loss = float('inf') | |
| # Create directory for training progress videos | |
| os.makedirs("training_progress", exist_ok=True) | |
| # Define checkpoints for generation (every 1/5th of training) | |
| generation_checkpoints = [num_epochs // 5 * i for i in range(1, 6)] # 20%, 40%, 60%, 80%, 100% | |
| for epoch in range(num_epochs): | |
| # Clear gradients | |
| vae_optimizer.zero_grad() | |
| text_optimizer.zero_grad() | |
| diffusion_optimizer.zero_grad() | |
| # VAE reconstruction loss | |
| latents = self.vae.encode(train_video) | |
| reconstructed = self.vae.decode(latents) | |
| vae_loss = F.mse_loss(reconstructed, train_video) | |
| # Text encoding | |
| text_embeddings = self.text_encoder(text_tokens) | |
| # Diffusion training - train on multiple timesteps per epoch for better learning (5x longer) | |
| total_diff_loss = 0 | |
| for _ in range(25): # 5x more diffusion steps per epoch for much richer signal | |
| t = torch.randint(0, self.scheduler.num_timesteps, (1,), device=self.device, dtype=torch.float32) | |
| noise = torch.randn_like(latents) | |
| noisy_latents = self.scheduler.add_noise(latents[0], noise[0], t[0]).unsqueeze(0) | |
| predicted_noise = self.diffusion_model(noisy_latents, t, text_embeddings) | |
| diff_loss = F.mse_loss(predicted_noise, noise) | |
| total_diff_loss += diff_loss | |
| diffusion_loss = total_diff_loss / 25 # Average over multiple steps | |
| # Combined loss with equal weighting for intensive training | |
| total_loss = diffusion_loss + vae_loss | |
| # Backward pass | |
| total_loss.backward() | |
| # Gradient clipping for stability during long training | |
| torch.nn.utils.clip_grad_norm_(self.vae.parameters(), 0.5) | |
| torch.nn.utils.clip_grad_norm_(self.text_encoder.parameters(), 0.5) | |
| torch.nn.utils.clip_grad_norm_(self.diffusion_model.parameters(), 0.5) | |
| vae_optimizer.step() | |
| text_optimizer.step() | |
| diffusion_optimizer.step() | |
| # Update learning rates | |
| vae_scheduler.step() | |
| text_scheduler.step() | |
| diff_scheduler.step() | |
| # Track best losses | |
| if total_loss.item() < best_total_loss: | |
| best_total_loss = total_loss.item() | |
| if vae_loss.item() < best_vae_loss: | |
| best_vae_loss = vae_loss.item() | |
| if diffusion_loss.item() < best_diff_loss: | |
| best_diff_loss = diffusion_loss.item() | |
| # Progress reporting every 100 epochs | |
| if epoch % 100 == 0: | |
| current_lr = vae_optimizer.param_groups[0]['lr'] | |
| print(f"Epoch {epoch:4d}/{num_epochs}, Total: {total_loss.item():.6f}, VAE: {vae_loss.item():.6f}, Diff: {diffusion_loss.item():.6f}, LR: {current_lr:.6f}") | |
| # Show reconstruction quality | |
| with torch.no_grad(): | |
| recon_error = F.mse_loss(reconstructed, train_video).item() | |
| if recon_error < 0.001: | |
| print(f" π₯ VAE reconstruction PERFECT: {recon_error:.8f}") | |
| elif recon_error < 0.01: | |
| print(f" β VAE reconstruction excellent: {recon_error:.6f}") | |
| elif recon_error < 0.1: | |
| print(f" π‘ VAE reconstruction good: {recon_error:.6f}") | |
| else: | |
| print(f" π΄ VAE reconstruction poor: {recon_error:.6f}") | |
| # Memory cleanup during long training | |
| if epoch % 500 == 0 and epoch > 0: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print(f" π§Ή Memory cleanup at epoch {epoch}") | |
| # Generate video at checkpoints to show training progress | |
| if (epoch + 1) in generation_checkpoints: | |
| checkpoint_idx = generation_checkpoints.index(epoch + 1) + 1 | |
| progress_percent = checkpoint_idx * 20 | |
| print(f"\n㪠CHECKPOINT {checkpoint_idx}/5 ({progress_percent}% complete) - Generating video...") | |
| # Temporarily set to eval mode for generation | |
| self.vae.eval() | |
| self.text_encoder.eval() | |
| self.diffusion_model.eval() | |
| with torch.no_grad(): | |
| # Generate video with current model state | |
| generated_video = self.generate(train_prompt, num_steps=20, show_progress=False) | |
| # Save the generated video | |
| filename = f"training_progress/checkpoint_{checkpoint_idx}_{progress_percent}percent.gif" | |
| save_video_as_gif_simple(generated_video, filename) | |
| # Calculate similarity to training data | |
| similarity = 1 - F.mse_loss(generated_video, train_video[0].cpu()).item() | |
| print(f" π Similarity to training data: {similarity:.4f}") | |
| print(f" πΎ Saved: {filename}") | |
| # Set back to training mode | |
| self.vae.train() | |
| self.text_encoder.train() | |
| self.diffusion_model.train() | |
| print(f"π Resuming training...\n") | |
| print("β INTENSIVE overfitting completed!") | |
| print(f"π Best losses achieved:") | |
| print(f" Total: {best_total_loss:.8f}") | |
| print(f" VAE: {best_vae_loss:.8f}") | |
| print(f" Diffusion: {best_diff_loss:.8f}") | |
| # Set to eval mode | |
| self.vae.eval() | |
| self.text_encoder.eval() | |
| self.diffusion_model.eval() | |
| # Store final reconstruction for comparison | |
| with torch.no_grad(): | |
| self.final_latents = self.vae.encode(train_video) | |
| self.vae_reconstruction = self.vae.decode(self.final_latents) | |
| def show_training_results(self): | |
| """Show original data, VAE reconstruction, and diffusion generation""" | |
| print("\nπ TRAINING RESULTS COMPARISON") | |
| print("="*50) | |
| with torch.no_grad(): | |
| # 1. Original training data | |
| original = self.original_video[0].cpu() | |
| print(f"1οΈβ£ Original video range: [{original.min():.3f}, {original.max():.3f}]") | |
| # 2. VAE reconstruction | |
| vae_recon = self.vae_reconstruction[0].cpu() | |
| vae_error = F.mse_loss(vae_recon, original).item() | |
| print(f"2οΈβ£ VAE reconstruction range: [{vae_recon.min():.3f}, {vae_recon.max():.3f}]") | |
| print(f" VAE reconstruction error: {vae_error:.6f}") | |
| # 3. Diffusion generation | |
| generated = self.generate("red circle moving", num_steps=20, show_progress=False) | |
| print(f"3οΈβ£ Generated video range: [{generated.min():.3f}, {generated.max():.3f}]") | |
| # Save all three for comparison | |
| videos_to_save = [original, vae_recon, generated] | |
| titles = ["1_Original_Training_Data", "2_VAE_Reconstruction", "3_Diffusion_Generated"] | |
| print("\nπΎ Saving comparison videos...") | |
| saved_files = display_video_grid(videos_to_save, titles, save_dir="training_results") | |
| return saved_files | |
| def generate(self, prompt, num_steps=20, frames=8, height=32, width=32, show_progress=True): | |
| """Generate video from text prompt""" | |
| if show_progress: | |
| print(f"π¬ Generating video for prompt: '{prompt}'") | |
| with torch.no_grad(): | |
| # Tokenize prompt | |
| text_tokens = self.tokenize(prompt) | |
| text_embeddings = self.text_encoder(text_tokens) | |
| # Start with random noise in latent space (smaller now) | |
| latents = torch.randn(1, 4, frames, height//2, width//2, device=self.device, dtype=torch.float32) | |
| # Simple denoising loop | |
| for step in range(num_steps): | |
| t = torch.tensor([num_steps - step - 1], device=self.device, dtype=torch.float32) | |
| # Predict noise | |
| predicted_noise = self.diffusion_model(latents, t, text_embeddings) | |
| # Simple denoising step | |
| latents = self.scheduler.step(predicted_noise, t.item(), latents) | |
| if show_progress and step % 5 == 0: | |
| print(f" Denoising step {step}/{num_steps}") | |
| # Decode to video | |
| video = self.vae.decode(latents) | |
| if show_progress: | |
| print("β Video generation completed!") | |
| return video.squeeze(0).cpu() | |
| # ============================================================================ | |
| # Main Execution (Google Colab Ready) | |
| # ============================================================================ | |
| def main(): | |
| """Main function - overfit on single circle and show all steps""" | |
| print("π― Initializing Tiny Wan Text-to-Video Model for Single Circle Overfitting") | |
| print("="*70) | |
| # Check device | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"Using device: {device}") | |
| if device == 'cuda': | |
| print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") | |
| # Show model sizes | |
| print("\nπ Model Architecture:") | |
| model = TinyWanT2V(device=device) | |
| vae_params = sum(p.numel() for p in model.vae.parameters()) | |
| text_params = sum(p.numel() for p in model.text_encoder.parameters()) | |
| diff_params = sum(p.numel() for p in model.diffusion_model.parameters()) | |
| total_params = vae_params + text_params + diff_params | |
| print(f" VAE parameters: {vae_params:,}") | |
| print(f" Text Encoder parameters: {text_params:,}") | |
| print(f" Diffusion Model parameters: {diff_params:,}") | |
| print(f" Total parameters: {total_params:,}") | |
| # Show training data | |
| print("\nπ STEP 1: TRAINING DATA") | |
| print("-" * 30) | |
| train_video, train_prompt = create_single_training_video() | |
| print(f"Video shape: {train_video.shape}") | |
| print(f"Video range: [{train_video.min():.3f}, {train_video.max():.3f}]") | |
| print(f"Prompt: '{train_prompt}'") | |
| # Save training data | |
| os.makedirs("training_results", exist_ok=True) | |
| save_video_as_gif_simple(train_video, "training_results/0_training_data.gif") | |
| print("β Training data saved as: training_results/0_training_data.gif") | |
| # Train the model (overfit on single video) - MUCH longer training | |
| print("\nπ― STEP 2: INTENSIVE OVERFITTING TRAINING") | |
| print("-" * 30) | |
| model.train(num_epochs=300) # 20x longer training! | |
| # Show training results | |
| print("\nπ STEP 3: RESULTS COMPARISON") | |
| print("-" * 30) | |
| saved_files = model.show_training_results() | |
| # Test generation with same and different prompts | |
| print("\nοΏ½ STEP 4: GENERATION TESTS") | |
| print("-" * 30) | |
| test_prompts = [ | |
| "red moving area", # Same as training | |
| "moving red area", # Slightly different order | |
| "red area", # Shorter | |
| ] | |
| for i, prompt in enumerate(test_prompts): | |
| print(f"\nTest {i+1}: '{prompt}'") | |
| video = model.generate(prompt, num_steps=20, show_progress=False) | |
| # Calculate similarity to training data | |
| similarity = 1 - F.mse_loss(video, model.original_video[0].cpu()).item() | |
| print(f" Similarity to training: {similarity:.4f}") | |
| print(f" Video std: {video.std().item():.4f}") | |
| # Save | |
| os.makedirs("training_results", exist_ok=True) | |
| filename = f"training_results/test_{i+1}_{prompt.replace(' ', '_')}.gif" | |
| save_video_as_gif_simple(video, filename) | |
| print(f" Saved: {filename}") | |
| print("\nπ COMPLETE PIPELINE DEMONSTRATION FINISHED!") | |
| print("="*70) | |
| print("π Check 'training_results/' folder for all videos:") | |
| print(" - 0_training_data.gif (original)") | |
| print(" - 1_Original_Training_Data.gif (same as above)") | |
| print(" - 2_VAE_Reconstruction.gif (autoencoder output)") | |
| print(" - 3_Diffusion_Generated.gif (full pipeline)") | |
| print(" - test_*.gif (generation tests)") | |
| print("\nοΏ½ Check 'twraining_progress/' folder for diffusion training evolution:") | |
| print(" - checkpoint_1_20percent.gif (20% trained)") | |
| print(" - checkpoint_2_40percent.gif (40% trained)") | |
| print(" - checkpoint_3_60percent.gif (60% trained)") | |
| print(" - checkpoint_4_80percent.gif (80% trained)") | |
| print(" - checkpoint_5_100percent.gif (100% trained)") | |
| print("\nπ‘ This shows: Data β VAE β Diffusion β Generation pipeline") | |
| print("π― Perfect overfitting means VAE and Diffusion should reproduce the training circle!") | |
| print("π Training progress videos show how diffusion model learns over time!") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment