Skip to content

Instantly share code, notes, and snippets.

@innat
Created November 1, 2025 18:03
Show Gist options
  • Select an option

  • Save innat/476e7b314e423e23960c8ccae90c259e to your computer and use it in GitHub Desktop.

Select an option

Save innat/476e7b314e423e23960c8ccae90c259e to your computer and use it in GitHub Desktop.
adaptive pooling in keras 3
# ATTENTION: In training, this will work on torch and tensorflow backend, not in jax.
import os
os.environ["KERAS_BACKEND"] = "torch" # tensorflow, torch, jax
import keras
from keras import layers, ops
from medicai.utils.swi_utils import ensure_tuple_rep
@keras.saving.register_keras_serializable(package="medicai")
class AdaptivePooling2D(layers.Layer):
"""Parent class for 2D pooling layers with adaptive kernel size.
This layer performs pooling over the input height (H) and width (W) dimensions
such that the output dimensions match the specified `output_size`.
It supports arbitrary input sizes, even when the input H/W is not divisible
by the output H/W.
It assumes the 'channels_last' data format: (batch, H, W, C).
Args:
reduce_function: The reduction method to apply, e.g. `keras.ops.mean` or
`keras.ops.max`.
output_size: An integer or tuple/list of 2 integers specifying
(pooled_rows, pooled_cols). The new size of the H and W dimensions.
"""
def __init__(
self,
reduce_function,
output_size,
**kwargs,
):
self.reduce_function = reduce_function
self.output_size = ensure_tuple_rep(output_size, 2)
super().__init__(**kwargs)
def build(self, input_shape):
if len(input_shape) != 4:
raise ValueError(
f"{self.__class__.__name__} expects input with 4 dims (batch, H, W, C), "
f"but got {input_shape}"
)
super().build(input_shape)
def call(self, inputs):
# (batch, H, W, C)
h_bins = self.output_size[0]
w_bins = self.output_size[1]
# Get input dimensions H and W
input_shape = ops.shape(inputs)
h = input_shape[1]
w = input_shape[2]
# Calculate the start and end indices for each bin using linspace
h_idx = ops.linspace(0, h, h_bins + 1)
w_idx = ops.linspace(0, w, w_bins + 1)
outputs = []
for i in range(h_bins):
row_outputs = []
# Calculate height indices (axes 1)
h_start = ops.cast(ops.floor(h_idx[i]), "int32")
h_end = ops.cast(ops.ceil(h_idx[i + 1]), "int32")
h_end = ops.where(h_end > h, h, h_end)
for j in range(w_bins):
# Calculate width indices (axes 2)
w_start = ops.cast(ops.floor(w_idx[j]), "int32")
w_end = ops.cast(ops.ceil(w_idx[j + 1]), "int32")
w_end = ops.where(w_end > w, w, w_end)
# Slicing: inputs[:, H_slice, W_slice, :]
region = inputs[:, h_start:h_end, w_start:w_end, :]
# Reduction axes are H (1) and W (2)
pooled = self.reduce_function(region, axis=[1, 2], keepdims=True)
row_outputs.append(pooled)
# Concatenate pooled regions along the width axis (axis 2)
outputs.append(ops.concatenate(row_outputs, axis=2))
# Concatenate pooled rows along the height axis (axis 1)
outputs = ops.concatenate(outputs, axis=1)
return outputs
def compute_output_shape(self, input_shape):
shape = (
input_shape[0],
self.output_size[0],
self.output_size[1],
input_shape[3],
)
return shape
def get_config(self):
config = {
"output_size": self.output_size,
}
base_config = super().get_config()
return {**base_config, **config}
@keras.saving.register_keras_serializable(package="medicai")
class AdaptiveAveragePooling2D(AdaptivePooling2D):
"""Adaptive Average Pooling 2D layer (channels_last).
This layer resizes the 2D input (H, W) to a fixed size using average pooling.
Args:
output_size: An integer or tuple/list of 2 integers specifying
(pooled_rows, pooled_cols).
"""
def __init__(self, output_size, **kwargs):
super().__init__(ops.mean, output_size, **kwargs)
@keras.saving.register_keras_serializable(package="medicai")
class AdaptiveMaxPooling2D(AdaptivePooling2D):
"""Adaptive Max Pooling 2D layer (channels_last).
This layer resizes the 2D input (H, W) to a fixed size using max pooling.
Args:
output_size: An integer or tuple/list of 2 integers specifying
(pooled_rows, pooled_cols).
"""
def __init__(self, output_size, **kwargs):
super().__init__(ops.max, output_size, **kwargs)
@keras.saving.register_keras_serializable(package="medicai")
class AdaptivePooling3D(layers.Layer):
"""Parent class for 3D pooling layers with adaptive kernel size.
This layer performs pooling over the input depth (D), height (H), and width (W)
dimensions such that the output dimensions match the specified `output_size`.
It supports arbitrary input sizes.
It assumes the 'channels_last' data format: (batch, D, H, W, C).
Args:
reduce_function: The reduction method to apply, e.g. `keras.ops.mean` or
`keras.ops.max`.
output_size: An integer or tuple/list of 3 integers specifying
(pooled_depth, pooled_rows, pooled_cols). The new size of
the D, H, and W dimensions.
"""
def __init__(
self,
reduce_function,
output_size,
**kwargs,
):
self.reduce_function = reduce_function
self.output_size = ensure_tuple_rep(output_size, 3)
super().__init__(**kwargs)
def build(self, input_shape):
if len(input_shape) != 5:
raise ValueError(
f"{self.__class__.__name__} expects input with 4 dims (batch, D, H, W, C), "
f"but got {input_shape}"
)
super().build(input_shape)
def call(self, inputs):
# (batch, D, H, W, C)
d_bins = self.output_size[0]
h_bins = self.output_size[1]
w_bins = self.output_size[2]
# Get input dimensions D, H, W
input_shape = ops.shape(inputs)
d = input_shape[1] # Depth
h = input_shape[2] # Height
w = input_shape[3] # Width
# Calculate the start and end indices for each bin using linspace
d_idx = ops.linspace(0, d, d_bins + 1)
h_idx = ops.linspace(0, h, h_bins + 1)
w_idx = ops.linspace(0, w, w_bins + 1)
depth_outputs = []
for i in range(d_bins):
# Calculate Depth indices (axis 1)
d_start = ops.cast(ops.floor(d_idx[i]), "int32")
d_end = ops.cast(ops.ceil(d_idx[i + 1]), "int32")
d_end = ops.where(d_end > d, d, d_end)
row_outputs = []
for j in range(h_bins):
# Calculate Height indices (axis 2)
h_start = ops.cast(ops.floor(h_idx[j]), "int32")
h_end = ops.cast(ops.ceil(h_idx[j + 1]), "int32")
h_end = ops.where(h_end > h, h, h_end)
col_outputs = []
for k in range(w_bins):
# Calculate Width indices (axis 3)
w_start = ops.cast(ops.floor(w_idx[k]), "int32")
w_end = ops.cast(ops.ceil(w_idx[k + 1]), "int32")
w_end = ops.where(w_end > w, w, w_end)
# Slicing: inputs[:, D_slice, H_slice, W_slice, :]
region = inputs[:, d_start:d_end, h_start:h_end, w_start:w_end, :]
# Reduction axes are D (1), H (2), W (3)
pooled = self.reduce_function(region, axis=[1, 2, 3], keepdims=True)
col_outputs.append(pooled)
# Concatenate pooled regions along the width axis (axis 3)
row_outputs.append(ops.concatenate(col_outputs, axis=3))
# Concatenate pooled rows along the height axis (axis 2)
depth_outputs.append(ops.concatenate(row_outputs, axis=2))
# Concatenate pooled depth slices along the depth axis (axis 1)
outputs = ops.concatenate(depth_outputs, axis=1)
return outputs
def compute_output_shape(self, input_shape):
shape = (
input_shape[0],
self.output_size[0],
self.output_size[1],
self.output_size[2],
input_shape[4],
)
return shape
def get_config(self):
config = {
"output_size": self.output_size,
}
base_config = super().get_config()
return {**base_config, **config}
@keras.saving.register_keras_serializable(package="medicai")
class AdaptiveAveragePooling3D(AdaptivePooling3D):
"""Adaptive Average Pooling 3D layer (channels_last).
This layer resizes the 3D input (D, H, W) to a fixed size using average pooling.
Args:
output_size: An integer or tuple/list of 3 integers specifying
(pooled_depth, pooled_rows, pooled_cols).
"""
def __init__(self, output_size, **kwargs):
super().__init__(ops.mean, output_size, **kwargs)
@keras.saving.register_keras_serializable(package="medicai")
class AdaptiveMaxPooling3D(AdaptivePooling3D):
"""Adaptive Max Pooling 3D layer (channels_last).
This layer resizes the 3D input (D, H, W) to a fixed size using max pooling.
Args:
output_size: An integer or tuple/list of 3 integers specifying
(pooled_depth, pooled_rows, pooled_cols).
"""
def __init__(self, output_size, **kwargs):
super().__init__(ops.max, output_size, **kwargs)
# Test
## Torch
import torch
from torch import nn
pool_size2d = (5, 7)
m2 = nn.AdaptiveAvgPool2d(pool_size2d)
input2 = torch.randn(1, 64, 8, 9) # bs, channel, h, w
output2 = m2(input2)
print('pool size ', pool_size2d)
print('input size ', input2.shape)
print('output size ', output2.shape)
pool_size3d = (5, 7, 9)
m3 = nn.AdaptiveAvgPool3d(pool_size3d)
input3 = torch.randn(1, 64, 8, 9, 10) # bs, channel, depth, h, w
output3 = m3(input3)
print('\npool size ', pool_size3d)
print('input size ', input3.shape)
print('output size ', output3.shape)
# Keras 3
km2 = AdaptiveAveragePooling2D(
output_size=pool_size2d
)
kinput2 = ops.convert_to_numpy(input2).transpose(0, 2, 3, 1)
koutput2 = km2(kinput2)
print('pool size ', pool_size2d)
print('input size ', kinput2.shape)
print('output size ', koutput2.shape)
km3 = AdaptiveAveragePooling3D(
output_size=pool_size3d
)
kinput3 = ops.convert_to_numpy(input3).transpose(0, 2, 3, 4, 1)
koutput3 = km3(kinput3)
print('\npool size ', pool_size3d)
print('input size ', kinput3.shape)
print('output size ', koutput3.shape)
np.testing.assert_allclose(
koutput2,
output2.detach().numpy().transpose(0, 2, 3, 1),
1e-6, 1e-6
) # OK
np.testing.assert_allclose(
koutput3,
output3.detach().numpy().transpose(0, 2, 3, 4, 1),
1e-6, 1e-6
) # OK
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment