Created
July 25, 2025 07:21
-
-
Save 338rajesh/4699bb3b9a258531b97433c819ee3ddc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| """ | |
| PLAN OF ACTION | |
| - Generate (synthetic) training data on-fly during training on GPU | |
| - Measure performance on the on-fly generated data and the original data | |
| """ | |
| def normalise_ellipses_batch(xy: torch.Tensor) -> torch.Tensor: | |
| """It normalises the ellipses to have a uniform distribution.""" | |
| # Normalise the ellipses to have a uniform distribution | |
| xy = [xy[..., i] for i in range(xy.shape[-1])] | |
| max_x_range = torch.max(*(ax.max() - ax.min() for ax in xy)) | |
| normalised_x = [(((a_x - a_x.min()) / max_x_range) * 2.0 - 1.0) for a_x in xy] | |
| return torch.stack(normalised_x, dim=-1) | |
| def generate_ellipses_batch(opt: dict): | |
| """It generates a batch of ellipses on the GPU.""" | |
| ne = opt["batch_size"] | |
| num_points = opt["num_points"] | |
| device = opt["device"] | |
| tht = torch.linspace(0, 2 * np.pi, num_points, device=device) | |
| r = torch.empty(ne, 1, device=device).uniform_(*opt["aspect_ratio_range"]) | |
| a = torch.empty(ne, 1, device=device).uniform_(*opt["smj_range"]) | |
| b = a / r | |
| xc = torch.empty(ne, 1, device=device).uniform_(*opt["xc_range"]) | |
| yc = torch.empty(ne, 1, device=device).uniform_(*opt["yc_range"]) | |
| theta = torch.empty(ne, 1, device=device).uniform_(*opt["theta_range"]) | |
| x_ = a * torch.cos(tht) | |
| y_ = b * torch.sin(tht) | |
| x = xc + (x_ * torch.cos(theta) - y_ * torch.sin(theta)) | |
| y = yc + (x_ * torch.sin(theta) + y_ * torch.cos(theta)) | |
| ellipses = torch.stack((x, y), dim=2) # (ne, num_points, 2) | |
| if opt.get("normalise", None): | |
| ellipses = normalise_ellipses_batch(ellipses) | |
| # random split point | |
| min_idx, max_idx = [int(i * num_points) for i in opt["split_range"]] | |
| split_index = torch.randint(min_idx, max_idx, (1,), device=device).item() | |
| input_seq = ellipses[:, :split_index, :] # (ne, split_index, 2) | |
| target_seq = ellipses[:, split_index:, :] # (ne, num_points - sp_index, 2) | |
| # split_indices = torch.randint(min_idx, max_idx, (ne,), device=device) | |
| # input_seq = [ell[:split_indices[i]] for (i, ell) in enumerate(ellipses)] | |
| # target_seq = [ell[split_indices[i]:] for (i, ell) in enumerate(ellipses)] | |
| # input_seq=torch.nn.utils.rnn.pad_sequence(input_seq, batch_first=True) | |
| # target_seq=torch.nn.utils.rnn.pad_sequence(target_seq, batch_first=True) | |
| # # Here, input_seq & target_seq are of shape (batch_size, max_seq_len, 2) | |
| # # From a list of B sequences, each with s_i length, a single tensor of | |
| # # shape (B, max(s_i), 2) is created, where shorter sequences are padded | |
| # # with zeros. batch_first=True means that the first dimension is | |
| # # the batch size, while the next dim is the maximum sequence length. | |
| return input_seq, target_seq | |
| def serve_tensor_as_np(tensor): | |
| """Serve a tensor to the CPU and convert it to numpy.""" | |
| if isinstance(tensor, torch.Tensor): | |
| if tensor.is_cuda: | |
| tensor = tensor.cpu() | |
| return tensor.numpy() | |
| elif isinstance(tensor, np.ndarray): | |
| return tensor | |
| else: | |
| raise ValueError("Input is not a torch.Tensor.") | |
| def plot_curves(input_curves, target_curves, pred_curves, save_path=None): | |
| """Plot the input, target and predicted curves.""" | |
| input_curves = serve_tensor_as_np(input_curves) | |
| target_curves = serve_tensor_as_np(target_curves) | |
| pred_curves = serve_tensor_as_np(pred_curves) | |
| assert ( | |
| input_curves.shape[0] == target_curves.shape[0] == pred_curves.shape[0] | |
| ), "Input, target and predicted curves must have the same batch size." | |
| bs = input_curves.shape[0] | |
| nc = np.sqrt(bs).astype(int) | |
| nr = bs // nc | |
| _, axs = plt.subplots(nr, nc, figsize=(nc * 5, nr * 5)) | |
| axs = axs.flatten() | |
| for i in range(bs): | |
| axs[i].plot( | |
| input_curves[i, :, 0], input_curves[i, :, 1], color="r", lw=3.0, ls="solid" | |
| ) | |
| axs[i].plot( | |
| target_curves[i, :, 0], | |
| target_curves[i, :, 1], | |
| color="k", | |
| ls="dotted", | |
| lw=2.0, | |
| ) | |
| axs[i].plot( | |
| pred_curves[i, :, 0], pred_curves[i, :, 1], color="b", ls="solid", lw=2.0 | |
| ) | |
| axs[i].set_title(f"Curve {i + 1}") | |
| axs[i].axis("equal") | |
| plt.tight_layout() | |
| if save_path is not None: | |
| plt.savefig(save_path) | |
| plt.close() | |
| return | |
| def train(model, options): | |
| """Train the model with data generated on the fly.""" | |
| num_steps = options["num_steps"] | |
| options["device"] = torch.device( | |
| options.get("device", "cuda" if torch.cuda.is_available() else "cpu") | |
| ) | |
| model.to(options["device"]) | |
| print("Given Arguments:") | |
| print("=========") | |
| for k, v in options.items(): | |
| print(f"{k}: {v}") | |
| print("\n\nModel Parameters:") | |
| print("=========") | |
| print(f"Name: {model.__class__.__name__}") | |
| print(f"Cell Type: {model.cell_type}") | |
| print("=========") | |
| val_inp_seq, val_target_seq = generate_ellipses_batch({**options, "batch_size": 16}) | |
| criterion = torch.nn.MSELoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=options["lr"]) | |
| loop = tqdm(range(num_steps), desc="Training", unit="step") | |
| step_losses = [] | |
| for step in loop: | |
| model.train() | |
| input_seq, target_seq = generate_ellipses_batch(options) | |
| # evaluate the teacher forcing ratio | |
| min_tf_ratio = 0.1 | |
| if step < options["tf_decay_steps"]: | |
| tf_ratio = 1.0 - (step / options["tf_decay_steps"]) * (1.0 - min_tf_ratio) | |
| else: | |
| tf_ratio = min_tf_ratio | |
| # forward pass | |
| pred_seq = model( | |
| input_seq, | |
| target_seq=target_seq, | |
| teacher_forcing_ratio=tf_ratio, | |
| ) | |
| loss = criterion(pred_seq, target_seq) | |
| # backward pass | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| step_losses.append(loss.item()) | |
| loop.desc = f"[(Step {step}/{num_steps}) | Loss: {loss.item():.4f}]" | |
| if step % options["log_interval"] == 0: | |
| model.eval() | |
| pred_seq_ = model.predict(val_inp_seq, val_target_seq.size(1)) | |
| plot_curves( | |
| input_curves=val_inp_seq, | |
| target_curves=val_target_seq, | |
| pred_curves=pred_seq_, | |
| save_path=options["log_dir"].joinpath(f"ellipse_{step}.png"), | |
| ) | |
| fig, axs = plt.subplots(figsize=(8, 8)) | |
| axs.plot(step_losses) | |
| axs.set_title("step Losses") | |
| axs.set_xlabel("step") | |
| axs.set_ylabel("Loss") | |
| axs.set_yscale("log") | |
| axs.grid() | |
| plt.tight_layout() | |
| plt.savefig(options["log_dir"].joinpath("loss.png")) | |
| plt.close(fig) | |
| return | |
| class EllipseCompletionModelRNN(torch.nn.Module): | |
| def __init__( | |
| self, | |
| input_size=2, | |
| hidden_size=64, | |
| output_size=2, | |
| num_layers=2, | |
| dropout=0.0, | |
| cell_type="LSTM", | |
| ): | |
| """ | |
| A sequence to sequence prediction model for ellipses. | |
| """ | |
| super().__init__() | |
| self.cell_type = cell_type | |
| self.output_size = output_size | |
| # encoder | |
| self.encoder = getattr(torch.nn, cell_type)( | |
| input_size=input_size, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| dropout=dropout, | |
| batch_first=True, | |
| ) | |
| # decoder | |
| self.decoder = getattr(torch.nn, cell_type)( | |
| input_size=output_size, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| dropout=dropout, | |
| batch_first=True, | |
| ) | |
| # fully connected layer | |
| self.fc = torch.nn.Linear(hidden_size, output_size) | |
| def forward( | |
| self, | |
| input_seq, | |
| *, | |
| target_seq=None, | |
| target_len=None, | |
| start_point="INP", # "INP" or "ZERO" | |
| teacher_forcing_ratio=1.0, # i.e. 100% teacher forcing by default | |
| ): | |
| batch_size = input_seq.size(0) | |
| device = input_seq.device | |
| # ENCODER | |
| _, hidden = self.encoder(input_seq) | |
| # Handle LSTM hidden state | |
| if self.cell_type == "LSTM": | |
| if isinstance(hidden, tuple): | |
| hidden = (hidden[0].contiguous(), hidden[1].contiguous()) | |
| else: | |
| hidden = hidden.contiguous() | |
| else: | |
| hidden = hidden.contiguous() | |
| # Initialize the decoder input | |
| if start_point == "INP": # use the last point of the input sequence | |
| decoder_inp = input_seq[:, -1:, :] | |
| else: | |
| decoder_inp = torch.zeros( | |
| batch_size, | |
| 1, | |
| self.output_size, | |
| device=device, | |
| ) | |
| # Find the target length | |
| if target_len is None: | |
| if target_seq is None: | |
| raise ValueError( | |
| "target_len & target_seq cannot be None simultaneously." | |
| ) | |
| target_len = target_seq.size(1) | |
| outputs = [] | |
| for t in range(target_len): | |
| out, hidden = self.decoder(decoder_inp, hidden) | |
| out = self.fc(out) | |
| # append the output to the list | |
| outputs.append(out) | |
| # prepare the next input for the decoder | |
| if target_seq is None: | |
| decoder_inp = out # Auto-regressive for inference | |
| else: | |
| if self.training: # Teacher forcing | |
| if torch.rand(1).item() < teacher_forcing_ratio: | |
| decoder_inp = target_seq[:, t].unsqueeze(1) | |
| else: | |
| decoder_inp = out | |
| return torch.cat(outputs, dim=1) | |
| def predict(self, input_seq, num_points): | |
| """ | |
| Predict the next num_points points of the ellipse. | |
| """ | |
| self.eval() | |
| with torch.no_grad(): | |
| predictions = self.forward( | |
| input_seq=input_seq, | |
| target_len=num_points, | |
| start_point="INP", | |
| ) | |
| return predictions | |
| def main(): | |
| CSD = Path(__file__).resolve().parent | |
| options = { | |
| "num_steps": 10000, | |
| "batch_size": 256, | |
| "num_points": 32, | |
| "normalise": True, | |
| "aspect_ratio_range": (1.0, 5.0), | |
| "smj_range": (0.2, 10.5), | |
| "xc_range": (-10.0, 10.0), | |
| "yc_range": (-10.0, 10.0), | |
| "theta_range": (0, 2.0 * np.pi), | |
| "split_range": (0.35, 0.75), | |
| "lr": 1e-3, | |
| "log_interval": 200, | |
| "log_dir": CSD.joinpath("logs"), | |
| "device": "cuda", | |
| "model_kwargs": { | |
| "cell_type": "LSTM", | |
| "input_size": 2, | |
| "hidden_size": 32, | |
| "num_layers": 2, | |
| "dropout": 0.0, | |
| "output_size": 2, | |
| }, | |
| } | |
| options["teacher_forcing_ratio"] = 1.0 # Start with full teacher forcing | |
| options["tf_decay_steps"] = options["num_steps"] // 2 | |
| # That is decay over half the number of steps | |
| options["log_dir"].mkdir(parents=True, exist_ok=True) | |
| for file in options["log_dir"].glob("*.png"): | |
| file.unlink() | |
| model = EllipseCompletionModelRNN(**options["model_kwargs"]) | |
| train(model, options) | |
| return | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment