Forked from vvolhejn/torch_conv_layer_to_fully_connected.py
Last active
December 7, 2024 14:58
-
-
Save PadLex/236d8178db45d950c5d4e93899fa608a to your computer and use it in GitHub Desktop.
Convert PyTorch convolutional layer to fully connected layer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| The function `torch_conv_layer_to_affine` takes a `torch.nn.Conv2d` layer `conv` | |
| and produces an equivalent `torch.nn.Linear` layer `fc`. | |
| Specifically, this means that the following holds for `x` of a valid shape: | |
| torch.flatten(conv(x)) == fc(torch.flatten(x)) | |
| Or equivalently: | |
| conv(x) == fc(torch.flatten(x)).reshape(conv(x).shape) | |
| allowing of course for some floating-point error. | |
| """ | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| def torch_conv_layer_to_affine(conv: torch.nn.Conv2d, input_size: Tuple[int, int]) -> torch.nn.Linear: | |
| w, h = input_size | |
| with torch.no_grad(): | |
| output_size = [ | |
| (input_size[i] + 2 * conv.padding[i] - conv.kernel_size[i]) // conv.stride[i] | |
| + 1 | |
| for i in [0, 1] | |
| ] | |
| in_shape = (conv.in_channels, w, h) | |
| out_shape = (conv.out_channels, output_size[0], output_size[1]) | |
| fc = nn.Linear(in_features=np.prod(in_shape), out_features=np.prod(out_shape)) | |
| fc.weight.fill_(0.0) | |
| fc.bias.fill_(0.0) | |
| # Fill in weights and biases | |
| for xo, yo in range2d(output_size[0], output_size[1]): | |
| xi0 = -conv.padding[0] + conv.stride[0] * xo | |
| yi0 = -conv.padding[1] + conv.stride[1] * yo | |
| for xd, yd in range2d(conv.kernel_size[0], conv.kernel_size[1]): | |
| for co in range(conv.out_channels): | |
| # Ensure bias is set once per output unit | |
| fc.bias[enc_tuple((co, xo, yo), out_shape)] = conv.bias[co] | |
| for ci in range(conv.in_channels): | |
| if 0 <= xi0 + xd < w and 0 <= yi0 + yd < h: | |
| cw = conv.weight[co, ci, xd, yd] | |
| fc.weight[ | |
| enc_tuple((co, xo, yo), out_shape), | |
| enc_tuple((ci, xi0 + xd, yi0 + yd), in_shape) | |
| ] = cw | |
| return fc | |
| def range2d(to_a, to_b): | |
| for a in range(to_a): | |
| for b in range(to_b): | |
| yield a, b | |
| def enc_tuple(tup: Tuple, shape: Tuple) -> int: | |
| # Converts an n-d index (tup) into a linear index based on shape | |
| res = 0 | |
| coef = 1 | |
| for i in reversed(range(len(shape))): | |
| assert tup[i] < shape[i] | |
| res += coef * tup[i] | |
| coef *= shape[i] | |
| return res | |
| def dec_tuple(x: int, shape: Tuple) -> Tuple: | |
| # Converts a linear index back into an n-d index | |
| res = [] | |
| for i in reversed(range(len(shape))): | |
| res.append(x % shape[i]) | |
| x //= shape[i] | |
| return tuple(reversed(res)) | |
| def convert_network(net, input_channels=1, input_height=28, input_width=28): | |
| # Move the Flatten layer at the top | |
| new_layers = [nn.Flatten()] | |
| # Keep track what the shape should be shape. Start it's (C,H,W). | |
| C, H, W = input_channels, input_height, input_width | |
| for layer in net: | |
| if isinstance(layer, nn.Conv2d): | |
| # Convert the Conv2d layer to a Linear layer | |
| fc = torch_conv_layer_to_affine(layer, (H, W)) | |
| new_layers.append(fc) | |
| H_out = (H + 2 * layer.padding[0] - layer.kernel_size[0]) // layer.stride[0] + 1 | |
| W_out = (W + 2 * layer.padding[1] - layer.kernel_size[1]) // layer.stride[1] + 1 | |
| C = layer.out_channels | |
| H, W = H_out, W_out | |
| elif isinstance(layer, nn.ReLU): | |
| new_layers.append(nn.ReLU()) | |
| elif isinstance(layer, nn.Flatten): | |
| pass | |
| elif isinstance(layer, nn.Linear): | |
| new_fc = nn.Linear(layer.in_features, layer.out_features, bias=(layer.bias is not None)) | |
| with torch.no_grad(): | |
| new_fc.weight.copy_(layer.weight.data) | |
| if layer.bias is not None: | |
| new_fc.bias.copy_(layer.bias.data) | |
| new_layers.append(new_fc) | |
| else: | |
| # If there's any other kind of layer, either skip or handle accordingly | |
| # Assuming we only have Conv, ReLU, Flatten, Linear in these tests | |
| new_layers.append(layer) | |
| return nn.Sequential(*new_layers) | |
| def test_tuple_encoding(): | |
| x = enc_tuple((3, 2, 1), (5, 6, 7)) | |
| assert dec_tuple(x, (5, 6, 7)) == (3, 2, 1) | |
| print("Tuple encoding test passed.") | |
| def test_layer_conversion(): | |
| for stride in [1, 2]: | |
| for padding in [0, 1, 2]: | |
| for filter_size in [3, 4]: | |
| img = torch.rand((1, 2, 6, 7)) | |
| conv = nn.Conv2d(2, 5, filter_size, stride=stride, padding=padding) | |
| fc = torch_conv_layer_to_affine(conv, img.shape[2:]) | |
| # Compare outputs | |
| res1 = fc(img.reshape((-1))).reshape(conv(img).shape) | |
| res2 = conv(img) | |
| worst_error = (res1 - res2).abs().max() | |
| print(f"stride={stride}, padding={padding}, filter_size={filter_size}, shape={res2.shape}, worst_error={worst_error.item()}") | |
| assert worst_error <= 1.0e-6 | |
| print("Layer conversion test passed.") | |
| def test_network_conversion(): | |
| net = nn.Sequential( | |
| nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), | |
| nn.Conv2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), | |
| nn.ReLU(), | |
| nn.Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
| nn.ReLU(), | |
| nn.Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
| nn.ReLU(), | |
| nn.Flatten(start_dim=1, end_dim=-1), | |
| nn.Linear(in_features=392, out_features=50, bias=True), | |
| nn.ReLU(), | |
| nn.Linear(in_features=50, out_features=10, bias=True) | |
| ) | |
| in_ch, in_dim = 1, 28 | |
| net_converted = convert_network(net, input_channels=in_ch, input_height=in_dim, input_width=in_dim) | |
| # Test with a random batch input of size (N, C, H, W) | |
| # Let's test with a batch size of 4 for more robustness | |
| inp = torch.rand((4, in_ch, in_dim, in_dim)) | |
| out_original = net(inp) | |
| out_converted = net_converted(inp) | |
| diff = (out_original - out_converted).abs().max() | |
| print("Network conversion worst error:", diff.item()) | |
| assert diff < 1e-5, "Network conversion does not match closely enough." | |
| print("Network conversion test passed.") | |
| if __name__ == "__main__": | |
| test_tuple_encoding() | |
| test_layer_conversion() | |
| test_network_conversion() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment