-
-
Save knok/5b240ac74a901ad16344758bb6b90fd2 to your computer and use it in GitHub Desktop.
| import os | |
| import sys | |
| import re | |
| import json | |
| import numpy as np | |
| from PIL import Image | |
| import chainer | |
| import chainer.cuda | |
| from chainer import Variable | |
| from chainer import training | |
| from chainer.training import extension | |
| from chainer.training import extensions | |
| from net import Discriminator, StyleGenerator, MappingNetwork, SynthesisBlock, EqualizedConv2d | |
| import common | |
| from common.utils.record import record_setting | |
| from common.datasets.base.base_dataset import BaseDataset | |
| from config import FLAGS | |
| from common.utils.save_images import convert_batch_images | |
| from common.evaluation.fid import API as FIDAPI, fid_extension | |
| import math | |
| from common.networks.component.rescale import upscale2x | |
| import onnx_chainer | |
| class MyStyleGenerator(chainer.Chain): | |
| def __init__(self, ch=512, enable_blur=False): | |
| super(MyStyleGenerator, self).__init__() | |
| self.max_stage = 17 | |
| with self.init_scope(): | |
| self.blocks = chainer.ChainList( | |
| SynthesisBlock(ch, ch, upsample=False), #4 | |
| SynthesisBlock(ch, ch, upsample=True, enable_blur=enable_blur), #8 | |
| SynthesisBlock(ch, ch, upsample=True, enable_blur=enable_blur), #16 | |
| SynthesisBlock(ch, ch, upsample=True, enable_blur=enable_blur), # 32 | |
| SynthesisBlock(ch // 2, ch, upsample=True, enable_blur=enable_blur), #64 | |
| SynthesisBlock(ch // 4, ch // 2, upsample=True, enable_blur=enable_blur), #128 | |
| SynthesisBlock(ch // 8, ch // 4, upsample=True, enable_blur=enable_blur), #256 | |
| SynthesisBlock(ch // 16, ch // 8, upsample=True, enable_blur=enable_blur), #512 | |
| SynthesisBlock(ch // 32, ch // 16, upsample=True, enable_blur=enable_blur) #1024 | |
| ) | |
| self.outs = chainer.ChainList( | |
| EqualizedConv2d(ch, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 2, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 4, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 8, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 16, 3, 1, 1, 0, gain=1), | |
| EqualizedConv2d(ch // 32, 3, 1, 1, 0, gain=1) | |
| ) | |
| self.n_blocks = len(self.blocks) | |
| self.image_size = 1024 | |
| self.enable_blur = enable_blur | |
| def __call__(self, w, stage=17, add_noise=True, w2=None): | |
| ''' | |
| for alpha in [0, 1), and 2*k+2 + alpha < self.max_stage (-1 <= k <= ...): | |
| stage 0 + alpha : z -> block[0] -> out[0] * 1 | |
| stage 2*k+1 + alpha : z -> ... -> block[k] -> (up -> out[k]) * (1 - alpha) | |
| .................... -> (block[k+1] -> out[k+1]) * (alpha) | |
| stage 2*k+2 + alpha : z -> ............... -> (block[k+1] -> out[k+1]) * 1 | |
| over flow stages continues. | |
| ''' | |
| stage = min(stage, self.max_stage - 1e-8) | |
| alpha = stage - math.floor(stage) | |
| stage = math.floor(stage) | |
| h = None | |
| if stage % 2 == 0: | |
| k = (stage - 2) // 2 | |
| # Enable Style Mixing: | |
| if w2 is not None and k >= 0: | |
| lim = np.random.randint(1, k+2) | |
| else: | |
| lim = k+2 | |
| for i in range(0, (k + 1) + 1): # 0 .. k+1 | |
| if i == lim: | |
| w = w2 | |
| h = self.blocks[i](w, x=h, add_noise=add_noise) | |
| h = self.outs[k + 1](h) | |
| else: | |
| k = (stage - 1) // 2 | |
| if w2 is not None and k >= 1: | |
| lim = np.random.randint(1, k+1) | |
| else: | |
| lim = k+1 | |
| for i in range(0, k + 1): # 0 .. k | |
| if i == lim: | |
| w = w2 | |
| h = self.blocks[i](w, x=h, add_noise=add_noise) | |
| h_0 = self.outs[k](upscale2x(h)) | |
| h_1 = self.outs[k + 1](self.blocks[k + 1](w, x=h, add_noise=add_noise)) | |
| assert 0. <= alpha < 1. | |
| h = (1.0 - alpha) * h_0 + alpha * h_1 | |
| if chainer.configuration.config.train: | |
| return h | |
| else: | |
| min_sample_image_size = 64 | |
| if h.data.shape[2] < min_sample_image_size: # too small | |
| scale = int(min_sample_image_size // h.data.shape[2]) | |
| return F.unpooling_2d(h, scale, scale, 0, outsize=(min_sample_image_size, min_sample_image_size)) | |
| else: | |
| return h | |
| hps_device="--gpu 0" | |
| mpi="--use_mpi=False" | |
| run_iter="--dynamic_batch_size 256,256,256,128,128,64,64,32,32,8,8,4,4,2,2,1,1 --max_stage 13 --stage_interval 1250000" | |
| eval_iter="--evaluation_sample_interval 500 --display_interval 10 --snapshot_interval 5000" | |
| hps_training_dynamics="$eval_iter $mpi $run_iter" | |
| hps_lr="--adam_alpha_g 0.001 --adam_alpha_d 0.001 --adam_beta1 0.0 --adam_beta2 0.999" | |
| hps_hyperparameters="--lambda_gp 5.0 --smoothing 0.999 --keep_smoothed_gen=True" | |
| hps_dataset="--dataset_config $DATASET_CONFIG --dataset_worker_num 16" | |
| hps_output="--out $SCRATCH_ROOT/$SUBPROJ_NAME/$EXPR_ID" | |
| hps_resume="--auto_resume" | |
| argv = hps_lr.split() + hps_training_dynamics.split() + hps_hyperparameters.split() + hps_dataset.split() + hps_device.split() + hps_output.split() + hps_resume.split() | |
| FLAGS(argv) | |
| mapping = MappingNetwork(FLAGS.ch) | |
| generator = StyleGenerator(FLAGS.ch, enable_blur=FLAGS.enable_blur) | |
| discriminator = Discriminator(ch=FLAGS.ch, enable_blur=FLAGS.enable_blur) | |
| logdir = "." | |
| smoothed_generator = MyStyleGenerator(FLAGS.ch, enable_blur=FLAGS.enable_blur) | |
| smoothed_mapping = MappingNetwork(FLAGS.ch) | |
| chainer.serializers.load_npz("%s/SmoothedGenerator_405000.npz" % logdir, smoothed_generator) | |
| chainer.serializers.load_npz("%s/SmoothedMapping_405000.npz" % logdir, smoothed_mapping) | |
| chainer.config.train = False | |
| mapping = smoothed_mapping | |
| generator = smoothed_generator | |
| # gen graph | |
| with chainer.no_backprop_mode(): | |
| xp = generator.xp | |
| example_input = xp.asarray(mapping.make_hidden(1)) | |
| z = chainer.Variable(example_input) | |
| w = mapping(z) | |
| y = generator(w, stage=11, add_noise=True) | |
| chainer.config.train = False | |
| onnx_mapping = onnx_chainer.export(mapping, z, filename="stylegan_map.onnx") | |
| onnx_generator = onnx_chainer.export(generator, w, filename="stylegan_gen.onnx") |
Hi @knok, thanks for sharing your script! I'm looking for a way to export a StyleGAN model into onnx, and if I understood correctly that's what sg-onnx.py is doing.
If I try to run it, I get the error below (apparently I'm missing a module). Do you have any suggestion on how to fix it? Thanks heaps
File "sg-onnx.py", line 16, in <module> from net import Discriminator, StyleGenerator, MappingNetwork, SynthesisBlock, EqualizedConv2d ModuleNotFoundError: No module named 'net'
You need chainer-stylegan code, and place correct path, usually same path with sg-onnx.py.
BTW, I still can't work it because of onnx-chainer didn't support RandomNormal yet.
chainer/onnx-chainer#214
Thanks for your reply and for the information! I need to export the entire model to .onnx, so I guess I'll try to reimplement the code to have control over what's going on :p
BTW I cloned the Chainer StyleGAN repo, placed the script in src/styleganand launched python sg-onnx.py. I got the error below... Am I doing anything wrong?
Traceback (most recent call last):
File "sg-onnx.py", line 27, in <module>
import onnx_chainer
File "/home/alberto/anaconda3/envs/chainer-stylegan/lib/python2.7/site-packages/onnx_chainer/__init__.py", line 3, in <module>
from onnx_chainer.export import convert_parameter # NOQA
File "/home/alberto/anaconda3/envs/chainer-stylegan/lib/python2.7/site-packages/onnx_chainer/export.py", line 12, in <module>
from onnx_chainer.graph import Graph
File "/home/alberto/anaconda3/envs/chainer-stylegan/lib/python2.7/site-packages/onnx_chainer/graph.py", line 7, in <module>
from onnx_chainer.functions.converter import FunctionConverterParams
File "/home/alberto/anaconda3/envs/chainer-stylegan/lib/python2.7/site-packages/onnx_chainer/functions/__init__.py", line 46, in <module>
from onnx_chainer.functions.math import convert_Absolute # NOQA
File "/home/alberto/anaconda3/envs/chainer-stylegan/lib/python2.7/site-packages/onnx_chainer/functions/math.py", line 275
n1 = gb.op('Sub', [one_name, p], **kwargs, **kwargs2)
^
SyntaxError: invalid syntax
I hadn't see such problem, but it seems coms from python version.
I tried onnx_chainer on git master's HEAD with python 3.7.3.
Stack trace: