Skip to content

Instantly share code, notes, and snippets.

@zshn25
Last active July 31, 2023 14:05
Show Gist options
  • Select an option

  • Save zshn25/0b7fdab97c3fa06c0bfd1e528c861041 to your computer and use it in GitHub Desktop.

Select an option

Save zshn25/0b7fdab97c3fa06c0bfd1e528c861041 to your computer and use it in GitHub Desktop.
Conv2d and BatchNorm2d fusion
from torch import nn
from torch.nn.utils.fusion import fuse_conv_bn_eval
def fuse_all_conv_bn(model):
stack = []
for name, module in model.named_children(): # immediate children
if list(module.named_children()): # is not empty (not a leaf)
fuse_all_conv_bn(module)
if isinstance(module, nn.BatchNorm2d):
if isinstance(stack[-1][1], nn.Conv2d):
setattr(model, stack[-1][0], fuse_conv_bn_eval(stack[-1][1], module))
setattr(model, name, nn.Identity())
else:
stack.append((name, module))
torch
torchvision
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torchvision.models.resnet import resnet101\n",
"from fusion import fuse_all_conv_bn\n",
"\n",
"model=resnet101(pretrained=True)\n",
"model.eval()\n",
"rand_input = torch.randn((1,3,256,256))\n",
"\n",
"# Forward pass\n",
"output = model(rand_input)\n",
"print(\"Inference time before fusion:\")\n",
"%timeit model(rand_input)\n",
"\n",
"# Fuse Conv BN\n",
"fuse_all_conv_bn(model)\n",
"print(\"\\nInference time after fusion:\")\n",
"%timeit model(rand_input)\n",
"# compare result\n",
"print(\"\\nError between outputs before and after fusion: \\n\", torch.norm(torch.abs(output - model(rand_input))).data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.7 64-bit ('e2r': conda)",
"language": "python",
"name": "python37764bite2rcondad9123364716e49bdbcac99e0b1d6c310"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment