Skip to content

Instantly share code, notes, and snippets.

@znxkznxk1030
Created November 22, 2025 15:29
Show Gist options
  • Select an option

  • Save znxkznxk1030/00eb9bd4875469ec04df4204635622b3 to your computer and use it in GitHub Desktop.

Select an option

Save znxkznxk1030/00eb9bd4875469ec04df4204635622b3 to your computer and use it in GitHub Desktop.
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss # Added for BCEWithLogits
# =========================
# 0. 기본 설정
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
batch_size = 128
nz = 100 # noise 벡터 차원
num_classes = 10 # CIFAR-10 클래스 수
ngf = 64 # Generator feature 크기
ndf = 64 # Discriminator feature 크기
num_epochs = 200 # 필요에 따라 조절
data_root = "/content/drive/MyDrive/datasets/cifar10"
# 샘플 이미지를 저장할 디렉토리 설정
sample_dir = "generated_samples"
os.makedirs(sample_dir, exist_ok=True)
# =========================
# 1. CIFAR-10 데이터셋
# =========================
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)) # [-1, 1] 범위로
])
train_dataset = torchvision.datasets.CIFAR10(
root=data_root,
train=True,
download=True,
transform=transform
)
test_dataset = torchvision.datasets.CIFAR10(
root=data_root,
train=False,
download=True,
transform=transform
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
drop_last=True
)
print("Train size:", len(train_dataset))
# =========================
# 2. Generator 정의 (cGAN)
# - 입력: noise z + label 임베딩
# - 출력: 3 x 32 x 32 이미지
# =========================
class Generator(nn.Module):
def __init__(self, nz, num_classes, ngf, nc=3):
super().__init__()
self.nz = nz
self.num_classes = num_classes
# 레이블을 num_classes 차원으로 임베딩
self.label_emb = nn.Embedding(num_classes, num_classes)
# 입력 채널 = noise + label_embedding
in_channels = nz + num_classes
self.main = nn.Sequential(
# input: (N, in_channels, 1, 1)
nn.ConvTranspose2d(in_channels, ngf * 4, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True), # 4x4
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True), # 8x8
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True), # 16x16
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh() # 32x32, [-1,1]
)
def forward(self, z, labels):
# z: (N, nz, 1, 1)
# labels: (N,)
# label embedding: (N, num_classes)
y = self.label_emb(labels) # (N, num_classes)
y = y.view(y.size(0), self.num_classes, 1, 1) # (N, num_classes, 1, 1)
# 채널 방향으로 concat
z = torch.cat([z, y], dim=1) # (N, nz + num_classes, 1, 1)
out = self.main(z)
return out
########################################
# 3. Conditional Discriminator (SGAN + cGAN style)
########################################
class Discriminator(nn.Module):
def __init__(self, ndf, num_classes):
super().__init__()
self.num_classes = num_classes
# label embedding: (num_classes → num_classes)
self.label_emb = nn.Embedding(num_classes, num_classes)
# 이미지 채널 3 + label 임베딩 채널 num_classes
in_channels = 3 + num_classes
# 기존 SGAN feature extractor
self.features = nn.Sequential(
nn.Conv2d(in_channels, ndf, 3, 1, 1), # (N, ndf, 32, 32)
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf, 4, 2, 1), # 16x16
nn.BatchNorm2d(ndf),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1), # 8x8
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1), # 4x4
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
)
# SGAN-style classifier: K+1 (fake 포함)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(ndf * 4, num_classes + 1)
def forward(self, x, labels):
# x: (N, 3, 32, 32)
# labels: (N,)
# -------------------------
# 1. Label embedding
# -------------------------
y = self.label_emb(labels) # (N, num_classes)
# reshape → spatial broadcast 준비
y = y.view(y.size(0), self.num_classes, 1, 1) # (N, num_classes, 1, 1)
# spatial broadcast → (N, num_classes, 32, 32)
y = y.expand(-1, -1, x.size(2), x.size(3))
# -------------------------
# 2. 이미지 + 라벨 concat
# -------------------------
d_in = torch.cat([x, y], dim=1) # (N, 3 + num_classes, 32, 32)
# -------------------------
# 3. SGAN-style feature extractor
# -------------------------
h = self.features(d_in) # (N, ndf*4, 4, 4)
h = self.avgpool(h) # (N, ndf*4, 1, 1)
h = h.view(h.size(0), -1) # (N, ndf*4)
# -------------------------
# 4. SGAN classifier: K+1 출력
# -------------------------
logits = self.fc(h) # (N, num_classes + 1)
return logits, h # logits, feature
G = Generator(nz, num_classes, ngf).to(device)
D = Discriminator(ndf, num_classes).to(device)
print(G)
print(D)
# =========================
# 4. Loss 및 Optimizer
# =========================
lr = 2e-4
beta1, beta2 = 0.5, 0.999
optimizerD = optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2))
optimizerG = optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
########################################
# 4. SGAN Loss 함수 정의
########################################
# Supervised loss: labeled 데이터에 대해 0~K-1 클래스만 사용
def supervised_loss(logits, labels, num_classes):
# logits: (N, K+1), labels: (N,)
logits_supervised = logits[:, :num_classes] # 마지막(fake) 제외
return F.cross_entropy(logits_supervised, labels)
# Unsupervised loss (Salimans SGAN 방식)
# real x: maximize log(1 - p_model(y = K+1 | x))
# fake x: maximize log( p_model(y = K+1 | x))
def unsupervised_loss(logits_real, logits_fake):
# softmax over K+1
p_real = F.softmax(logits_real, dim=1)
p_fake = F.softmax(logits_fake, dim=1)
# p(y = fake)
p_real_fake = p_real[:, -1]
p_fake_fake = p_fake[:, -1]
# L_unsup = -E_real log(1.0 - p_fake_real) - E_fake log(p_fake_fake)
loss_real = -torch.mean(torch.log(1.0 - p_real_fake + 1e-8))
loss_fake = -torch.mean(torch.log(p_fake_fake + 1e-8))
return loss_real + loss_fake
def generator_loss(logits_fake):
# Generator: fake를 "진짜 클래스들(0~K-1)에 속하게" 만들고 싶음
p_fake = F.softmax(logits_fake, dim=1)
p_fake_not_fake = 1.0 - p_fake[:, -1] # = sum_{k=0}^{K-1} p(y=k|x)
loss = -torch.mean(torch.log(p_fake_not_fake + 1e-8))
return loss
bce_loss_fn = nn.BCEWithLogitsLoss()
# =========================
# 5. 샘플 이미지 저장 함수
# =========================
# 클래스별로 이미지 생성해보는 용도
@torch.no_grad()
def sample_cgan_images(epoch, G, nz, num_classes, device, n_per_class=8):
G.eval()
all_imgs = []
all_labels = []
for cls in range(num_classes):
z = torch.randn(n_per_class, nz, 1, 1, device=device)
labels = torch.full((n_per_class,), cls, dtype=torch.long, device=device)
fake = G(z, labels) # (n_per_class, 3, 32, 32)
all_imgs.append(fake.cpu())
all_labels.extend([cls] * n_per_class)
imgs = torch.cat(all_imgs, dim=0) # (num_classes * n_per_class, 3, 32, 32)
imgs = (imgs + 1) / 2 # [-1,1] → [0,1]
grid = make_grid(imgs, nrow=n_per_class)
save_path = os.path.join(sample_dir, f"cgan_epoch_{epoch}.png")
save_image(grid, save_path)
print(f"Saved sample images to: {save_path}")
# 화면에 바로 보고 싶으면:
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0))
plt.axis("off")
plt.title(f"CGAN Samples @ epoch {epoch}")
plt.show()
G.train()
# =========================
# 6. 학습 루프
# =========================
fixed_z = torch.randn(num_classes * 8, nz, 1, 1, device=device)
fixed_labels = torch.tensor(
sum([[i] * 8 for i in range(num_classes)], []),
dtype=torch.long,
device=device,
)
for epoch in range(1, num_epochs + 1):
G.train()
D.train()
running_sup = 0.0
running_unsup = 0.0
running_g = 0.0
lambda_unsup = 0
for i, (real_imgs, labels) in enumerate(train_loader):
real_imgs = real_imgs.to(device)
labels = labels.to(device)
bsz = real_imgs.size(0)
# -----------------------------------
# 1) Discriminator update (SGAN 스타일)
# -----------------------------------
optimizerD.zero_grad()
# Real images (조건 포함)
logits_real, _ = D(real_imgs, labels)
# Fake images (조건 포함)
z = torch.randn(bsz, nz, 1, 1, device=device)
fake_labels = torch.randint(0, num_classes, (bsz,), device=device)
fake_images = G(z, fake_labels)
logits_fake, _ = D(fake_images.detach(), fake_labels)
# Supervised loss: labeled real (0~K-1)
loss_sup = supervised_loss(logits_real, labels, num_classes)
# Unsupervised loss: real vs fake (K+1번째 index)
loss_unsup = unsupervised_loss(logits_real, logits_fake)
loss_D = loss_sup + loss_unsup * lambda_unsup
loss_D.backward()
# d_real_bce = bce_loss_fn(logits_real[:, -1], torch.ones_like(logits_real[:, -1])) # Real images should be classified as 'real' (e.g., target 1)
# d_fake_bce = bce_loss_fn(logits_fake[:, -1], torch.zeros_like(logits_fake[:, -1])) # Fake images should be classified as 'fake' (e.g., target 0)
# d_loss_bce = d_real_bce + d_fake_bce # This would be an alternative to loss_unsup
# d_loss_bce.backward()
optimizerD.step()
# -----------------------------------
# 2) Generator update (SGAN generator_loss)
# -----------------------------------
optimizerG.zero_grad()
z = torch.randn(bsz, nz, 1, 1, device=device)
fake_labels = torch.randint(0, num_classes, (bsz,), device=device)
fake_images = G(z, fake_labels)
logits_fake_for_G, _ = D(fake_images, fake_labels)
loss_G = generator_loss(logits_fake_for_G)
loss_G.backward()
# g_loss_bce = bce_loss_fn(logits_fake_for_G[:, -1], torch.ones_like(logits_fake_for_G[:, -1])) # Generator wants fake to be classified as 'real'
# g_loss_bce.backward()
optimizerG.step()
# 로그용 누적
running_sup += loss_sup.item()
running_unsup += loss_unsup.item()
running_g += loss_G.item()
if (i + 1) % 100 == 0:
print(
f"Epoch [{epoch}/{num_epochs}] "
f"Step [{i+1}/{len(train_loader)}] "
f"Sup: {running_sup/(i+1):.4f} "
f"Unsup: {running_unsup/(i+1):.4f} "
f"G: {running_g/(i+1):.4f}"
)
if i % 5 == 0:
lambda_unsup += 0.1
sample_cgan_images(epoch, G, nz, num_classes, device, n_per_class=8)
print("Training finished!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment