Skip to content

Instantly share code, notes, and snippets.

@338rajesh
Created July 25, 2025 07:21
Show Gist options
  • Select an option

  • Save 338rajesh/4699bb3b9a258531b97433c819ee3ddc to your computer and use it in GitHub Desktop.

Select an option

Save 338rajesh/4699bb3b9a258531b97433c819ee3ddc to your computer and use it in GitHub Desktop.
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
"""
PLAN OF ACTION
- Generate (synthetic) training data on-fly during training on GPU
- Measure performance on the on-fly generated data and the original data
"""
def normalise_ellipses_batch(xy: torch.Tensor) -> torch.Tensor:
"""It normalises the ellipses to have a uniform distribution."""
# Normalise the ellipses to have a uniform distribution
xy = [xy[..., i] for i in range(xy.shape[-1])]
max_x_range = torch.max(*(ax.max() - ax.min() for ax in xy))
normalised_x = [(((a_x - a_x.min()) / max_x_range) * 2.0 - 1.0) for a_x in xy]
return torch.stack(normalised_x, dim=-1)
def generate_ellipses_batch(opt: dict):
"""It generates a batch of ellipses on the GPU."""
ne = opt["batch_size"]
num_points = opt["num_points"]
device = opt["device"]
tht = torch.linspace(0, 2 * np.pi, num_points, device=device)
r = torch.empty(ne, 1, device=device).uniform_(*opt["aspect_ratio_range"])
a = torch.empty(ne, 1, device=device).uniform_(*opt["smj_range"])
b = a / r
xc = torch.empty(ne, 1, device=device).uniform_(*opt["xc_range"])
yc = torch.empty(ne, 1, device=device).uniform_(*opt["yc_range"])
theta = torch.empty(ne, 1, device=device).uniform_(*opt["theta_range"])
x_ = a * torch.cos(tht)
y_ = b * torch.sin(tht)
x = xc + (x_ * torch.cos(theta) - y_ * torch.sin(theta))
y = yc + (x_ * torch.sin(theta) + y_ * torch.cos(theta))
ellipses = torch.stack((x, y), dim=2) # (ne, num_points, 2)
if opt.get("normalise", None):
ellipses = normalise_ellipses_batch(ellipses)
# random split point
min_idx, max_idx = [int(i * num_points) for i in opt["split_range"]]
split_index = torch.randint(min_idx, max_idx, (1,), device=device).item()
input_seq = ellipses[:, :split_index, :] # (ne, split_index, 2)
target_seq = ellipses[:, split_index:, :] # (ne, num_points - sp_index, 2)
# split_indices = torch.randint(min_idx, max_idx, (ne,), device=device)
# input_seq = [ell[:split_indices[i]] for (i, ell) in enumerate(ellipses)]
# target_seq = [ell[split_indices[i]:] for (i, ell) in enumerate(ellipses)]
# input_seq=torch.nn.utils.rnn.pad_sequence(input_seq, batch_first=True)
# target_seq=torch.nn.utils.rnn.pad_sequence(target_seq, batch_first=True)
# # Here, input_seq & target_seq are of shape (batch_size, max_seq_len, 2)
# # From a list of B sequences, each with s_i length, a single tensor of
# # shape (B, max(s_i), 2) is created, where shorter sequences are padded
# # with zeros. batch_first=True means that the first dimension is
# # the batch size, while the next dim is the maximum sequence length.
return input_seq, target_seq
def serve_tensor_as_np(tensor):
"""Serve a tensor to the CPU and convert it to numpy."""
if isinstance(tensor, torch.Tensor):
if tensor.is_cuda:
tensor = tensor.cpu()
return tensor.numpy()
elif isinstance(tensor, np.ndarray):
return tensor
else:
raise ValueError("Input is not a torch.Tensor.")
def plot_curves(input_curves, target_curves, pred_curves, save_path=None):
"""Plot the input, target and predicted curves."""
input_curves = serve_tensor_as_np(input_curves)
target_curves = serve_tensor_as_np(target_curves)
pred_curves = serve_tensor_as_np(pred_curves)
assert (
input_curves.shape[0] == target_curves.shape[0] == pred_curves.shape[0]
), "Input, target and predicted curves must have the same batch size."
bs = input_curves.shape[0]
nc = np.sqrt(bs).astype(int)
nr = bs // nc
_, axs = plt.subplots(nr, nc, figsize=(nc * 5, nr * 5))
axs = axs.flatten()
for i in range(bs):
axs[i].plot(
input_curves[i, :, 0], input_curves[i, :, 1], color="r", lw=3.0, ls="solid"
)
axs[i].plot(
target_curves[i, :, 0],
target_curves[i, :, 1],
color="k",
ls="dotted",
lw=2.0,
)
axs[i].plot(
pred_curves[i, :, 0], pred_curves[i, :, 1], color="b", ls="solid", lw=2.0
)
axs[i].set_title(f"Curve {i + 1}")
axs[i].axis("equal")
plt.tight_layout()
if save_path is not None:
plt.savefig(save_path)
plt.close()
return
def train(model, options):
"""Train the model with data generated on the fly."""
num_steps = options["num_steps"]
options["device"] = torch.device(
options.get("device", "cuda" if torch.cuda.is_available() else "cpu")
)
model.to(options["device"])
print("Given Arguments:")
print("=========")
for k, v in options.items():
print(f"{k}: {v}")
print("\n\nModel Parameters:")
print("=========")
print(f"Name: {model.__class__.__name__}")
print(f"Cell Type: {model.cell_type}")
print("=========")
val_inp_seq, val_target_seq = generate_ellipses_batch({**options, "batch_size": 16})
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=options["lr"])
loop = tqdm(range(num_steps), desc="Training", unit="step")
step_losses = []
for step in loop:
model.train()
input_seq, target_seq = generate_ellipses_batch(options)
# evaluate the teacher forcing ratio
min_tf_ratio = 0.1
if step < options["tf_decay_steps"]:
tf_ratio = 1.0 - (step / options["tf_decay_steps"]) * (1.0 - min_tf_ratio)
else:
tf_ratio = min_tf_ratio
# forward pass
pred_seq = model(
input_seq,
target_seq=target_seq,
teacher_forcing_ratio=tf_ratio,
)
loss = criterion(pred_seq, target_seq)
# backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
step_losses.append(loss.item())
loop.desc = f"[(Step {step}/{num_steps}) | Loss: {loss.item():.4f}]"
if step % options["log_interval"] == 0:
model.eval()
pred_seq_ = model.predict(val_inp_seq, val_target_seq.size(1))
plot_curves(
input_curves=val_inp_seq,
target_curves=val_target_seq,
pred_curves=pred_seq_,
save_path=options["log_dir"].joinpath(f"ellipse_{step}.png"),
)
fig, axs = plt.subplots(figsize=(8, 8))
axs.plot(step_losses)
axs.set_title("step Losses")
axs.set_xlabel("step")
axs.set_ylabel("Loss")
axs.set_yscale("log")
axs.grid()
plt.tight_layout()
plt.savefig(options["log_dir"].joinpath("loss.png"))
plt.close(fig)
return
class EllipseCompletionModelRNN(torch.nn.Module):
def __init__(
self,
input_size=2,
hidden_size=64,
output_size=2,
num_layers=2,
dropout=0.0,
cell_type="LSTM",
):
"""
A sequence to sequence prediction model for ellipses.
"""
super().__init__()
self.cell_type = cell_type
self.output_size = output_size
# encoder
self.encoder = getattr(torch.nn, cell_type)(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
batch_first=True,
)
# decoder
self.decoder = getattr(torch.nn, cell_type)(
input_size=output_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
batch_first=True,
)
# fully connected layer
self.fc = torch.nn.Linear(hidden_size, output_size)
def forward(
self,
input_seq,
*,
target_seq=None,
target_len=None,
start_point="INP", # "INP" or "ZERO"
teacher_forcing_ratio=1.0, # i.e. 100% teacher forcing by default
):
batch_size = input_seq.size(0)
device = input_seq.device
# ENCODER
_, hidden = self.encoder(input_seq)
# Handle LSTM hidden state
if self.cell_type == "LSTM":
if isinstance(hidden, tuple):
hidden = (hidden[0].contiguous(), hidden[1].contiguous())
else:
hidden = hidden.contiguous()
else:
hidden = hidden.contiguous()
# Initialize the decoder input
if start_point == "INP": # use the last point of the input sequence
decoder_inp = input_seq[:, -1:, :]
else:
decoder_inp = torch.zeros(
batch_size,
1,
self.output_size,
device=device,
)
# Find the target length
if target_len is None:
if target_seq is None:
raise ValueError(
"target_len & target_seq cannot be None simultaneously."
)
target_len = target_seq.size(1)
outputs = []
for t in range(target_len):
out, hidden = self.decoder(decoder_inp, hidden)
out = self.fc(out)
# append the output to the list
outputs.append(out)
# prepare the next input for the decoder
if target_seq is None:
decoder_inp = out # Auto-regressive for inference
else:
if self.training: # Teacher forcing
if torch.rand(1).item() < teacher_forcing_ratio:
decoder_inp = target_seq[:, t].unsqueeze(1)
else:
decoder_inp = out
return torch.cat(outputs, dim=1)
def predict(self, input_seq, num_points):
"""
Predict the next num_points points of the ellipse.
"""
self.eval()
with torch.no_grad():
predictions = self.forward(
input_seq=input_seq,
target_len=num_points,
start_point="INP",
)
return predictions
def main():
CSD = Path(__file__).resolve().parent
options = {
"num_steps": 10000,
"batch_size": 256,
"num_points": 32,
"normalise": True,
"aspect_ratio_range": (1.0, 5.0),
"smj_range": (0.2, 10.5),
"xc_range": (-10.0, 10.0),
"yc_range": (-10.0, 10.0),
"theta_range": (0, 2.0 * np.pi),
"split_range": (0.35, 0.75),
"lr": 1e-3,
"log_interval": 200,
"log_dir": CSD.joinpath("logs"),
"device": "cuda",
"model_kwargs": {
"cell_type": "LSTM",
"input_size": 2,
"hidden_size": 32,
"num_layers": 2,
"dropout": 0.0,
"output_size": 2,
},
}
options["teacher_forcing_ratio"] = 1.0 # Start with full teacher forcing
options["tf_decay_steps"] = options["num_steps"] // 2
# That is decay over half the number of steps
options["log_dir"].mkdir(parents=True, exist_ok=True)
for file in options["log_dir"].glob("*.png"):
file.unlink()
model = EllipseCompletionModelRNN(**options["model_kwargs"])
train(model, options)
return
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment