Skip to content

Instantly share code, notes, and snippets.

@ZhangMenghe
Created March 5, 2019 22:57
Show Gist options
  • Select an option

  • Save ZhangMenghe/a7807b1cc4503a94b9107e018aae9c99 to your computer and use it in GitHub Desktop.

Select an option

Save ZhangMenghe/a7807b1cc4503a94b9107e018aae9c99 to your computer and use it in GitHub Desktop.
Restore tensorflow checkpoints and add new variables to train
#!/usr/bin/env python
import sys
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
class Model(object):
def __init__(self, x, with_new_layer=False):
self.with_new_layer = with_new_layer
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
self.build(x)
def build(self, x):
input_layer = tf.reshape(x, [-1, 28, 28, 1])
# First Convolutional Layer
# Input Tensor Shape: [batch_size, 28, 28, 1]
# Output Tensor Shape: [batch_size, 28, 28, 1]
conv = tf.layers.conv2d(
inputs=input_layer,
filters=1,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu,
name="first_conv")
if self.with_new_layer:
conv = tf.layers.conv2d(
inputs=conv,
filters=1,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu,
name="middle_conv")
# Last Convolutional Layer
# Input Tensor Shape: [batch_size, 28, 28, 1]
# Output Tensor Shape: [batch_size, 28, 28, 1]
conv = tf.layers.conv2d(
inputs=conv,
filters=1,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu,
name="last_conv")
self.result = tf.reduce_mean(conv)
def optimistic_restore(session, save_file):
reader = tf.train.NewCheckpointReader(save_file)
saved_shapes = reader.get_variable_to_shape_map()
var_names = sorted([(var.name, var.name.split(':')[0]) for
var in tf.global_variables()
if var.name.split(':')[0] in saved_shapes])
restore_vars = []
name2var = dict(zip(map(lambda x: x.name.split(':')[0],
tf.global_variables()),
tf.global_variables()))
with tf.variable_scope('', reuse=True):
for var_name, saved_var_name in var_names:
curr_var = name2var[saved_var_name]
var_shape = curr_var.get_shape().as_list()
if var_shape == saved_shapes[saved_var_name]:
restore_vars.append(curr_var)
saver = tf.train.Saver(restore_vars)
saver.restore(session, save_file)
def main():
ckpt_file = '/tmp/model.ckpt'
input_val = np.random.normal(size=(10, 784))
# Create a model and save it
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, [None, 784])
m = Model(x).result
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
y = sess.run([m], feed_dict={x: input_val})
print "y = ", y
save_path = saver.save(sess, ckpt_file)
# Restore the model
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, [None, 784])
m = Model(x).result
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, ckpt_file)
y = sess.run([m], feed_dict={x: input_val})
print "y = ", y
# Restore the model with an extra layer
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, [None, 784])
m = Model(x, with_new_layer=True).result
with tf.Session() as sess:
# Initialzie all variables because restoring from checkpoint
# will not initialize middle_conv layer
sess.run(tf.global_variables_initializer())
optimistic_restore(sess, ckpt_file)
y = sess.run([m], feed_dict={x: input_val})
print "y = ", y
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment