Created
August 1, 2019 07:56
-
-
Save knok/5b240ac74a901ad16344758bb6b90fd2 to your computer and use it in GitHub Desktop.
Chainer StyleGAN onnx export
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import sys | |
| import re | |
| import json | |
| import numpy as np | |
| from PIL import Image | |
| import chainer | |
| import chainer.cuda | |
| from chainer import Variable | |
| from chainer import training | |
| from chainer.training import extension | |
| from chainer.training import extensions | |
| from net import Discriminator, StyleGenerator, MappingNetwork, SynthesisBlock, EqualizedConv2d | |
| import common | |
| from common.utils.record import record_setting | |
| from common.datasets.base.base_dataset import BaseDataset | |
| from config import FLAGS | |
| from common.utils.save_images import convert_batch_images | |
| from common.evaluation.fid import API as FIDAPI, fid_extension | |
| import math | |
| from common.networks.component.rescale import upscale2x | |
| import onnx_chainer | |
| class MyStyleGenerator(chainer.Chain): | |
| def __init__(self, ch=512, enable_blur=False): | |
| super(MyStyleGenerator, self).__init__() | |
| self.max_stage = 17 | |
| with self.init_scope(): | |
| self.blocks = chainer.ChainList( | |
| SynthesisBlock(ch, ch, upsample=False), #4 | |
| SynthesisBlock(ch, ch, upsample=True, enable_blur=enable_blur), #8 | |
| SynthesisBlock(ch, ch, upsample=True, enable_blur=enable_blur), #16 | |
| SynthesisBlock(ch, ch, upsample=True, enable_blur=enable_blur), # 32 | |
| SynthesisBlock(ch // 2, ch, upsample=True, enable_blur=enable_blur), #64 | |
| SynthesisBlock(ch // 4, ch // 2, upsample=True, enable_blur=enable_blur), #128 | |
| SynthesisBlock(ch // 8, ch // 4, upsample=True, enable_blur=enable_blur), #256 | |
| SynthesisBlock(ch // 16, ch // 8, upsample=True, enable_blur=enable_blur), #512 | |
| SynthesisBlock(ch // 32, ch // 16, upsample=True, enable_blur=enable_blur) #1024 | |
| ) | |
| self.outs = chainer.ChainList( | |
| EqualizedConv2d(ch, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 2, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 4, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 8, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 16, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 32, 3, 1, 1, 0, gain=1) | |
| ) | |
| self.n_blocks = len(self.blocks) | |
| self.image_size = 1024 | |
| self.enable_blur = enable_blur | |
| def __call__(self, w, stage=17, add_noise=True, w2=None): | |
| ''' | |
| for alpha in [0, 1), and 2*k+2 + alpha < self.max_stage (-1 <= k <= ...): | |
| stage 0 + alpha : z -> block[0] -> out[0] * 1 | |
| stage 2*k+1 + alpha : z -> ... -> block[k] -> (up -> out[k]) * (1 - alpha) | |
| .................... -> (block[k+1] -> out[k+1]) * (alpha) | |
| stage 2*k+2 + alpha : z -> ............... -> (block[k+1] -> out[k+1]) * 1 | |
| over flow stages continues. | |
| ''' | |
| stage = min(stage, self.max_stage - 1e-8) | |
| alpha = stage - math.floor(stage) | |
| stage = math.floor(stage) | |
| h = None | |
| if stage % 2 == 0: | |
| k = (stage - 2) // 2 | |
| # Enable Style Mixing: | |
| if w2 is not None and k >= 0: | |
| lim = np.random.randint(1, k+2) | |
| else: | |
| lim = k+2 | |
| for i in range(0, (k + 1) + 1): # 0 .. k+1 | |
| if i == lim: | |
| w = w2 | |
| h = self.blocks[i](w, x=h, add_noise=add_noise) | |
| h = self.outs[k + 1](h) | |
| else: | |
| k = (stage - 1) // 2 | |
| if w2 is not None and k >= 1: | |
| lim = np.random.randint(1, k+1) | |
| else: | |
| lim = k+1 | |
| for i in range(0, k + 1): # 0 .. k | |
| if i == lim: | |
| w = w2 | |
| h = self.blocks[i](w, x=h, add_noise=add_noise) | |
| h_0 = self.outs[k](upscale2x(h)) | |
| h_1 = self.outs[k + 1](self.blocks[k + 1](w, x=h, add_noise=add_noise)) | |
| assert 0. <= alpha < 1. | |
| h = (1.0 - alpha) * h_0 + alpha * h_1 | |
| if chainer.configuration.config.train: | |
| return h | |
| else: | |
| min_sample_image_size = 64 | |
| if h.data.shape[2] < min_sample_image_size: # too small | |
| scale = int(min_sample_image_size // h.data.shape[2]) | |
| return F.unpooling_2d(h, scale, scale, 0, outsize=(min_sample_image_size, min_sample_image_size)) | |
| else: | |
| return h | |
| hps_device="--gpu 0" | |
| mpi="--use_mpi=False" | |
| run_iter="--dynamic_batch_size 256,256,256,128,128,64,64,32,32,8,8,4,4,2,2,1,1 --max_stage 13 --stage_interval 1250000" | |
| eval_iter="--evaluation_sample_interval 500 --display_interval 10 --snapshot_interval 5000" | |
| hps_training_dynamics="$eval_iter $mpi $run_iter" | |
| hps_lr="--adam_alpha_g 0.001 --adam_alpha_d 0.001 --adam_beta1 0.0 --adam_beta2 0.999" | |
| hps_hyperparameters="--lambda_gp 5.0 --smoothing 0.999 --keep_smoothed_gen=True" | |
| hps_dataset="--dataset_config $DATASET_CONFIG --dataset_worker_num 16" | |
| hps_output="--out $SCRATCH_ROOT/$SUBPROJ_NAME/$EXPR_ID" | |
| hps_resume="--auto_resume" | |
| argv = hps_lr.split() + hps_training_dynamics.split() + hps_hyperparameters.split() + hps_dataset.split() + hps_device.split() + hps_output.split() + hps_resume.split() | |
| FLAGS(argv) | |
| mapping = MappingNetwork(FLAGS.ch) | |
| generator = StyleGenerator(FLAGS.ch, enable_blur=FLAGS.enable_blur) | |
| discriminator = Discriminator(ch=FLAGS.ch, enable_blur=FLAGS.enable_blur) | |
| logdir = "." | |
| smoothed_generator = MyStyleGenerator(FLAGS.ch, enable_blur=FLAGS.enable_blur) | |
| smoothed_mapping = MappingNetwork(FLAGS.ch) | |
| chainer.serializers.load_npz("%s/SmoothedGenerator_405000.npz" % logdir, smoothed_generator) | |
| chainer.serializers.load_npz("%s/SmoothedMapping_405000.npz" % logdir, smoothed_mapping) | |
| chainer.config.train = False | |
| mapping = smoothed_mapping | |
| generator = smoothed_generator | |
| # gen graph | |
| with chainer.no_backprop_mode(): | |
| xp = generator.xp | |
| example_input = xp.asarray(mapping.make_hidden(1)) | |
| z = chainer.Variable(example_input) | |
| w = mapping(z) | |
| y = generator(w, stage=11, add_noise=True) | |
| chainer.config.train = False | |
| onnx_mapping = onnx_chainer.export(mapping, z, filename="stylegan_map.onnx") | |
| onnx_generator = onnx_chainer.export(generator, w, filename="stylegan_gen.onnx") |
Author
You need chainer-stylegan code, and place correct path, usually same path with sg-onnx.py.
BTW, I still can't work it because of onnx-chainer didn't support RandomNormal yet.
chainer/onnx-chainer#214
Thanks for your reply and for the information! I need to export the entire model to .onnx, so I guess I'll try to reimplement the code to have control over what's going on :p
BTW I cloned the Chainer StyleGAN repo, placed the script in src/styleganand launched python sg-onnx.py. I got the error below... Am I doing anything wrong?
Traceback (most recent call last):
File "sg-onnx.py", line 27, in <module>
import onnx_chainer
File "/home/alberto/anaconda3/envs/chainer-stylegan/lib/python2.7/site-packages/onnx_chainer/__init__.py", line 3, in <module>
from onnx_chainer.export import convert_parameter # NOQA
File "/home/alberto/anaconda3/envs/chainer-stylegan/lib/python2.7/site-packages/onnx_chainer/export.py", line 12, in <module>
from onnx_chainer.graph import Graph
File "/home/alberto/anaconda3/envs/chainer-stylegan/lib/python2.7/site-packages/onnx_chainer/graph.py", line 7, in <module>
from onnx_chainer.functions.converter import FunctionConverterParams
File "/home/alberto/anaconda3/envs/chainer-stylegan/lib/python2.7/site-packages/onnx_chainer/functions/__init__.py", line 46, in <module>
from onnx_chainer.functions.math import convert_Absolute # NOQA
File "/home/alberto/anaconda3/envs/chainer-stylegan/lib/python2.7/site-packages/onnx_chainer/functions/math.py", line 275
n1 = gb.op('Sub', [one_name, p], **kwargs, **kwargs2)
^
SyntaxError: invalid syntax
Author
I hadn't see such problem, but it seems coms from python version.
I tried onnx_chainer on git master's HEAD with python 3.7.3.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @knok, thanks for sharing your script! I'm looking for a way to export a StyleGAN model into onnx, and if I understood correctly that's what sg-onnx.py is doing.
If I try to run it, I get the error below (apparently I'm missing a module). Do you have any suggestion on how to fix it? Thanks heaps
File "sg-onnx.py", line 16, in <module> from net import Discriminator, StyleGenerator, MappingNetwork, SynthesisBlock, EqualizedConv2d ModuleNotFoundError: No module named 'net'