Skip to content

Instantly share code, notes, and snippets.

@knok
Created August 1, 2019 07:56
Show Gist options
  • Select an option

  • Save knok/5b240ac74a901ad16344758bb6b90fd2 to your computer and use it in GitHub Desktop.

Select an option

Save knok/5b240ac74a901ad16344758bb6b90fd2 to your computer and use it in GitHub Desktop.
Chainer StyleGAN onnx export
import os
import sys
import re
import json
import numpy as np
from PIL import Image
import chainer
import chainer.cuda
from chainer import Variable
from chainer import training
from chainer.training import extension
from chainer.training import extensions
from net import Discriminator, StyleGenerator, MappingNetwork, SynthesisBlock, EqualizedConv2d
import common
from common.utils.record import record_setting
from common.datasets.base.base_dataset import BaseDataset
from config import FLAGS
from common.utils.save_images import convert_batch_images
from common.evaluation.fid import API as FIDAPI, fid_extension
import math
from common.networks.component.rescale import upscale2x
import onnx_chainer
class MyStyleGenerator(chainer.Chain):
def __init__(self, ch=512, enable_blur=False):
super(MyStyleGenerator, self).__init__()
self.max_stage = 17
with self.init_scope():
self.blocks = chainer.ChainList(
SynthesisBlock(ch, ch, upsample=False), #4
SynthesisBlock(ch, ch, upsample=True, enable_blur=enable_blur), #8
SynthesisBlock(ch, ch, upsample=True, enable_blur=enable_blur), #16
SynthesisBlock(ch, ch, upsample=True, enable_blur=enable_blur), # 32
SynthesisBlock(ch // 2, ch, upsample=True, enable_blur=enable_blur), #64
SynthesisBlock(ch // 4, ch // 2, upsample=True, enable_blur=enable_blur), #128
SynthesisBlock(ch // 8, ch // 4, upsample=True, enable_blur=enable_blur), #256
SynthesisBlock(ch // 16, ch // 8, upsample=True, enable_blur=enable_blur), #512
SynthesisBlock(ch // 32, ch // 16, upsample=True, enable_blur=enable_blur) #1024
)
self.outs = chainer.ChainList(
EqualizedConv2d(ch, 3, 1, 1, 0, gain=1),
EqualizedConv2d(ch, 3, 1, 1, 0, gain=1),
EqualizedConv2d(ch, 3, 1, 1, 0, gain=1),
EqualizedConv2d(ch, 3, 1, 1, 0, gain=1),
EqualizedConv2d(ch // 2, 3, 1, 1, 0, gain=1),
EqualizedConv2d(ch // 4, 3, 1, 1, 0, gain=1),
EqualizedConv2d(ch // 8, 3, 1, 1, 0, gain=1),
EqualizedConv2d(ch // 16, 3, 1, 1, 0, gain=1),
EqualizedConv2d(ch // 32, 3, 1, 1, 0, gain=1)
)
self.n_blocks = len(self.blocks)
self.image_size = 1024
self.enable_blur = enable_blur
def __call__(self, w, stage=17, add_noise=True, w2=None):
'''
for alpha in [0, 1), and 2*k+2 + alpha < self.max_stage (-1 <= k <= ...):
stage 0 + alpha : z -> block[0] -> out[0] * 1
stage 2*k+1 + alpha : z -> ... -> block[k] -> (up -> out[k]) * (1 - alpha)
.................... -> (block[k+1] -> out[k+1]) * (alpha)
stage 2*k+2 + alpha : z -> ............... -> (block[k+1] -> out[k+1]) * 1
over flow stages continues.
'''
stage = min(stage, self.max_stage - 1e-8)
alpha = stage - math.floor(stage)
stage = math.floor(stage)
h = None
if stage % 2 == 0:
k = (stage - 2) // 2
# Enable Style Mixing:
if w2 is not None and k >= 0:
lim = np.random.randint(1, k+2)
else:
lim = k+2
for i in range(0, (k + 1) + 1): # 0 .. k+1
if i == lim:
w = w2
h = self.blocks[i](w, x=h, add_noise=add_noise)
h = self.outs[k + 1](h)
else:
k = (stage - 1) // 2
if w2 is not None and k >= 1:
lim = np.random.randint(1, k+1)
else:
lim = k+1
for i in range(0, k + 1): # 0 .. k
if i == lim:
w = w2
h = self.blocks[i](w, x=h, add_noise=add_noise)
h_0 = self.outs[k](upscale2x(h))
h_1 = self.outs[k + 1](self.blocks[k + 1](w, x=h, add_noise=add_noise))
assert 0. <= alpha < 1.
h = (1.0 - alpha) * h_0 + alpha * h_1
if chainer.configuration.config.train:
return h
else:
min_sample_image_size = 64
if h.data.shape[2] < min_sample_image_size: # too small
scale = int(min_sample_image_size // h.data.shape[2])
return F.unpooling_2d(h, scale, scale, 0, outsize=(min_sample_image_size, min_sample_image_size))
else:
return h
hps_device="--gpu 0"
mpi="--use_mpi=False"
run_iter="--dynamic_batch_size 256,256,256,128,128,64,64,32,32,8,8,4,4,2,2,1,1 --max_stage 13 --stage_interval 1250000"
eval_iter="--evaluation_sample_interval 500 --display_interval 10 --snapshot_interval 5000"
hps_training_dynamics="$eval_iter $mpi $run_iter"
hps_lr="--adam_alpha_g 0.001 --adam_alpha_d 0.001 --adam_beta1 0.0 --adam_beta2 0.999"
hps_hyperparameters="--lambda_gp 5.0 --smoothing 0.999 --keep_smoothed_gen=True"
hps_dataset="--dataset_config $DATASET_CONFIG --dataset_worker_num 16"
hps_output="--out $SCRATCH_ROOT/$SUBPROJ_NAME/$EXPR_ID"
hps_resume="--auto_resume"
argv = hps_lr.split() + hps_training_dynamics.split() + hps_hyperparameters.split() + hps_dataset.split() + hps_device.split() + hps_output.split() + hps_resume.split()
FLAGS(argv)
mapping = MappingNetwork(FLAGS.ch)
generator = StyleGenerator(FLAGS.ch, enable_blur=FLAGS.enable_blur)
discriminator = Discriminator(ch=FLAGS.ch, enable_blur=FLAGS.enable_blur)
logdir = "."
smoothed_generator = MyStyleGenerator(FLAGS.ch, enable_blur=FLAGS.enable_blur)
smoothed_mapping = MappingNetwork(FLAGS.ch)
chainer.serializers.load_npz("%s/SmoothedGenerator_405000.npz" % logdir, smoothed_generator)
chainer.serializers.load_npz("%s/SmoothedMapping_405000.npz" % logdir, smoothed_mapping)
chainer.config.train = False
mapping = smoothed_mapping
generator = smoothed_generator
# gen graph
with chainer.no_backprop_mode():
xp = generator.xp
example_input = xp.asarray(mapping.make_hidden(1))
z = chainer.Variable(example_input)
w = mapping(z)
y = generator(w, stage=11, add_noise=True)
chainer.config.train = False
onnx_mapping = onnx_chainer.export(mapping, z, filename="stylegan_map.onnx")
onnx_generator = onnx_chainer.export(generator, w, filename="stylegan_gen.onnx")
@knok
Copy link
Author

knok commented Aug 22, 2019

I hadn't see such problem, but it seems coms from python version.
I tried onnx_chainer on git master's HEAD with python 3.7.3.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment