Created
January 19, 2017 15:59
-
-
Save prhbrt/8265619db9b05fd7093561e21daf8a28 to your computer and use it in GitHub Desktop.
V-Net in Keras and tensorflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy\n", | |
| "import warnings\n", | |
| "from keras.layers import Convolution3D, Input, merge, RepeatVector, Activation\n", | |
| "from keras.models import Model\n", | |
| "from keras.layers.advanced_activations import PReLU\n", | |
| "from keras import activations, initializations, regularizers\n", | |
| "from keras.engine import Layer, InputSpec\n", | |
| "from keras.utils.np_utils import conv_output_length\n", | |
| "from keras.optimizers import Adam\n", | |
| "from keras.callbacks import ModelCheckpoint\n", | |
| "import keras.backend as K\n", | |
| "from keras.engine.topology import Layer\n", | |
| "import functools\n", | |
| "import tensorflow as tf\n", | |
| "import pickle\n", | |
| "import time" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "with open('../data/PROMISE2012/train_data.p3', 'rb') as f:\n", | |
| " X, y = pickle.load(f)\n", | |
| " \n", | |
| "X = X.reshape(X.shape + (1,)).astype(numpy.float32)\n", | |
| "y = y.reshape(y.shape + (1,))\n", | |
| "y = numpy.concatenate([y, ~y], axis=4)\n", | |
| "y=y.astype(numpy.float32)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class Deconvolution3D(Layer):\n", | |
| " def __init__(self, nb_filter, kernel_dims, output_shape, subsample):\n", | |
| " self.nb_filter = nb_filter\n", | |
| " self.kernel_dims = kernel_dims\n", | |
| " self.strides = (1,) + subsample + (1,)\n", | |
| " self.output_shape_ = output_shape\n", | |
| " assert K.backend() == 'tensorflow'\n", | |
| " super(Deconvolution3D, self).__init__()\n", | |
| " \n", | |
| " def build(self, input_shape):\n", | |
| " assert len(input_shape) == 5\n", | |
| " self.input_shape_ = input_shape\n", | |
| " W_shape = self.kernel_dims + (self.nb_filter, input_shape[4], )\n", | |
| " self.W = self.add_weight(W_shape,\n", | |
| " initializer=functools.partial(initializations.glorot_uniform,dim_ordering='tf'),\n", | |
| " name='{}_W'.format(self.name))\n", | |
| " self.b = self.add_weight((1,1,1,self.nb_filter,), initializer='zero', name='{}_b'.format(self.name))\n", | |
| " self.built = True\n", | |
| "\n", | |
| " def get_output_shape_for(self, input_shape):\n", | |
| " return (None, ) + self.output_shape_[1:]\n", | |
| "\n", | |
| " def call(self, x, mask=None):\n", | |
| " return tf.nn.conv3d_transpose(x, self.W, output_shape=self.output_shape_,\n", | |
| " strides=self.strides, padding='SAME', name=self.name) + self.b\n", | |
| "\n", | |
| " def get_config(self):\n", | |
| " base_config = super(Deconvolution3D, self).get_config().copy()\n", | |
| " base_config['output_shape'] = self.output_shape_\n", | |
| " return base_config" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from keras import backend as K\n", | |
| "from keras.engine import Layer\n", | |
| "\n", | |
| "class Softmax(Layer):\n", | |
| " def __init__(self, axis=-1,**kwargs):\n", | |
| " self.axis=axis\n", | |
| " super(Softmax, self).__init__(**kwargs)\n", | |
| "\n", | |
| " def build(self,input_shape):\n", | |
| " pass\n", | |
| "\n", | |
| " def call(self, x,mask=None):\n", | |
| " e = K.exp(x - K.max(x, axis=self.axis, keepdims=True))\n", | |
| " s = K.sum(e, axis=self.axis, keepdims=True)\n", | |
| " return e / s\n", | |
| "\n", | |
| " def get_output_shape_for(self, input_shape):\n", | |
| " return input_shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def downward_layer(input_layer, n_convolutions, n_output_channels):\n", | |
| " inl = input_layer\n", | |
| " for _ in range(n_convolutions-1):\n", | |
| " inl = PReLU()(\n", | |
| " Convolution3D(n_output_channels // 2, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n", | |
| " )\n", | |
| " conv = Convolution3D(n_output_channels // 2, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n", | |
| " add = merge([conv, input_layer], mode='sum')\n", | |
| " downsample = Convolution3D(n_output_channels, 2,2,2, subsample=(2,2,2))(add)\n", | |
| " prelu = PReLU()(downsample)\n", | |
| " return prelu, add" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def upward_layer(input0 ,input1, n_convolutions, n_output_channels):\n", | |
| " merged = merge([input0, input1], mode='concat', concat_axis=4)\n", | |
| " inl = merged\n", | |
| " for _ in range(n_convolutions-1):\n", | |
| " inl = PReLU()(\n", | |
| " Convolution3D(n_output_channels * 4, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n", | |
| " )\n", | |
| " conv = Convolution3D(n_output_channels * 4, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n", | |
| " add = merge([conv, merged], mode='sum')\n", | |
| " shape = add.get_shape().as_list()\n", | |
| " new_shape = (1, shape[1] * 2, shape[2] * 2, shape[3] * 2, n_output_channels)\n", | |
| " upsample = Deconvolution3D(n_output_channels, (4,4,4), new_shape, subsample=(2,2,2))(add)\n", | |
| " return PReLU()(upsample)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Layer 1\n", | |
| "input_layer = Input(shape=(128, 128, 64, 1), name='data')\n", | |
| "conv_1 = Convolution3D(16, 5, 5, 5, border_mode='same', dim_ordering='tf')(input_layer)\n", | |
| "repeat_1 = merge([input_layer] * 16, mode='concat')\n", | |
| "add_1 = merge([conv_1, repeat_1], mode='sum')\n", | |
| "prelu_1_1 = PReLU()(add_1)\n", | |
| "downsample_1 = Convolution3D(32, 2,2,2, subsample=(2,2,2))(prelu_1_1)\n", | |
| "prelu_1_2 = PReLU()(downsample_1)\n", | |
| "\n", | |
| "# Layer 2,3,4\n", | |
| "out2, left2 = downward_layer(prelu_1_2, 2, 64)\n", | |
| "out3, left3 = downward_layer(out2, 2, 128)\n", | |
| "out4, left4 = downward_layer(out3, 2, 256)\n", | |
| "\n", | |
| "# Layer 5\n", | |
| "conv_5_1 = Convolution3D(256, 5, 5, 4, border_mode='same', dim_ordering='tf')(out4)\n", | |
| "prelu_5_1 = PReLU()(conv_5_1)\n", | |
| "conv_5_2 = Convolution3D(256, 5, 5, 4, border_mode='same', dim_ordering='tf')(prelu_5_1)\n", | |
| "prelu_5_2 = PReLU()(conv_5_2)\n", | |
| "conv_5_3 = Convolution3D(256, 5, 5, 4, border_mode='same', dim_ordering='tf')(prelu_5_2)\n", | |
| "add_5 = merge([conv_5_3, out4], mode='sum')\n", | |
| "prelu_5_1 = PReLU()(add_5)\n", | |
| "downsample_5 = Deconvolution3D(128, (2,2,2), (1, 16, 16, 8, 128), subsample=(2,2,2))(prelu_5_1)\n", | |
| "prelu_5_2 = PReLU()(downsample_5)\n", | |
| "\n", | |
| "#Layer 6,7,8\n", | |
| "out6 = upward_layer(prelu_5_2, left4, 3, 64)\n", | |
| "out7 = upward_layer(out6, left3, 3, 32)\n", | |
| "out8 = upward_layer(out7, left2, 2, 16)\n", | |
| "\n", | |
| "#Layer 9\n", | |
| "merged_9 = merge([out8, add_1], mode='concat', concat_axis=4)\n", | |
| "conv_9_1 = Convolution3D(32, 5, 5, 5, border_mode='same', dim_ordering='tf')(merged_9)\n", | |
| "add_9 = merge([conv_9_1, merged_9], mode='sum')\n", | |
| "conv_9_2 = Convolution3D(2, 1, 1, 1, border_mode='same', dim_ordering='tf')(add_9)\n", | |
| "\n", | |
| "softmax = Softmax()(conv_9_2)\n", | |
| "\n", | |
| "model = Model(input_layer, softmax)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "_________________________________________________________________________________________________________________\n", | |
| "Layer (type) Output Shape Param # Connected to \n", | |
| "=================================================================================================================\n", | |
| "data (InputLayer) (None, 128, 128, 64, 1) 0 \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_1 (Convolution3D) (None, 128, 128, 64, 16) 2016 data[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_1 (Merge) (None, 128, 128, 64, 16) 0 data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| " data[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_2 (Merge) (None, 128, 128, 64, 16) 0 convolution3d_1[0][0] \n", | |
| " merge_1[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_1 (PReLU) (None, 128, 128, 64, 16) 16777216 merge_2[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_2 (Convolution3D) (None, 64, 64, 32, 32) 4128 prelu_1[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_2 (PReLU) (None, 64, 64, 32, 32) 4194304 convolution3d_2[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_3 (Convolution3D) (None, 64, 64, 32, 32) 128032 prelu_2[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_3 (PReLU) (None, 64, 64, 32, 32) 4194304 convolution3d_3[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_4 (Convolution3D) (None, 64, 64, 32, 32) 128032 prelu_3[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_3 (Merge) (None, 64, 64, 32, 32) 0 convolution3d_4[0][0] \n", | |
| " prelu_2[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_5 (Convolution3D) (None, 32, 32, 16, 64) 16448 merge_3[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_4 (PReLU) (None, 32, 32, 16, 64) 1048576 convolution3d_5[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_6 (Convolution3D) (None, 32, 32, 16, 64) 512064 prelu_4[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_5 (PReLU) (None, 32, 32, 16, 64) 1048576 convolution3d_6[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_7 (Convolution3D) (None, 32, 32, 16, 64) 512064 prelu_5[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_4 (Merge) (None, 32, 32, 16, 64) 0 convolution3d_7[0][0] \n", | |
| " prelu_4[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_8 (Convolution3D) (None, 16, 16, 8, 128) 65664 merge_4[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_6 (PReLU) (None, 16, 16, 8, 128) 262144 convolution3d_8[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_9 (Convolution3D) (None, 16, 16, 8, 128) 2048128 prelu_6[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_7 (PReLU) (None, 16, 16, 8, 128) 262144 convolution3d_9[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_10 (Convolution3D) (None, 16, 16, 8, 128) 2048128 prelu_7[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_5 (Merge) (None, 16, 16, 8, 128) 0 convolution3d_10[0][0] \n", | |
| " prelu_6[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_11 (Convolution3D) (None, 8, 8, 4, 256) 262400 merge_5[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_8 (PReLU) (None, 8, 8, 4, 256) 65536 convolution3d_11[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_12 (Convolution3D) (None, 8, 8, 4, 256) 6553856 prelu_8[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_9 (PReLU) (None, 8, 8, 4, 256) 65536 convolution3d_12[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_13 (Convolution3D) (None, 8, 8, 4, 256) 6553856 prelu_9[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_10 (PReLU) (None, 8, 8, 4, 256) 65536 convolution3d_13[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_14 (Convolution3D) (None, 8, 8, 4, 256) 6553856 prelu_10[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_6 (Merge) (None, 8, 8, 4, 256) 0 convolution3d_14[0][0] \n", | |
| " prelu_8[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_11 (PReLU) (None, 8, 8, 4, 256) 65536 merge_6[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "deconvolution3d_1 (Deconvolution3D) (None, 16, 16, 8, 128) 262272 prelu_11[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_12 (PReLU) (None, 16, 16, 8, 128) 262144 deconvolution3d_1[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_7 (Merge) (None, 16, 16, 8, 256) 0 prelu_12[0][0] \n", | |
| " merge_5[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_15 (Convolution3D) (None, 16, 16, 8, 256) 8192256 merge_7[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_13 (PReLU) (None, 16, 16, 8, 256) 524288 convolution3d_15[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_16 (Convolution3D) (None, 16, 16, 8, 256) 8192256 prelu_13[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_14 (PReLU) (None, 16, 16, 8, 256) 524288 convolution3d_16[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_17 (Convolution3D) (None, 16, 16, 8, 256) 8192256 prelu_14[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_8 (Merge) (None, 16, 16, 8, 256) 0 convolution3d_17[0][0] \n", | |
| " merge_7[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "deconvolution3d_2 (Deconvolution3D) (None, 32, 32, 16, 64) 1048640 merge_8[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_15 (PReLU) (None, 32, 32, 16, 64) 1048576 deconvolution3d_2[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_9 (Merge) (None, 32, 32, 16, 128) 0 prelu_15[0][0] \n", | |
| " merge_4[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_18 (Convolution3D) (None, 32, 32, 16, 128) 2048128 merge_9[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_16 (PReLU) (None, 32, 32, 16, 128) 2097152 convolution3d_18[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_19 (Convolution3D) (None, 32, 32, 16, 128) 2048128 prelu_16[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_17 (PReLU) (None, 32, 32, 16, 128) 2097152 convolution3d_19[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_20 (Convolution3D) (None, 32, 32, 16, 128) 2048128 prelu_17[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_10 (Merge) (None, 32, 32, 16, 128) 0 convolution3d_20[0][0] \n", | |
| " merge_9[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "deconvolution3d_3 (Deconvolution3D) (None, 64, 64, 32, 32) 262176 merge_10[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_18 (PReLU) (None, 64, 64, 32, 32) 4194304 deconvolution3d_3[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_11 (Merge) (None, 64, 64, 32, 64) 0 prelu_18[0][0] \n", | |
| " merge_3[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_21 (Convolution3D) (None, 64, 64, 32, 64) 512064 merge_11[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_19 (PReLU) (None, 64, 64, 32, 64) 8388608 convolution3d_21[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_22 (Convolution3D) (None, 64, 64, 32, 64) 512064 prelu_19[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_12 (Merge) (None, 64, 64, 32, 64) 0 convolution3d_22[0][0] \n", | |
| " merge_11[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "deconvolution3d_4 (Deconvolution3D) (None, 128, 128, 64, 16) 65552 merge_12[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "prelu_20 (PReLU) (None, 128, 128, 64, 16) 16777216 deconvolution3d_4[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_13 (Merge) (None, 128, 128, 64, 32) 0 prelu_20[0][0] \n", | |
| " merge_2[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_23 (Convolution3D) (None, 128, 128, 64, 32) 128032 merge_13[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "merge_14 (Merge) (None, 128, 128, 64, 32) 0 convolution3d_23[0][0] \n", | |
| " merge_13[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "convolution3d_24 (Convolution3D) (None, 128, 128, 64, 2) 66 merge_14[0][0] \n", | |
| "_________________________________________________________________________________________________________________\n", | |
| "softmax_1 (Softmax) (None, 128, 128, 64, 2) 0 convolution3d_24[0][0] \n", | |
| "=================================================================================================================\n", | |
| "Total params: 122,863,826\n", | |
| "Trainable params: 122,863,826\n", | |
| "Non-trainable params: 0\n", | |
| "_________________________________________________________________________________________________________________\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "model.summary(line_length=113)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def dice_coef(y_true, y_pred):\n", | |
| " y_true_f = K.flatten(y_true)\n", | |
| " y_pred_f = K.reshape(y_pred, (-1, 2))\n", | |
| " intersection = K.mean(y_true_f * y_pred_f[:,0]) + K.mean((1.0 - y_true_f) * y_pred_f[:,1])\n", | |
| " \n", | |
| " return 2. * intersection;\n", | |
| "\n", | |
| "def dice_coef_loss(y_true, y_pred):\n", | |
| " return -dice_coef(y_true, y_pred)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "t=time.time()\n", | |
| "y_pred = model.predict(X[:1,:,:,:,:])\n", | |
| "print(time.time() - t)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 1/20\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "model_checkpoint = ModelCheckpoint('unet.hdf5', monitor='loss', save_best_only=True)\n", | |
| "\n", | |
| "model.fit(X, y, batch_size=50, nb_epoch=20, verbose=1)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.5.1" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 1 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, can you tell me what data set was your model used on?