Created
March 5, 2019 22:57
-
-
Save ZhangMenghe/a7807b1cc4503a94b9107e018aae9c99 to your computer and use it in GitHub Desktop.
Restore tensorflow checkpoints and add new variables to train
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| import sys | |
| import tensorflow as tf | |
| import numpy as np | |
| from tensorflow.examples.tutorials.mnist import input_data | |
| class Model(object): | |
| def __init__(self, x, with_new_layer=False): | |
| self.with_new_layer = with_new_layer | |
| with tf.variable_scope('model', reuse=tf.AUTO_REUSE): | |
| self.build(x) | |
| def build(self, x): | |
| input_layer = tf.reshape(x, [-1, 28, 28, 1]) | |
| # First Convolutional Layer | |
| # Input Tensor Shape: [batch_size, 28, 28, 1] | |
| # Output Tensor Shape: [batch_size, 28, 28, 1] | |
| conv = tf.layers.conv2d( | |
| inputs=input_layer, | |
| filters=1, | |
| kernel_size=[5, 5], | |
| padding="same", | |
| activation=tf.nn.relu, | |
| name="first_conv") | |
| if self.with_new_layer: | |
| conv = tf.layers.conv2d( | |
| inputs=conv, | |
| filters=1, | |
| kernel_size=[5, 5], | |
| padding="same", | |
| activation=tf.nn.relu, | |
| name="middle_conv") | |
| # Last Convolutional Layer | |
| # Input Tensor Shape: [batch_size, 28, 28, 1] | |
| # Output Tensor Shape: [batch_size, 28, 28, 1] | |
| conv = tf.layers.conv2d( | |
| inputs=conv, | |
| filters=1, | |
| kernel_size=[5, 5], | |
| padding="same", | |
| activation=tf.nn.relu, | |
| name="last_conv") | |
| self.result = tf.reduce_mean(conv) | |
| def optimistic_restore(session, save_file): | |
| reader = tf.train.NewCheckpointReader(save_file) | |
| saved_shapes = reader.get_variable_to_shape_map() | |
| var_names = sorted([(var.name, var.name.split(':')[0]) for | |
| var in tf.global_variables() | |
| if var.name.split(':')[0] in saved_shapes]) | |
| restore_vars = [] | |
| name2var = dict(zip(map(lambda x: x.name.split(':')[0], | |
| tf.global_variables()), | |
| tf.global_variables())) | |
| with tf.variable_scope('', reuse=True): | |
| for var_name, saved_var_name in var_names: | |
| curr_var = name2var[saved_var_name] | |
| var_shape = curr_var.get_shape().as_list() | |
| if var_shape == saved_shapes[saved_var_name]: | |
| restore_vars.append(curr_var) | |
| saver = tf.train.Saver(restore_vars) | |
| saver.restore(session, save_file) | |
| def main(): | |
| ckpt_file = '/tmp/model.ckpt' | |
| input_val = np.random.normal(size=(10, 784)) | |
| # Create a model and save it | |
| with tf.Graph().as_default(): | |
| x = tf.placeholder(tf.float32, [None, 784]) | |
| m = Model(x).result | |
| saver = tf.train.Saver() | |
| with tf.Session() as sess: | |
| sess.run(tf.global_variables_initializer()) | |
| y = sess.run([m], feed_dict={x: input_val}) | |
| print "y = ", y | |
| save_path = saver.save(sess, ckpt_file) | |
| # Restore the model | |
| with tf.Graph().as_default(): | |
| x = tf.placeholder(tf.float32, [None, 784]) | |
| m = Model(x).result | |
| saver = tf.train.Saver() | |
| with tf.Session() as sess: | |
| saver.restore(sess, ckpt_file) | |
| y = sess.run([m], feed_dict={x: input_val}) | |
| print "y = ", y | |
| # Restore the model with an extra layer | |
| with tf.Graph().as_default(): | |
| x = tf.placeholder(tf.float32, [None, 784]) | |
| m = Model(x, with_new_layer=True).result | |
| with tf.Session() as sess: | |
| # Initialzie all variables because restoring from checkpoint | |
| # will not initialize middle_conv layer | |
| sess.run(tf.global_variables_initializer()) | |
| optimistic_restore(sess, ckpt_file) | |
| y = sess.run([m], feed_dict={x: input_val}) | |
| print "y = ", y | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment