Created
November 22, 2025 15:29
-
-
Save znxkznxk1030/00eb9bd4875469ec04df4204635622b3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from torchvision.utils import make_grid, save_image | |
| import matplotlib.pyplot as plt | |
| import torch.nn.functional as F | |
| from torch.nn import BCEWithLogitsLoss # Added for BCEWithLogits | |
| # ========================= | |
| # 0. 기본 설정 | |
| # ========================= | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Device:", device) | |
| batch_size = 128 | |
| nz = 100 # noise 벡터 차원 | |
| num_classes = 10 # CIFAR-10 클래스 수 | |
| ngf = 64 # Generator feature 크기 | |
| ndf = 64 # Discriminator feature 크기 | |
| num_epochs = 200 # 필요에 따라 조절 | |
| data_root = "/content/drive/MyDrive/datasets/cifar10" | |
| # 샘플 이미지를 저장할 디렉토리 설정 | |
| sample_dir = "generated_samples" | |
| os.makedirs(sample_dir, exist_ok=True) | |
| # ========================= | |
| # 1. CIFAR-10 데이터셋 | |
| # ========================= | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), | |
| (0.5, 0.5, 0.5)) # [-1, 1] 범위로 | |
| ]) | |
| train_dataset = torchvision.datasets.CIFAR10( | |
| root=data_root, | |
| train=True, | |
| download=True, | |
| transform=transform | |
| ) | |
| test_dataset = torchvision.datasets.CIFAR10( | |
| root=data_root, | |
| train=False, | |
| download=True, | |
| transform=transform | |
| ) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=2, | |
| drop_last=True | |
| ) | |
| print("Train size:", len(train_dataset)) | |
| # ========================= | |
| # 2. Generator 정의 (cGAN) | |
| # - 입력: noise z + label 임베딩 | |
| # - 출력: 3 x 32 x 32 이미지 | |
| # ========================= | |
| class Generator(nn.Module): | |
| def __init__(self, nz, num_classes, ngf, nc=3): | |
| super().__init__() | |
| self.nz = nz | |
| self.num_classes = num_classes | |
| # 레이블을 num_classes 차원으로 임베딩 | |
| self.label_emb = nn.Embedding(num_classes, num_classes) | |
| # 입력 채널 = noise + label_embedding | |
| in_channels = nz + num_classes | |
| self.main = nn.Sequential( | |
| # input: (N, in_channels, 1, 1) | |
| nn.ConvTranspose2d(in_channels, ngf * 4, 4, 1, 0, bias=False), | |
| nn.BatchNorm2d(ngf * 4), | |
| nn.ReLU(True), # 4x4 | |
| nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(ngf * 2), | |
| nn.ReLU(True), # 8x8 | |
| nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(ngf), | |
| nn.ReLU(True), # 16x16 | |
| nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), | |
| nn.Tanh() # 32x32, [-1,1] | |
| ) | |
| def forward(self, z, labels): | |
| # z: (N, nz, 1, 1) | |
| # labels: (N,) | |
| # label embedding: (N, num_classes) | |
| y = self.label_emb(labels) # (N, num_classes) | |
| y = y.view(y.size(0), self.num_classes, 1, 1) # (N, num_classes, 1, 1) | |
| # 채널 방향으로 concat | |
| z = torch.cat([z, y], dim=1) # (N, nz + num_classes, 1, 1) | |
| out = self.main(z) | |
| return out | |
| ######################################## | |
| # 3. Conditional Discriminator (SGAN + cGAN style) | |
| ######################################## | |
| class Discriminator(nn.Module): | |
| def __init__(self, ndf, num_classes): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| # label embedding: (num_classes → num_classes) | |
| self.label_emb = nn.Embedding(num_classes, num_classes) | |
| # 이미지 채널 3 + label 임베딩 채널 num_classes | |
| in_channels = 3 + num_classes | |
| # 기존 SGAN feature extractor | |
| self.features = nn.Sequential( | |
| nn.Conv2d(in_channels, ndf, 3, 1, 1), # (N, ndf, 32, 32) | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(ndf, ndf, 4, 2, 1), # 16x16 | |
| nn.BatchNorm2d(ndf), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(ndf, ndf * 2, 4, 2, 1), # 8x8 | |
| nn.BatchNorm2d(ndf * 2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1), # 4x4 | |
| nn.BatchNorm2d(ndf * 4), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # SGAN-style classifier: K+1 (fake 포함) | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.fc = nn.Linear(ndf * 4, num_classes + 1) | |
| def forward(self, x, labels): | |
| # x: (N, 3, 32, 32) | |
| # labels: (N,) | |
| # ------------------------- | |
| # 1. Label embedding | |
| # ------------------------- | |
| y = self.label_emb(labels) # (N, num_classes) | |
| # reshape → spatial broadcast 준비 | |
| y = y.view(y.size(0), self.num_classes, 1, 1) # (N, num_classes, 1, 1) | |
| # spatial broadcast → (N, num_classes, 32, 32) | |
| y = y.expand(-1, -1, x.size(2), x.size(3)) | |
| # ------------------------- | |
| # 2. 이미지 + 라벨 concat | |
| # ------------------------- | |
| d_in = torch.cat([x, y], dim=1) # (N, 3 + num_classes, 32, 32) | |
| # ------------------------- | |
| # 3. SGAN-style feature extractor | |
| # ------------------------- | |
| h = self.features(d_in) # (N, ndf*4, 4, 4) | |
| h = self.avgpool(h) # (N, ndf*4, 1, 1) | |
| h = h.view(h.size(0), -1) # (N, ndf*4) | |
| # ------------------------- | |
| # 4. SGAN classifier: K+1 출력 | |
| # ------------------------- | |
| logits = self.fc(h) # (N, num_classes + 1) | |
| return logits, h # logits, feature | |
| G = Generator(nz, num_classes, ngf).to(device) | |
| D = Discriminator(ndf, num_classes).to(device) | |
| print(G) | |
| print(D) | |
| # ========================= | |
| # 4. Loss 및 Optimizer | |
| # ========================= | |
| lr = 2e-4 | |
| beta1, beta2 = 0.5, 0.999 | |
| optimizerD = optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2)) | |
| optimizerG = optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2)) | |
| ######################################## | |
| # 4. SGAN Loss 함수 정의 | |
| ######################################## | |
| # Supervised loss: labeled 데이터에 대해 0~K-1 클래스만 사용 | |
| def supervised_loss(logits, labels, num_classes): | |
| # logits: (N, K+1), labels: (N,) | |
| logits_supervised = logits[:, :num_classes] # 마지막(fake) 제외 | |
| return F.cross_entropy(logits_supervised, labels) | |
| # Unsupervised loss (Salimans SGAN 방식) | |
| # real x: maximize log(1 - p_model(y = K+1 | x)) | |
| # fake x: maximize log( p_model(y = K+1 | x)) | |
| def unsupervised_loss(logits_real, logits_fake): | |
| # softmax over K+1 | |
| p_real = F.softmax(logits_real, dim=1) | |
| p_fake = F.softmax(logits_fake, dim=1) | |
| # p(y = fake) | |
| p_real_fake = p_real[:, -1] | |
| p_fake_fake = p_fake[:, -1] | |
| # L_unsup = -E_real log(1.0 - p_fake_real) - E_fake log(p_fake_fake) | |
| loss_real = -torch.mean(torch.log(1.0 - p_real_fake + 1e-8)) | |
| loss_fake = -torch.mean(torch.log(p_fake_fake + 1e-8)) | |
| return loss_real + loss_fake | |
| def generator_loss(logits_fake): | |
| # Generator: fake를 "진짜 클래스들(0~K-1)에 속하게" 만들고 싶음 | |
| p_fake = F.softmax(logits_fake, dim=1) | |
| p_fake_not_fake = 1.0 - p_fake[:, -1] # = sum_{k=0}^{K-1} p(y=k|x) | |
| loss = -torch.mean(torch.log(p_fake_not_fake + 1e-8)) | |
| return loss | |
| bce_loss_fn = nn.BCEWithLogitsLoss() | |
| # ========================= | |
| # 5. 샘플 이미지 저장 함수 | |
| # ========================= | |
| # 클래스별로 이미지 생성해보는 용도 | |
| @torch.no_grad() | |
| def sample_cgan_images(epoch, G, nz, num_classes, device, n_per_class=8): | |
| G.eval() | |
| all_imgs = [] | |
| all_labels = [] | |
| for cls in range(num_classes): | |
| z = torch.randn(n_per_class, nz, 1, 1, device=device) | |
| labels = torch.full((n_per_class,), cls, dtype=torch.long, device=device) | |
| fake = G(z, labels) # (n_per_class, 3, 32, 32) | |
| all_imgs.append(fake.cpu()) | |
| all_labels.extend([cls] * n_per_class) | |
| imgs = torch.cat(all_imgs, dim=0) # (num_classes * n_per_class, 3, 32, 32) | |
| imgs = (imgs + 1) / 2 # [-1,1] → [0,1] | |
| grid = make_grid(imgs, nrow=n_per_class) | |
| save_path = os.path.join(sample_dir, f"cgan_epoch_{epoch}.png") | |
| save_image(grid, save_path) | |
| print(f"Saved sample images to: {save_path}") | |
| # 화면에 바로 보고 싶으면: | |
| plt.figure(figsize=(8, 8)) | |
| plt.imshow(grid.permute(1, 2, 0)) | |
| plt.axis("off") | |
| plt.title(f"CGAN Samples @ epoch {epoch}") | |
| plt.show() | |
| G.train() | |
| # ========================= | |
| # 6. 학습 루프 | |
| # ========================= | |
| fixed_z = torch.randn(num_classes * 8, nz, 1, 1, device=device) | |
| fixed_labels = torch.tensor( | |
| sum([[i] * 8 for i in range(num_classes)], []), | |
| dtype=torch.long, | |
| device=device, | |
| ) | |
| for epoch in range(1, num_epochs + 1): | |
| G.train() | |
| D.train() | |
| running_sup = 0.0 | |
| running_unsup = 0.0 | |
| running_g = 0.0 | |
| lambda_unsup = 0 | |
| for i, (real_imgs, labels) in enumerate(train_loader): | |
| real_imgs = real_imgs.to(device) | |
| labels = labels.to(device) | |
| bsz = real_imgs.size(0) | |
| # ----------------------------------- | |
| # 1) Discriminator update (SGAN 스타일) | |
| # ----------------------------------- | |
| optimizerD.zero_grad() | |
| # Real images (조건 포함) | |
| logits_real, _ = D(real_imgs, labels) | |
| # Fake images (조건 포함) | |
| z = torch.randn(bsz, nz, 1, 1, device=device) | |
| fake_labels = torch.randint(0, num_classes, (bsz,), device=device) | |
| fake_images = G(z, fake_labels) | |
| logits_fake, _ = D(fake_images.detach(), fake_labels) | |
| # Supervised loss: labeled real (0~K-1) | |
| loss_sup = supervised_loss(logits_real, labels, num_classes) | |
| # Unsupervised loss: real vs fake (K+1번째 index) | |
| loss_unsup = unsupervised_loss(logits_real, logits_fake) | |
| loss_D = loss_sup + loss_unsup * lambda_unsup | |
| loss_D.backward() | |
| # d_real_bce = bce_loss_fn(logits_real[:, -1], torch.ones_like(logits_real[:, -1])) # Real images should be classified as 'real' (e.g., target 1) | |
| # d_fake_bce = bce_loss_fn(logits_fake[:, -1], torch.zeros_like(logits_fake[:, -1])) # Fake images should be classified as 'fake' (e.g., target 0) | |
| # d_loss_bce = d_real_bce + d_fake_bce # This would be an alternative to loss_unsup | |
| # d_loss_bce.backward() | |
| optimizerD.step() | |
| # ----------------------------------- | |
| # 2) Generator update (SGAN generator_loss) | |
| # ----------------------------------- | |
| optimizerG.zero_grad() | |
| z = torch.randn(bsz, nz, 1, 1, device=device) | |
| fake_labels = torch.randint(0, num_classes, (bsz,), device=device) | |
| fake_images = G(z, fake_labels) | |
| logits_fake_for_G, _ = D(fake_images, fake_labels) | |
| loss_G = generator_loss(logits_fake_for_G) | |
| loss_G.backward() | |
| # g_loss_bce = bce_loss_fn(logits_fake_for_G[:, -1], torch.ones_like(logits_fake_for_G[:, -1])) # Generator wants fake to be classified as 'real' | |
| # g_loss_bce.backward() | |
| optimizerG.step() | |
| # 로그용 누적 | |
| running_sup += loss_sup.item() | |
| running_unsup += loss_unsup.item() | |
| running_g += loss_G.item() | |
| if (i + 1) % 100 == 0: | |
| print( | |
| f"Epoch [{epoch}/{num_epochs}] " | |
| f"Step [{i+1}/{len(train_loader)}] " | |
| f"Sup: {running_sup/(i+1):.4f} " | |
| f"Unsup: {running_unsup/(i+1):.4f} " | |
| f"G: {running_g/(i+1):.4f}" | |
| ) | |
| if i % 5 == 0: | |
| lambda_unsup += 0.1 | |
| sample_cgan_images(epoch, G, nz, num_classes, device, n_per_class=8) | |
| print("Training finished!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment