Skip to content

Instantly share code, notes, and snippets.

@ianfhunter
Last active June 29, 2025 06:51
Show Gist options
  • Select an option

  • Save ianfhunter/1eeb4e31d42f14600412951004b68cf9 to your computer and use it in GitHub Desktop.

Select an option

Save ianfhunter/1eeb4e31d42f14600412951004b68cf9 to your computer and use it in GitHub Desktop.
The code for JIRA-Net as descrived in the SIGBOVIK 2025 paper submission
from torch.nn import Module
from torch import nn
import numpy as np
import os
import torch
from torchsummary import summary
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Compose, Resize
import random
def print_model_weights(model):
print("Model Weights...")
for name, param in model.named_parameters():
if param.requires_grad: # Only print trainable parameters
print(f"Layer: {name}")
print(f"Shape: {param.data.shape}")
print(f"Weights:\n{param.data}\n")
def save_model_weights(model, output_dir="weights_dump"):
os.makedirs(output_dir, exist_ok=True)
for name, param in model.named_parameters():
# Save weights to a text file
file_path = os.path.join(output_dir, f"{name.replace('.', '_')}.txt")
with open(file_path, "w") as f:
# Write the weights in a readable format
f.write(f"Layer: {name}\n")
f.write(f"Weights:\n{param.data.cpu().numpy()}\n")
if param.requires_grad:
f.write(f"Gradients:\n{param.grad.cpu().numpy() if param.grad is not None else None}\n")
print(f"Weights saved in {output_dir}")
def save_input(input_tensor, file_path="input.txt"):
with open(file_path, "w") as f:
f.write("Input Tensor:\n")
f.write(np.array2string(input_tensor.cpu().numpy(), precision=5, separator=','))
print(f"Input saved to {file_path}")
class JIRANet(Module):
def __init__(self):
super(JIRANet, self).__init__()
# Block 1 - 1x14x14
inChannels = 1
self.c1a = nn.Conv2d(inChannels, 1, 3, padding=1)
self.c1b = nn.Conv2d(inChannels, 1, 5, padding=2)
self.c1c = nn.Conv2d(inChannels, 1, 7, padding=3)
self.pool1 = nn.MaxPool2d(2)
# Block 2 - 1x7x7
inChannels = 1
self.c2a = nn.Conv2d(in_channels=inChannels, out_channels=1, kernel_size=(1, 3), padding=(0, 1))
self.c2b = nn.Conv2d(in_channels=inChannels, out_channels=1, kernel_size=(3, 1), padding=(1, 0))
self.c2c = nn.Conv2d(inChannels, 2, 1)
self.pool2 = nn.MaxPool2d(2)
# Block 3 - 1x5x5
inChannels = 4
self.c3a = nn.Conv2d(inChannels, 2, 3, padding=1)
self.c3b = nn.Conv2d(inChannels, 4, 1)
self.pool3 = nn.MaxPool2d(2)
# Block 4 - 1x3x3
inChannels = 6
c4_out_chan = 5
c5_out_chan = 4
self.c4 = nn.Conv2d(inChannels, c4_out_chan, 1)
self.c5 = nn.Conv2d(c4_out_chan, c5_out_chan, 3, padding=0)
fc1_outC = 19
fc2_outC = 10
final_out = 10
self.fc1 = nn.Linear(c5_out_chan, fc1_outC)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(fc1_outC, fc2_outC)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(fc2_outC, final_out)
self.relu5 = nn.ReLU()
def forward(self, netIn):
# print("Forward...")
# Block 1 1x14x14
c1a = self.c1a(netIn) # x
c1b = self.c1b(netIn) # a b c
c1c = self.c1c(netIn)
add1 = c1a + c1b + c1c
pool1 = self.pool1(add1)
# Block 2 1x7x7
c2a = self.c2a(pool1) # x
c2b = self.c2b(pool1) # a b c
c2c = self.c2c(pool1)
concat2 = torch.concat((c2a, c2b, c2c), dim=1)
pool2 = self.pool2(concat2)
# Block 3 1x5x5
c3a = self.c3a(pool2) # x
c3b = self.c3b(pool2) # a b c
concat3 = torch.concat((c3a, c3b), dim=1)
# pool3 = self.pool3(concat3)
# Block 4
c4 = self.c4(concat3)
c5 = self.c5(c4)
y = c5.view(c5.shape[0], -1)
y = self.fc1(y)
y = self.relu3(y)
y = self.fc2(y)
y = self.relu4(y)
y = self.fc3(y)
y = self.relu5(y)
return y
my_transforms = Compose([
ToTensor()
])
jira_transforms = Compose([
Resize((14, 14)),
ToTensor()
])
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 256
seed = 1688469032097200
torch.manual_seed(seed)
random.seed(0)
print("Seed: ", torch.seed() )
train_dataset = mnist.MNIST(root='./train', train=True, transform=jira_transforms, download=True)
test_dataset = mnist.MNIST(root='./test', train=False, transform=jira_transforms, download=True)
model = JIRANet().to(device)
stats = summary(model, (1, 14, 14), col_names=["output_size", "num_params"])
print(stats)
print("Train...")
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
test_input, test_label = next(iter(test_loader))
print("Test Input Shape:", test_input)
save_input(test_input)
import timeit
execution_time = timeit.timeit(lambda: model.forward(test_input), number=1)
print(f"Function took {execution_time:.6f} seconds to execute.")
sgd = SGD(model.parameters(), lr=1e-1)
loss_fn = CrossEntropyLoss()
all_epoch = 100
prev_acc = 0
for current_epoch in range(all_epoch):
model.train()
for idx, (train_x, train_label) in enumerate(train_loader):
train_x = train_x.to(device)
train_label = train_label.to(device)
sgd.zero_grad()
predict_y = model(train_x.float())
loss = loss_fn(predict_y, train_label.long())
loss.backward()
sgd.step()
all_correct_num = 0
all_sample_num = 0
model.eval()
for idx, (test_x, test_label) in enumerate(test_loader):
test_x = test_x.to(device)
test_label = test_label.to(device)
predict_y = model(test_x.float()).detach()
predict_y =torch.argmax(predict_y, dim=-1)
current_correct_num = predict_y == test_label
all_correct_num += np.sum(current_correct_num.to('cpu').numpy(), axis=-1)
all_sample_num += current_correct_num.shape[0]
acc = all_correct_num / all_sample_num
print(f'epoch #{current_epoch}. accuracy: {acc}')
if not os.path.isdir("models"):
os.mkdir("models")
torch.save(model, 'models/mnist_{:.3f}.pkl'.format(acc))
if np.abs(acc - prev_acc) < 1e-4:
break
prev_acc = acc
print("Model finished training")
# Example: Apply this to your model
print_model_weights(model)
save_model_weights(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment