Created
August 1, 2019 07:56
-
-
Save knok/5b240ac74a901ad16344758bb6b90fd2 to your computer and use it in GitHub Desktop.
Chainer StyleGAN onnx export
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import sys | |
| import re | |
| import json | |
| import numpy as np | |
| from PIL import Image | |
| import chainer | |
| import chainer.cuda | |
| from chainer import Variable | |
| from chainer import training | |
| from chainer.training import extension | |
| from chainer.training import extensions | |
| from net import Discriminator, StyleGenerator, MappingNetwork, SynthesisBlock, EqualizedConv2d | |
| import common | |
| from common.utils.record import record_setting | |
| from common.datasets.base.base_dataset import BaseDataset | |
| from config import FLAGS | |
| from common.utils.save_images import convert_batch_images | |
| from common.evaluation.fid import API as FIDAPI, fid_extension | |
| import math | |
| from common.networks.component.rescale import upscale2x | |
| import onnx_chainer | |
| class MyStyleGenerator(chainer.Chain): | |
| def __init__(self, ch=512, enable_blur=False): | |
| super(MyStyleGenerator, self).__init__() | |
| self.max_stage = 17 | |
| with self.init_scope(): | |
| self.blocks = chainer.ChainList( | |
| SynthesisBlock(ch, ch, upsample=False), #4 | |
| SynthesisBlock(ch, ch, upsample=True, enable_blur=enable_blur), #8 | |
| SynthesisBlock(ch, ch, upsample=True, enable_blur=enable_blur), #16 | |
| SynthesisBlock(ch, ch, upsample=True, enable_blur=enable_blur), # 32 | |
| SynthesisBlock(ch // 2, ch, upsample=True, enable_blur=enable_blur), #64 | |
| SynthesisBlock(ch // 4, ch // 2, upsample=True, enable_blur=enable_blur), #128 | |
| SynthesisBlock(ch // 8, ch // 4, upsample=True, enable_blur=enable_blur), #256 | |
| SynthesisBlock(ch // 16, ch // 8, upsample=True, enable_blur=enable_blur), #512 | |
| SynthesisBlock(ch // 32, ch // 16, upsample=True, enable_blur=enable_blur) #1024 | |
| ) | |
| self.outs = chainer.ChainList( | |
| EqualizedConv2d(ch, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 2, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 4, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 8, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 16, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 32, 3, 1, 1, 0, gain=1) | |
| ) | |
| self.n_blocks = len(self.blocks) | |
| self.image_size = 1024 | |
| self.enable_blur = enable_blur | |
| def __call__(self, w, stage=17, add_noise=True, w2=None): | |
| ''' | |
| for alpha in [0, 1), and 2*k+2 + alpha < self.max_stage (-1 <= k <= ...): | |
| stage 0 + alpha : z -> block[0] -> out[0] * 1 | |
| stage 2*k+1 + alpha : z -> ... -> block[k] -> (up -> out[k]) * (1 - alpha) | |
| .................... -> (block[k+1] -> out[k+1]) * (alpha) | |
| stage 2*k+2 + alpha : z -> ............... -> (block[k+1] -> out[k+1]) * 1 | |
| over flow stages continues. | |
| ''' | |
| stage = min(stage, self.max_stage - 1e-8) | |
| alpha = stage - math.floor(stage) | |
| stage = math.floor(stage) | |
| h = None | |
| if stage % 2 == 0: | |
| k = (stage - 2) // 2 | |
| # Enable Style Mixing: | |
| if w2 is not None and k >= 0: | |
| lim = np.random.randint(1, k+2) | |
| else: | |
| lim = k+2 | |
| for i in range(0, (k + 1) + 1): # 0 .. k+1 | |
| if i == lim: | |
| w = w2 | |
| h = self.blocks[i](w, x=h, add_noise=add_noise) | |
| h = self.outs[k + 1](h) | |
| else: | |
| k = (stage - 1) // 2 | |
| if w2 is not None and k >= 1: | |
| lim = np.random.randint(1, k+1) | |
| else: | |
| lim = k+1 | |
| for i in range(0, k + 1): # 0 .. k | |
| if i == lim: | |
| w = w2 | |
| h = self.blocks[i](w, x=h, add_noise=add_noise) | |
| h_0 = self.outs[k](upscale2x(h)) | |
| h_1 = self.outs[k + 1](self.blocks[k + 1](w, x=h, add_noise=add_noise)) | |
| assert 0. <= alpha < 1. | |
| h = (1.0 - alpha) * h_0 + alpha * h_1 | |
| if chainer.configuration.config.train: | |
| return h | |
| else: | |
| min_sample_image_size = 64 | |
| if h.data.shape[2] < min_sample_image_size: # too small | |
| scale = int(min_sample_image_size // h.data.shape[2]) | |
| return F.unpooling_2d(h, scale, scale, 0, outsize=(min_sample_image_size, min_sample_image_size)) | |
| else: | |
| return h | |
| hps_device="--gpu 0" | |
| mpi="--use_mpi=False" | |
| run_iter="--dynamic_batch_size 256,256,256,128,128,64,64,32,32,8,8,4,4,2,2,1,1 --max_stage 13 --stage_interval 1250000" | |
| eval_iter="--evaluation_sample_interval 500 --display_interval 10 --snapshot_interval 5000" | |
| hps_training_dynamics="$eval_iter $mpi $run_iter" | |
| hps_lr="--adam_alpha_g 0.001 --adam_alpha_d 0.001 --adam_beta1 0.0 --adam_beta2 0.999" | |
| hps_hyperparameters="--lambda_gp 5.0 --smoothing 0.999 --keep_smoothed_gen=True" | |
| hps_dataset="--dataset_config $DATASET_CONFIG --dataset_worker_num 16" | |
| hps_output="--out $SCRATCH_ROOT/$SUBPROJ_NAME/$EXPR_ID" | |
| hps_resume="--auto_resume" | |
| argv = hps_lr.split() + hps_training_dynamics.split() + hps_hyperparameters.split() + hps_dataset.split() + hps_device.split() + hps_output.split() + hps_resume.split() | |
| FLAGS(argv) | |
| mapping = MappingNetwork(FLAGS.ch) | |
| generator = StyleGenerator(FLAGS.ch, enable_blur=FLAGS.enable_blur) | |
| discriminator = Discriminator(ch=FLAGS.ch, enable_blur=FLAGS.enable_blur) | |
| logdir = "." | |
| smoothed_generator = MyStyleGenerator(FLAGS.ch, enable_blur=FLAGS.enable_blur) | |
| smoothed_mapping = MappingNetwork(FLAGS.ch) | |
| chainer.serializers.load_npz("%s/SmoothedGenerator_405000.npz" % logdir, smoothed_generator) | |
| chainer.serializers.load_npz("%s/SmoothedMapping_405000.npz" % logdir, smoothed_mapping) | |
| chainer.config.train = False | |
| mapping = smoothed_mapping | |
| generator = smoothed_generator | |
| # gen graph | |
| with chainer.no_backprop_mode(): | |
| xp = generator.xp | |
| example_input = xp.asarray(mapping.make_hidden(1)) | |
| z = chainer.Variable(example_input) | |
| w = mapping(z) | |
| y = generator(w, stage=11, add_noise=True) | |
| chainer.config.train = False | |
| onnx_mapping = onnx_chainer.export(mapping, z, filename="stylegan_map.onnx") | |
| onnx_generator = onnx_chainer.export(generator, w, filename="stylegan_gen.onnx") |
Author
I hadn't see such problem, but it seems coms from python version.
I tried onnx_chainer on git master's HEAD with python 3.7.3.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
BTW I cloned the Chainer StyleGAN repo, placed the script in
src/styleganand launchedpython sg-onnx.py. I got the error below... Am I doing anything wrong?