Skip to content

Instantly share code, notes, and snippets.

@ab2005
Last active January 6, 2017 22:53
Show Gist options
  • Select an option

  • Save ab2005/86d126eb27a50b9203082f4aa76300d4 to your computer and use it in GitHub Desktop.

Select an option

Save ab2005/86d126eb27a50b9203082f4aa76300d4 to your computer and use it in GitHub Desktop.
TensorFlow Convolution Layer
layer_depth = {
'layer_1': 32,
'layer_2': 64,
'layer_3': 128,
'fully_connected': 512
}
n_classes = 10 # MNIST total classes (0-9 digits)
weights = {
'layer_1': tf.Variable(tf.truncated_normal(
[5, 5, 1, layer_depth['layer_1']])),
'layer_2': tf.Variable(tf.truncated_normal(
[5, 5, layer_depth['layer_1'], layer_depth['layer_2']])),
'layer_3': tf.Variable(tf.truncated_normal(
[5, 5, layer_depth['layer_2'], layer_depth['layer_3']])),
'fully_connected': tf.Variable(tf.truncated_normal(
[4*4*128, layer_depth['fully_connected']])),
'out': tf.Variable(tf.truncated_normal(
[layer_depth['fully_connected'], n_classes]))
}
biases = {
'layer_1': tf.Variable(tf.zeros(layer_depth['layer_1'])),
'layer_2': tf.Variable(tf.zeros(layer_depth['layer_2'])),
'layer_3': tf.Variable(tf.zeros(layer_depth['layer_3'])),
'fully_connected': tf.Variable(tf.zeros(layer_depth['fully_connected'])),
'out': tf.Variable(tf.zeros(n_classes))
}
# Conv2D wrapper, with bias and relu activation
def conv2d(x, W, b, strides=1):
x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
x = tf.nn.bias_add(x, b)
return tf.nn.relu(x)
# Maxpool
def maxpool2d(x, k=2):
return tf.nn.max_pool(x,ksize=[1, k, k, 1],strides=[1, k, k, 1],padding='SAME')
# Create model
def conv_net(x, weights, biases):
# Layer 1 - 28*28*1 to 14*14*32
conv1 = conv2d(x, weights['layer_1'], biases['layer_1'])
conv1 = maxpool2d(conv1)
# Layer 2 - 14*14*32 to 7*7*64
conv2 = conv2d(conv1, weights['layer_2'], biases['layer_2'])
conv2 = maxpool2d(conv2)
# Layer 3 - 7*7*64 to 4*4*128
conv3 = conv2d(conv2, weights['layer_3'], biases['layer_3'])
conv3 = maxpool2d(conv3)
# Fully connected layer - 4*4*128 to 512
# Reshape conv3 output to fit fully connected layer input
fc1 = tf.reshape(conv3,[-1, weights['fully_connected'].get_shape().as_list()[0]])
fc1 = tf.add(tf.matmul(fc1, weights['fully_connected']),biases['fully_connected'])
fc1 = tf.nn.tanh(fc1)
# Output Layer - class prediction - 512 to 10
out = tf.add(tf.matmul(fc1, weights['out']), biases['out'])
return out
# Image Properties
image_width = 10
image_height = 10
color_channels = 3
# Convolution filter
filter_size_width = 5
filter_size_height = 5
# Output depth
k_output = 64
# stride
sampling = [1, 2, 2, 1]
image_shape = [None, image_width, image_height, color_channels]
conv_filter = [filter_size_width, filter_size_height, color_channels, k_output]
# Input, Weight and bias
input = tf.placeholder(tf.float32,shape=image_shape)
weight = tf.Variable(tf.truncated_normal(conv_filter))
bias = tf.Variable(tf.zeros(k_output))
# Apply Convolution, Add bias, Apply activation function
conv_layer = tf.nn.conv2d(input, weight, strides=sampling, padding='SAME')
conv_layer = tf.nn.bias_add(conv_layer, bias)
conv_layer = tf.nn.relu(conv_layer)
# Apply Max Pooling
conv_layer = tf.nn.max_pool(conv_layer, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# Image Properties
image_width = 10
image_height = 10
color_channels = 3
# Convolution filter
filter_size_width = 5
filter_size_height = 5
# Output depth
k_output = 64
# stride
sampling = [1, 2, 2, 1]
image_shape = [None, image_width, image_height, color_channels]
conv_filter = [filter_size_width, filter_size_height, color_channels, k_output]
# Input, Weight and bias
input = tf.placeholder(tf.float32,shape=image_shape)
weight = tf.Variable(tf.truncated_normal(conv_filter))
bias = tf.Variable(tf.zeros(k_output))
# Apply Convolution, Add bias, Apply activation function
conv_layer = tf.nn.conv2d(input, weight, strides=sampling, padding='SAME')
conv_layer = tf.nn.bias_add(conv_layer, bias)
conv_layer = tf.nn.relu(conv_layer)
====================== CONV NET ===========================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment