Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save PadLex/236d8178db45d950c5d4e93899fa608a to your computer and use it in GitHub Desktop.

Select an option

Save PadLex/236d8178db45d950c5d4e93899fa608a to your computer and use it in GitHub Desktop.
Convert PyTorch convolutional layer to fully connected layer
"""
The function `torch_conv_layer_to_affine` takes a `torch.nn.Conv2d` layer `conv`
and produces an equivalent `torch.nn.Linear` layer `fc`.
Specifically, this means that the following holds for `x` of a valid shape:
torch.flatten(conv(x)) == fc(torch.flatten(x))
Or equivalently:
conv(x) == fc(torch.flatten(x)).reshape(conv(x).shape)
allowing of course for some floating-point error.
"""
from typing import Tuple
import torch
import torch.nn as nn
import numpy as np
def torch_conv_layer_to_affine(conv: torch.nn.Conv2d, input_size: Tuple[int, int]) -> torch.nn.Linear:
w, h = input_size
with torch.no_grad():
output_size = [
(input_size[i] + 2 * conv.padding[i] - conv.kernel_size[i]) // conv.stride[i]
+ 1
for i in [0, 1]
]
in_shape = (conv.in_channels, w, h)
out_shape = (conv.out_channels, output_size[0], output_size[1])
fc = nn.Linear(in_features=np.prod(in_shape), out_features=np.prod(out_shape))
fc.weight.fill_(0.0)
fc.bias.fill_(0.0)
# Fill in weights and biases
for xo, yo in range2d(output_size[0], output_size[1]):
xi0 = -conv.padding[0] + conv.stride[0] * xo
yi0 = -conv.padding[1] + conv.stride[1] * yo
for xd, yd in range2d(conv.kernel_size[0], conv.kernel_size[1]):
for co in range(conv.out_channels):
# Ensure bias is set once per output unit
fc.bias[enc_tuple((co, xo, yo), out_shape)] = conv.bias[co]
for ci in range(conv.in_channels):
if 0 <= xi0 + xd < w and 0 <= yi0 + yd < h:
cw = conv.weight[co, ci, xd, yd]
fc.weight[
enc_tuple((co, xo, yo), out_shape),
enc_tuple((ci, xi0 + xd, yi0 + yd), in_shape)
] = cw
return fc
def range2d(to_a, to_b):
for a in range(to_a):
for b in range(to_b):
yield a, b
def enc_tuple(tup: Tuple, shape: Tuple) -> int:
# Converts an n-d index (tup) into a linear index based on shape
res = 0
coef = 1
for i in reversed(range(len(shape))):
assert tup[i] < shape[i]
res += coef * tup[i]
coef *= shape[i]
return res
def dec_tuple(x: int, shape: Tuple) -> Tuple:
# Converts a linear index back into an n-d index
res = []
for i in reversed(range(len(shape))):
res.append(x % shape[i])
x //= shape[i]
return tuple(reversed(res))
def convert_network(net, input_channels=1, input_height=28, input_width=28):
# Move the Flatten layer at the top
new_layers = [nn.Flatten()]
# Keep track what the shape should be shape. Start it's (C,H,W).
C, H, W = input_channels, input_height, input_width
for layer in net:
if isinstance(layer, nn.Conv2d):
# Convert the Conv2d layer to a Linear layer
fc = torch_conv_layer_to_affine(layer, (H, W))
new_layers.append(fc)
H_out = (H + 2 * layer.padding[0] - layer.kernel_size[0]) // layer.stride[0] + 1
W_out = (W + 2 * layer.padding[1] - layer.kernel_size[1]) // layer.stride[1] + 1
C = layer.out_channels
H, W = H_out, W_out
elif isinstance(layer, nn.ReLU):
new_layers.append(nn.ReLU())
elif isinstance(layer, nn.Flatten):
pass
elif isinstance(layer, nn.Linear):
new_fc = nn.Linear(layer.in_features, layer.out_features, bias=(layer.bias is not None))
with torch.no_grad():
new_fc.weight.copy_(layer.weight.data)
if layer.bias is not None:
new_fc.bias.copy_(layer.bias.data)
new_layers.append(new_fc)
else:
# If there's any other kind of layer, either skip or handle accordingly
# Assuming we only have Conv, ReLU, Flatten, Linear in these tests
new_layers.append(layer)
return nn.Sequential(*new_layers)
def test_tuple_encoding():
x = enc_tuple((3, 2, 1), (5, 6, 7))
assert dec_tuple(x, (5, 6, 7)) == (3, 2, 1)
print("Tuple encoding test passed.")
def test_layer_conversion():
for stride in [1, 2]:
for padding in [0, 1, 2]:
for filter_size in [3, 4]:
img = torch.rand((1, 2, 6, 7))
conv = nn.Conv2d(2, 5, filter_size, stride=stride, padding=padding)
fc = torch_conv_layer_to_affine(conv, img.shape[2:])
# Compare outputs
res1 = fc(img.reshape((-1))).reshape(conv(img).shape)
res2 = conv(img)
worst_error = (res1 - res2).abs().max()
print(f"stride={stride}, padding={padding}, filter_size={filter_size}, shape={res2.shape}, worst_error={worst_error.item()}")
assert worst_error <= 1.0e-6
print("Layer conversion test passed.")
def test_network_conversion():
net = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.Conv2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.ReLU(),
nn.Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.ReLU(),
nn.Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.ReLU(),
nn.Flatten(start_dim=1, end_dim=-1),
nn.Linear(in_features=392, out_features=50, bias=True),
nn.ReLU(),
nn.Linear(in_features=50, out_features=10, bias=True)
)
in_ch, in_dim = 1, 28
net_converted = convert_network(net, input_channels=in_ch, input_height=in_dim, input_width=in_dim)
# Test with a random batch input of size (N, C, H, W)
# Let's test with a batch size of 4 for more robustness
inp = torch.rand((4, in_ch, in_dim, in_dim))
out_original = net(inp)
out_converted = net_converted(inp)
diff = (out_original - out_converted).abs().max()
print("Network conversion worst error:", diff.item())
assert diff < 1e-5, "Network conversion does not match closely enough."
print("Network conversion test passed.")
if __name__ == "__main__":
test_tuple_encoding()
test_layer_conversion()
test_network_conversion()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment