Last active
April 26, 2017 23:25
-
-
Save datakop/9c5ca511526c2fd7b8ccbb4edce734fc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| require 'nn'; | |
| net = nn.Sequential() | |
| conv1 = nn.SpatialConvolution(1,64,3,3,2,2,1,1) | |
| conv2 = nn.SpatialConvolution(64,128,3,3,1,1,1,1) | |
| conv3 = nn.SpatialConvolution(128,128,3,3,2,2,1,1) | |
| conv4 = nn.SpatialConvolution(128,256,3,3,1,1,1,1) | |
| conv5 = nn.SpatialConvolution(256,256,3,3,2,2,1,1) | |
| conv6 = nn.SpatialConvolution(256,512,3,3,1,1,1,1) | |
| conv11 = nn.SpatialConvolution(512,512,3,3,1,1,1,1) | |
| conv12 = nn.SpatialConvolution(512,256,3,3,1,1,1,1) | |
| conv13 = nn.SpatialConvolution(256,128,3,3,1,1,1,1) | |
| conv14 = nn.SpatialConvolution(128,64,3,3,1,1,1,1) | |
| conv15 = nn.SpatialConvolution(64,64,3,3,1,1,1,1) | |
| conv16 = nn.SpatialConvolution(64,32,3,3,1,1,1,1) | |
| conv17 = nn.SpatialConvolution(32,3,3,3,1,1,1,1) | |
| bn1 = nn.SpatialBatchNormalization(64, 1e-05, 0.1, True) | |
| bn2 = nn.SpatialBatchNormalization(128, 1e-05, 0.1, True) | |
| bn3 = nn.SpatialBatchNormalization(128, 1e-05, 0.1, True) | |
| bn4= nn.SpatialBatchNormalization(256, 1e-05, 0.1, True) | |
| bn5 = nn.SpatialBatchNormalization(256, 1e-05, 0.1, True) | |
| bn6 = nn.SpatialBatchNormalization(512, 1e-05, 0.1, True) | |
| bn11 = nn.SpatialBatchNormalization(512, 1e-05, 0.1, True) | |
| bn12 = nn.SpatialBatchNormalization(256, 1e-05, 0.1, True) | |
| bn13 = nn.SpatialBatchNormalization(128, 1e-05, 0.1, True) | |
| bn14 = nn.SpatialBatchNormalization(64, 1e-05, 0.1, True) | |
| bn15 = nn.SpatialBatchNormalization(64, 1e-05, 0.1, True) | |
| bn16 = nn.SpatialBatchNormalization(32, 1e-05, 0.1, True) | |
| ups1 = nn.SpatialUpSamplingBilinear(2) | |
| ups2 = nn.SpatialUpSamplingBilinear(2) | |
| ups3 = nn.SpatialUpSamplingBilinear(2) | |
| -- Low-Level Features network | |
| net:add(conv1) | |
| net:add(bn1) | |
| net:add(nn.ReLU()) | |
| net:add(conv2) | |
| net:add(bn2) | |
| net:add(nn.ReLU()) | |
| net:add(conv3) | |
| net:add(bn3) | |
| net:add(nn.ReLU()) | |
| net:add(conv4) | |
| net:add(bn4) | |
| net:add(nn.ReLU()) | |
| net:add(conv5) | |
| net:add(bn5) | |
| net:add(nn.ReLU()) | |
| net:add(conv6) | |
| net:add(bn6) | |
| net:add(nn.ReLU()) | |
| -- Middle-Level Features network | |
| net:add(conv11) | |
| net:add(bn11) | |
| net:add(nn.ReLU()) | |
| net:add(conv12) | |
| net:add(bn12) | |
| net:add(nn.ReLU()) | |
| -- Colorization network | |
| net:add(conv13) | |
| net:add(bn13) | |
| net:add(nn.ReLU()) | |
| net:add(ups1) | |
| net:add(conv14) | |
| net:add(bn14) | |
| net:add(nn.ReLU()) | |
| net:add(conv15) | |
| net:add(bn15) | |
| net:add(nn.ReLU()) | |
| net:add(ups2) | |
| net:add(conv16) | |
| net:add(bn16) | |
| net:add(nn.ReLU()) | |
| net:add(conv17) | |
| net:add(nn.Sigmoid()) | |
| net:add(ups3) | |
| modules = { | |
| conv1=conv1, | |
| conv2=conv2, | |
| conv3=conv3, | |
| conv4=conv4, | |
| conv5=conv5, | |
| conv6=conv6, | |
| conv11=conv11, | |
| conv12=conv12, | |
| conv13=conv13, | |
| conv14=conv14, | |
| conv15=conv15, | |
| conv16=conv16, | |
| conv17=conv17, | |
| bn1=bn1, | |
| bn2=bn2, | |
| bn3=bn3, | |
| bn4=bn4, | |
| bn5=bn5, | |
| bn6=bn6, | |
| bn11=bn11, | |
| bn12=bn12, | |
| bn13=bn13, | |
| bn14=bn14, | |
| bn15=bn15, | |
| bn16=bn16} | |
| npy4th = require 'npy4th' | |
| function string.starts(String,Start) | |
| return string.sub(String,1,string.len(Start))==Start | |
| end | |
| net:evaluate() | |
| -- for k, v in pairs(modules) do | |
| -- if string.starts(k, "conv") then | |
| -- v.weight = npy4th.loadnpy( | |
| -- string.format("/mnt/hdd/b.kopin/model/weights_train/module.%s.weight.npy", k) | |
| -- ):double() | |
| -- v.bias = npy4th.loadnpy( | |
| -- string.format("/mnt/hdd/b.kopin/model/weights_train/module.%s.bias.npy", k) | |
| -- ):double() | |
| -- end | |
| -- if string.starts(k, "bn") then | |
| -- v.weight = npy4th.loadnpy( | |
| -- string.format("/mnt/hdd/b.kopin/model/weights_train/module.%s.weight.npy", k) | |
| -- ):double() | |
| -- v.bias = npy4th.loadnpy( | |
| -- string.format("/mnt/hdd/b.kopin/model/weights_train/module.%s.bias.npy", k) | |
| -- ):double() | |
| -- v.running_mean = npy4th.loadnpy( | |
| -- string.format("/mnt/hdd/b.kopin/model/weights_train/module.%s.running_mean.npy", k) | |
| -- ):double() | |
| -- v.running_var = npy4th.loadnpy( | |
| -- string.format("/mnt/hdd/b.kopin/model/weights_train/module.%s.running_var.npy", k) | |
| -- ):double() | |
| -- end | |
| -- end | |
| require 'hdf5' | |
| model_hdf5 = hdf5.open('/mnt/hdd/b.kopin/model/model.h5', 'r') | |
| for k, v in pairs(modules) do | |
| if string.starts(k, "conv") then | |
| v.weight = model_hdf5:read(string.format('module.%s.weight', k)):all() | |
| v.bias = model_hdf5:read(string.format('module.%s.bias', k)):all() | |
| end | |
| if string.starts(k, "bn") then | |
| v.weight = model_hdf5:read(string.format('module.%s.weight', k)):all() | |
| v.bias = model_hdf5:read(string.format('module.%s.bias', k)):all() | |
| v.running_mean = model_hdf5:read(string.format('module.%s.running_mean', k)):all() | |
| v.running_var = model_hdf5:read(string.format('module.%s.running_var', k)):all() | |
| end | |
| end | |
| model_hdf5:close() | |
| require "torch" | |
| require "image" | |
| -- convert rgb to grayscale by averaging channel intensities | |
| function rgb2gray(im) | |
| -- Image.rgb2y uses a different weight mixture | |
| local dim, w, h = im:size()[1], im:size()[2], im:size()[3] | |
| if dim ~= 3 then | |
| print('<error> expected 3 channels') | |
| return im | |
| end | |
| -- a cool application of tensor:select | |
| local r = im:select(1, 1) | |
| local g = im:select(1, 2) | |
| local b = im:select(1, 3) | |
| local z = torch.Tensor(w, h):zero() | |
| -- z = z + 0.21r | |
| z = z:add(0.21, r) | |
| z = z:add(0.72, g) | |
| z = z:add(0.07, b) | |
| return z | |
| end | |
| input = image.load("/mnt/hdd/b.kopin/tests/test/23.jpg", 3) | |
| input_grey = rgb2gray(input) | |
| out = net:forward(torch.reshape(input_grey,torch.LongStorage{1,1, | |
| input_grey:size()[input_grey:size():size()], | |
| input_grey:size()[input_grey:size():size()-1] | |
| }))[1] | |
| -- out = net:forward(torch.reshape(input_grey, torch.LongStorage{1,1,224,224})) | |
| -- itorch.image(torch.reshape(input2, torch.LongStorage{1,1,224,224})) | |
| itorch.image(out) | |
| input_lab = image.rgb2lab(input) | |
| out_lab = image.rgb2lab(out:clone()) | |
| h = out:size()[out:size():size() - 1] | |
| w = out:size()[out:size():size()] | |
| input_l = input_lab[{{1},{},{}}][1]:clone() | |
| out_l = out_lab:clone()[1] | |
| input_l_scaled = image.scale(input_l, w, h, "bilinear") -- image.scale(input_l, out_l) | |
| out_lab[{{1},{},{}}] = input_l_scaled | |
| out_new = image.lab2rgb(out_lab) | |
| itorch.image(out_new) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment