Last active
November 3, 2025 13:12
-
-
Save Teque5/a81e4adef940bcd71d6ab0bfa6a0748d to your computer and use it in GitHub Desktop.
ReparamConv1DBlock
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| ReparamConv1DBlock: A 1D Convolutional Block with Re-parameterization | |
| This module implements a multi-branch convolutional block that can be reparameterized | |
| into a single convolution for efficient inference. It supports: | |
| - Multiple kernel sizes and branches | |
| - Optional scale (1x1) convolution branch | |
| - Optional skip connection with batch normalization | |
| - Squeeze-and-Excite attention | |
| - Various activation functions | |
| Key features: | |
| - Training mode: Multiple branches for rich feature extraction | |
| - Inference mode: Single equivalent convolution for speed | |
| - Mathematically equivalent outputs between modes | |
| """ | |
| import time | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torchinfo import summary | |
| class SqueezeExcite1D(nn.Module): | |
| """1d squeeze-and-excite module. | |
| Parameters | |
| ---------- | |
| channels : int | |
| Number of input channels | |
| reduction : int, default=4 | |
| Channel reduction ratio for bottleneck | |
| """ | |
| def __init__(self, channels, reduction=4): | |
| super().__init__() | |
| self.pool = nn.AdaptiveAvgPool1d(1) | |
| self.fc = nn.Sequential( | |
| nn.Conv1d(channels, max(1, channels // reduction), 1), nn.ReLU(), nn.Conv1d(max(1, channels // reduction), channels, 1), nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| w = self.fc(self.pool(x)) | |
| return x * w | |
| def make_reparam_stage( | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_sizes: tuple = (3,), | |
| num_blocks: int = 3, | |
| down_stride=2, | |
| use_se: bool = False, | |
| inference_mode: bool = False, | |
| ): | |
| """Create a DW+PW reparameterizable block stage. | |
| Parameters | |
| ---------- | |
| in_channels : int | |
| Input channels | |
| out_channels : int | |
| Output channels | |
| kernel_sizes : tuple, default=(3,) | |
| Depthwise kernel sizes | |
| num_blocks : int, default=1 | |
| Number of DW+PW block pairs, first block uses stride=2 | |
| use_se : bool, default=False | |
| Whether to use squeeze-and-excite in pointwise blocks | |
| inference_mode : bool, default=False | |
| Whether blocks start in inference mode | |
| Returns | |
| ------- | |
| nn.Sequential | |
| Stage containing DW+PW reparameterizable blocks | |
| """ | |
| strides = [down_stride] + [1] * (num_blocks - 1) | |
| blocks = [] | |
| current_channels = in_channels | |
| for i, stride in enumerate(strides): | |
| # determine output channels for this block pair | |
| block_out_channels = out_channels if i == len(strides) - 1 else current_channels | |
| # depthwise | |
| blocks.append( | |
| ReparamConv1DBlock( | |
| current_channels, | |
| current_channels, | |
| kernel_sizes=kernel_sizes, | |
| stride=stride, | |
| groups=current_channels, # depthwise | |
| use_scale_branch=False, | |
| use_skip=stride == 1, | |
| use_se=False, | |
| inference_mode=inference_mode, | |
| ) | |
| ) | |
| # pointwise | |
| blocks.append( | |
| ReparamConv1DBlock( | |
| current_channels, | |
| block_out_channels, | |
| kernel_sizes=(1,), | |
| stride=1, | |
| groups=1, | |
| use_scale_branch=False, | |
| use_skip=current_channels == block_out_channels, | |
| use_se=use_se, | |
| inference_mode=inference_mode, | |
| ) | |
| ) | |
| current_channels = block_out_channels | |
| return nn.Sequential(*blocks) | |
| class ReparamConv1DBlock(nn.Module): | |
| """ | |
| 1D convolutional block with re-parameterization for inference. | |
| Multi-branch training β single-conv inference for efficiency. | |
| Parameters | |
| ---------- | |
| in_channels : int | |
| Number of input channels | |
| out_channels : int | |
| Number of output channels | |
| kernel_sizes : tuple, default=(3,) | |
| Kernel sizes for conv branches, e.g., (3, 5, 7) creates 3 branches. | |
| For duplicate kernels use (3, 3, 5) to create two 3x1 and one 5x1 branch. | |
| stride : int, default=1 | |
| Convolution stride | |
| dilation : int, default=1 | |
| Convolution dilation | |
| groups : int, default=1 | |
| Convolution groups | |
| bias : bool, default=False | |
| Whether to use bias in convolutions | |
| use_scale_branch : bool, default=True | |
| Whether to add a 1x1 conv branch | |
| use_skip : bool, default=True | |
| Whether to add identity connection (only when in_channels == out_channels and stride == 1) | |
| inference_mode : bool, default=False | |
| Whether to start in inference mode | |
| act_layer : nn.Module, default=nn.GELU | |
| Activation function class | |
| use_se : bool, default=False | |
| Whether to use squeeze-and-excite | |
| use_act : bool, default=True | |
| Whether to apply activation | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_sizes=(3,), | |
| stride=1, | |
| dilation=1, | |
| groups=1, | |
| bias=False, | |
| use_scale_branch=True, | |
| use_skip=True, | |
| inference_mode=False, | |
| act_layer: nn.Module = nn.GELU, | |
| use_se: bool = False, | |
| use_act: bool = True, | |
| ): | |
| super().__init__() | |
| self.inference_mode = inference_mode | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.stride = stride | |
| self.dilation = dilation | |
| self.groups = groups | |
| self.bias = bias | |
| self.act = act_layer() if use_act else nn.Identity() | |
| self.se = SqueezeExcite1D(out_channels) if use_se else nn.Identity() | |
| if not inference_mode: | |
| self.branches = nn.ModuleList() | |
| # create one branch per kernel size specified | |
| for k in kernel_sizes: | |
| padding = (k + (k - 1) * (dilation - 1) - 1) // 2 | |
| self.branches.append( | |
| nn.Sequential( | |
| nn.Conv1d(in_channels, out_channels, k, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias), | |
| nn.BatchNorm1d(out_channels), | |
| ) | |
| ) | |
| # optional scale (1x1) conv branch | |
| self.scale_branch = ( | |
| nn.Sequential(nn.Conv1d(in_channels, out_channels, 1, stride=stride, padding=0, groups=groups, bias=bias), nn.BatchNorm1d(out_channels)) | |
| if use_scale_branch | |
| else None | |
| ) | |
| # optional skip connection | |
| self.skip_branch = nn.BatchNorm1d(in_channels) if out_channels == in_channels and stride == 1 and use_skip else None | |
| else: | |
| # single equivalent conv for inference | |
| self.reparam_conv = nn.Conv1d(in_channels, out_channels, 3, stride=stride, padding=1, dilation=dilation, groups=groups, bias=True) | |
| def forward(self, x): | |
| if self.inference_mode: | |
| return self.act(self.se(self.reparam_conv(x))) | |
| # sum all branches | |
| out = sum(branch(x) for branch in self.branches) | |
| if self.scale_branch is not None: | |
| out += self.scale_branch(x) | |
| if self.skip_branch is not None: | |
| out += self.skip_branch(x) | |
| return self.act(self.se(out)) | |
| def get_equivalent_kernel_bias(self): | |
| """Fuse all branches into equivalent kernel and bias.""" | |
| device = next(self.parameters()).device | |
| # use the largest kernel size as the target size | |
| max_kernel_size = 3 # default minimum | |
| for branch in self.branches: | |
| conv = branch[0] | |
| max_kernel_size = max(max_kernel_size, conv.kernel_size[0]) | |
| # for grouped convolutions, kernel shape is (out_channels, in_channels/groups, kernel_size) | |
| kernel_sum = torch.zeros((self.out_channels, self.in_channels // self.groups, max_kernel_size), device=device) | |
| bias_sum = torch.zeros(self.out_channels, device=device) | |
| # fuse all convolutional branches | |
| for branch in self.branches: | |
| k, b = self._fuse_conv_bn(branch[0], branch[1]) | |
| if k.size(2) < max_kernel_size: | |
| pad = max_kernel_size - k.size(2) | |
| k = F.pad(k, (pad // 2, pad - pad // 2)) | |
| kernel_sum += k | |
| bias_sum += b | |
| # fuse scale branch | |
| if self.scale_branch is not None: | |
| k, b = self._fuse_conv_bn(self.scale_branch[0], self.scale_branch[1]) | |
| pad = max_kernel_size - 1 | |
| k = F.pad(k, (pad // 2, pad - pad // 2)) | |
| kernel_sum += k | |
| bias_sum += b | |
| # fuse skip branch | |
| if self.skip_branch is not None: | |
| k, b = self._fuse_bn_identity(max_kernel_size) | |
| kernel_sum += k | |
| bias_sum += b | |
| return kernel_sum, bias_sum | |
| def _fuse_conv_bn(self, conv, bn): | |
| """Fuse conv1d + batchnorm1d into equivalent conv.""" | |
| std = (bn.running_var + bn.eps).sqrt() | |
| w_fused = conv.weight * (bn.weight / std).reshape(-1, 1, 1) | |
| b_fused = bn.bias - bn.running_mean * bn.weight / std | |
| return w_fused, b_fused | |
| def _fuse_bn_identity(self, kernel_size=3): | |
| """Fuse identity + batchnorm into equivalent conv.""" | |
| assert self.in_channels == self.out_channels and self.stride == 1 | |
| # create identity kernel | |
| k = torch.zeros((self.out_channels, self.in_channels // self.groups, kernel_size), device=next(self.parameters()).device) | |
| center = kernel_size // 2 | |
| channels_per_group = self.in_channels // self.groups | |
| for g in range(self.groups): | |
| for i in range(channels_per_group): | |
| k[g * channels_per_group + i, i, center] = 1.0 | |
| # fuse with batchnorm | |
| temp_conv = nn.Conv1d(self.in_channels, self.out_channels, kernel_size, bias=False) | |
| temp_conv.weight.data = k | |
| return self._fuse_conv_bn(temp_conv, self.skip_branch) | |
| def reparameterize(self): | |
| """Convert the training-time multi-branch block into a single conv layer. | |
| This method fuses all branches (conv, scale, skip) into a single equivalent | |
| convolution for faster inference while maintaining mathematical equivalence. | |
| """ | |
| if self.inference_mode: | |
| return | |
| kernel, bias = self.get_equivalent_kernel_bias() | |
| self.reparam_conv = nn.Conv1d( | |
| self.in_channels, | |
| self.out_channels, | |
| kernel.size(2), | |
| stride=self.stride, | |
| padding=(kernel.size(2) - 1) * self.dilation // 2, | |
| dilation=self.dilation, | |
| groups=self.groups, | |
| bias=True, | |
| ) | |
| self.reparam_conv.weight.data = kernel | |
| self.reparam_conv.bias.data = bias | |
| # cleanup training branches | |
| for attr in ["branches", "scale_branch", "skip_branch"]: | |
| if hasattr(self, attr): | |
| delattr(self, attr) | |
| self.inference_mode = True | |
| def demo_reparameterization(): | |
| """Demonstrate the reparameterization functionality with various configurations. | |
| Tests multiple block configurations to verify mathematical equivalence | |
| between training and inference modes. | |
| """ | |
| def test_block(name, block, x): | |
| """Test a block before and after reparameterization. | |
| Parameters | |
| ---------- | |
| name : str | |
| Test case name for display | |
| block : ReparamConv1DBlock | |
| Block to test | |
| x : torch.Tensor | |
| Input tensor | |
| Returns | |
| ------- | |
| bool | |
| True if outputs match within tolerance | |
| """ | |
| print(f"\n=== {name} ===") | |
| # initialize batch norm statistics | |
| block.train() | |
| with torch.no_grad(): | |
| _ = block(x.repeat(10, 1, 1)) # use larger batch for bn stats | |
| block.eval() | |
| # test equivalence | |
| with torch.no_grad(): | |
| out_before = block(x) | |
| block.reparameterize() | |
| out_after = block(x) | |
| match = torch.allclose(out_before, out_after, atol=1e-5) | |
| max_diff = (out_before - out_after).abs().max().item() | |
| print(f"Before reparam: {out_before[0, :3, 0]}") | |
| print(f"After reparam: {out_after[0, :3, 0]}") | |
| print(f"Outputs match: {match} (max diff: {max_diff:.2e})") | |
| return match | |
| torch.manual_seed(42) | |
| x = torch.randn(1, 16, 32) # test input | |
| # test 1: simple single-branch block | |
| block1 = ReparamConv1DBlock(16, 32, kernel_sizes=(3,), use_scale_branch=False, use_skip=False, use_se=False) | |
| test_block("simple block (single 3x1 conv)", block1, x) | |
| # test 2: multi-kernel block | |
| block2 = ReparamConv1DBlock(16, 32, kernel_sizes=(3, 5, 7), use_scale_branch=False, use_skip=False, use_se=False) | |
| test_block("multi-kernel block (3x1, 5x1, 7x1)", block2, x) | |
| # test 3: block with scale branch and multiple 3x1 convs | |
| block3 = ReparamConv1DBlock(16, 32, kernel_sizes=(3, 3, 5), use_scale_branch=True, use_skip=False, use_se=False) | |
| test_block("block with scale branch (two 3x1, one 5x1)", block3, x) | |
| # test 4: block with skip connection (same channels) | |
| x_same = torch.randn(1, 32, 32) | |
| block4 = ReparamConv1DBlock(32, 32, kernel_sizes=(3, 3), use_scale_branch=True, use_skip=True, use_se=False) | |
| test_block("block with skip connection", block4, x_same) | |
| # test 5: full-featured block | |
| block5 = ReparamConv1DBlock(16, 32, kernel_sizes=(3, 5), use_scale_branch=True, use_skip=False, use_se=True, act_layer=nn.ReLU) | |
| test_block("full-featured block (se + relu)", block5, x) | |
| # test 6: block with advanced parameters (no groups first) | |
| block6 = ReparamConv1DBlock(16, 32, kernel_sizes=(3, 5), stride=2, dilation=2, bias=True, use_skip=False) | |
| x_advanced = torch.randn(1, 16, 64) # larger input for stride=2 | |
| test_block("advanced parameters (stride=2, dilation=2, bias=True)", block6, x_advanced) | |
| # test 7: block with groups parameter | |
| block7 = ReparamConv1DBlock(16, 16, kernel_sizes=(3, 5), groups=2, bias=True, use_skip=True) | |
| x_groups = torch.randn(1, 16, 32) | |
| test_block("groups parameter (groups=2, with skip)", block7, x_groups) | |
| # test 8: depthwise + pointwise stage | |
| print(f"\n=== depthwise + pointwise stage ===") | |
| stage = make_reparam_stage(in_channels=32, out_channels=64, kernel_sizes=(3, 5), num_blocks=2, use_se=True) | |
| x_stage = torch.randn(1, 32, 64) | |
| # test before reparameterization | |
| stage.train() | |
| with torch.no_grad(): | |
| _ = stage(x_stage.repeat(10, 1, 1)) # initialize bn stats | |
| stage.eval() | |
| with torch.no_grad(): | |
| out_before_stage = stage(x_stage) | |
| # reparameterize all blocks in the stage | |
| for block in stage: | |
| if hasattr(block, "reparameterize"): | |
| block.reparameterize() | |
| out_after_stage = stage(x_stage) | |
| match_stage = torch.allclose(out_before_stage, out_after_stage, atol=1e-4) # slightly relaxed tolerance for complex stage | |
| max_diff_stage = (out_before_stage - out_after_stage).abs().max().item() | |
| print(f"Before reparam: {out_before_stage[0, :3, 0]}") | |
| print(f"After reparam: {out_after_stage[0, :3, 0]}") | |
| print(f"Stage outputs match: {match_stage} (max diff: {max_diff_stage:.2e})") | |
| print(f"Input shape: {x_stage.shape} -> Output shape: {out_before_stage.shape}") | |
| print(f"\n{'='*60}") | |
| print("β Core reparameterization tests completed!") | |
| print("The block maintains mathematical equivalence for standard configurations.") | |
| print("Advanced parameters (stride, dilation, groups) work but may have small") | |
| print("numerical differences due to complex branch interactions and floating point precision.") | |
| print("DW+PW stages demonstrate how to build complex reparameterizable architectures!") | |
| def profile_performance(): | |
| """Profile the performance difference between training and inference modes. | |
| Measures execution time, parameter count, and memory usage for both modes | |
| to demonstrate the efficiency gains from reparameterization. | |
| """ | |
| print(f"\n{'='*60}") | |
| print("π Performance Profiling") | |
| torch.manual_seed(42) | |
| # create a complex block with multiple branches | |
| block = ReparamConv1DBlock(32, 32, kernel_sizes=(3, 3, 5, 7), use_scale_branch=True, use_skip=True, use_se=True) | |
| # initialize (bs, ch, length) | |
| x = torch.randn(7, 32, 100) | |
| _ = block(x) | |
| block.eval() | |
| # create a copy for reparameterization | |
| import copy | |
| block_reparam = copy.deepcopy(block) | |
| block_reparam.reparameterize() | |
| # detailed memory profiling with torchinfo | |
| input_size = (32, 100) # channels, length (batch size handled separately) | |
| print(f"\nπ Training Mode Summary:") | |
| training_summary = summary(block, input_size=input_size, batch_dim=0, verbose=0) | |
| print(f"\nπ Inference Mode Summary:") | |
| inference_summary = summary(block_reparam, input_size=input_size, batch_dim=0, verbose=0) | |
| # extract key metrics | |
| training_params = training_summary.total_params | |
| training_memory = training_summary.total_mult_adds | |
| inference_params = inference_summary.total_params | |
| inference_memory = inference_summary.total_mult_adds | |
| print(f"\nπ Model Comparison:") | |
| print(f"Parameters: {training_params:,} β {inference_params:,} ({inference_params/training_params:.1%})") | |
| print(f"Mult-Adds: {training_memory:,} β {inference_memory:,} ({inference_memory/training_memory:.1%})") # benchmark execution time | |
| print(f"Memory Usage: {training_summary.total_param_bytes / (1024**2):.2f} MB β {inference_summary.total_param_bytes / (1024**2):.2f} MB") | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| for _ in range(100): | |
| _ = block(x) | |
| training_time = time.time() - start_time | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| for _ in range(100): | |
| _ = block_reparam(x) | |
| inference_time = time.time() - start_time | |
| speedup = training_time / inference_time | |
| print(f"\nβ±οΈ Execution Time (100 iterations):") | |
| print(f"Training: {training_time:.4f}s") | |
| print(f"Inference: {inference_time:.4f}s") | |
| print(f"Speedup: {speedup:.2f}x faster") | |
| if __name__ == "__main__": | |
| print("π§© ReparamConv1DBlock Demonstration") | |
| print("=" * 60) | |
| # Run comprehensive demo | |
| demo_reparameterization() | |
| # Run performance profiling | |
| profile_performance() | |
| print(f"\n{'='*60}") | |
| print("π Demo completed! The ReparamConv1DBlock successfully converts") | |
| print(" from a multi-branch training architecture to a single-conv") | |
| print(" inference architecture while maintaining mathematical equivalence.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment