Skip to content

Instantly share code, notes, and snippets.

@zzhuolun
Created February 16, 2022 18:34
Show Gist options
  • Select an option

  • Save zzhuolun/559a4dcc5d50729c13c845ba98dcc15c to your computer and use it in GitHub Desktop.

Select an option

Save zzhuolun/559a4dcc5d50729c13c845ba98dcc15c to your computer and use it in GitHub Desktop.
For learning the basic concepts of GAN
### Learning GAN
### From this tutorial: https://towardsdatascience.com/build-a-super-simple-gan-in-pytorch-54ba349920e4 (code at https://github.com/nbertagnolli/pytorch-simple-gan)
import torch
import torch.nn as nn
import numpy as np
import math
class Generator(nn.Module):
def __init__(self, input_length, middle_length, output_length):
super(Generator, self).__init__()
self.lin_1 = nn.Linear(input_length, output_length)
# self.lin_2 = nn.Linear(middle_length, output_length)
self.activation = nn.Sigmoid()
def forward(self, x):
# return self.activation(self.lin_2(self.lin_1(x)))
return self.activation(self.lin_1(x))
class Discriminator(nn.Module):
def __init__(self, input_length):
super(Discriminator, self).__init__()
self.lin = nn.Linear(input_length, 1)
self.activation = nn.Sigmoid()
def forward(self, x):
return self.activation(self.lin(x))
def bin_list_from_int(number):
assert number >= 0
return [int(x) for x in list(bin(number)[2:])]
def convert_float_matrix_to_int_list(float_matrix: np.array, threshold: float = 0.5):
"""Converts generated output in binary list form to a list of integers
Args:
float_matrix: A matrix of values between 0 and 1 which we want to threshold and convert to
integers
threshold: The cutoff value for 0 and 1 thresholding.
Returns:
A list of integers.
"""
return [
int("".join([str(int(y)) for y in x]), 2) for x in float_matrix >= threshold
]
def odd_percentage(x):
odd_cnt = 0
for i in x:
if i % 2 != 0:
odd_cnt += 1
return odd_cnt / len(x) * 100
def gt_data_gen(max_int, batch_size):
max_length = int(math.log(max_int, 2))
sampled_int = 2 * np.random.randint(max_int // 2, size=batch_size) + 1
labels = [1] * batch_size
data = [bin_list_from_int(x) for x in sampled_int]
data = [([0] * (max_length - len(x))) + x for x in data]
return data, labels
def train(
max_int=128,
epoch=300,
batch_size=32,
lr=0.001,
gaussian_len=3,
middle_length=5,
print_freq=10,
):
digit_length = int(math.log(max_int, 2))
generator = Generator(gaussian_len, middle_length, digit_length)
discriminator = Discriminator(digit_length)
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
loss = nn.BCELoss()
for i in range(epoch):
generator_optimizer.zero_grad()
noise = torch.randn((batch_size, gaussian_len))
generated_data = generator(noise)
true_data, true_labels = gt_data_gen(max_int, batch_size)
true_data = torch.tensor(true_data).float()
true_labels = torch.tensor(true_labels).float()
# Train generator
gen_discriminator_output = discriminator(generated_data)
generator_loss = loss(gen_discriminator_output.squeeze(), true_labels)
generator_loss.backward()
generator_optimizer.step()
# Train discriminator
discriminator_optimizer.zero_grad()
true_discriminator_output = discriminator(true_data)
true_discriminator_loss = loss(true_discriminator_output.squeeze(), true_labels)
gen_discriminator_output = discriminator(generated_data.detach())
gen_discriminator_loss = loss(
gen_discriminator_output.squeeze(), torch.zeros_like(true_labels)
)
discriminator_loss = (true_discriminator_loss + gen_discriminator_loss) / 2
discriminator_loss.backward()
discriminator_optimizer.step()
all_loss = generator_loss + discriminator_loss
if i % print_freq == 0:
print(f"{i}_th step:")
generated_int = convert_float_matrix_to_int_list(generated_data)
print(sorted(generated_int))
print(
"odd number percentage: ",
odd_percentage(generated_int),
"| loss: ",
all_loss,
)
print("-" * 15)
return generator, discriminator
if __name__ == "__main__":
train(print_freq=10, epoch=500)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment