Last active
October 25, 2025 00:12
-
-
Save gregchu/f709b8a4389fa5b8e850d56b28227d03 to your computer and use it in GitHub Desktop.
Contrastive Learning of Random Vectors
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Contrastive Learning Script | |
| Aims: | |
| - Train model to move embeddings closer that are of the same class | |
| - Visualize embeddings in 2D space before and after contrastive learning | |
| - Calculate and track intraclass variance metrics | |
| """ | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torch.utils.data import TensorDataset, DataLoader | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.manifold import TSNE | |
| # Configuration | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {DEVICE}") | |
| NUM_SAMPLES = 300 | |
| EMBEDDING_SIZE = 64 | |
| NUM_CLASSES = 3 | |
| OUTPUT_DIR = "tsne_plots" | |
| # Create output directory | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| def calculate_intraclass_metrics(embeddings, labels): | |
| """ | |
| Calculate intraclass distance and variance metrics. | |
| Args: | |
| embeddings (torch.Tensor or np.ndarray): shape (n_samples, embedding_dim) | |
| labels (torch.Tensor or np.ndarray): shape (n_samples,) | |
| Returns: | |
| dict: metrics including mean intraclass distance and variance per class | |
| """ | |
| if isinstance(embeddings, torch.Tensor): | |
| embeddings = embeddings.cpu().numpy() | |
| if isinstance(labels, torch.Tensor): | |
| labels = labels.cpu().numpy() | |
| metrics = { | |
| 'mean_intraclass_distance': 0.0, | |
| 'mean_intraclass_variance': 0.0, | |
| 'per_class_distance': {}, | |
| 'per_class_variance': {} | |
| } | |
| unique_labels = np.unique(labels) | |
| all_distances = [] | |
| all_variances = [] | |
| for label in unique_labels: | |
| # Get all embeddings for this class | |
| class_mask = labels == label | |
| class_embeddings = embeddings[class_mask] | |
| if len(class_embeddings) < 2: | |
| continue | |
| # Calculate pairwise distances within class | |
| distances = [] | |
| for i in range(len(class_embeddings)): | |
| for j in range(i + 1, len(class_embeddings)): | |
| dist = np.linalg.norm(class_embeddings[i] - class_embeddings[j]) | |
| distances.append(dist) | |
| # Calculate variance (spread from centroid) | |
| centroid = np.mean(class_embeddings, axis=0) | |
| variances = [np.linalg.norm(emb - centroid) for emb in class_embeddings] | |
| mean_dist = np.mean(distances) if distances else 0.0 | |
| mean_var = np.mean(variances) if variances else 0.0 | |
| metrics['per_class_distance'][int(label)] = mean_dist | |
| metrics['per_class_variance'][int(label)] = mean_var | |
| all_distances.extend(distances) | |
| all_variances.extend(variances) | |
| metrics['mean_intraclass_distance'] = np.mean(all_distances) if all_distances else 0.0 | |
| metrics['mean_intraclass_variance'] = np.mean(all_variances) if all_variances else 0.0 | |
| return metrics | |
| def plot_embeddings(embeddings, labels, title, save_path=None, metrics=None, axis_limits=None): | |
| """ | |
| Plot embeddings using TSNE and optionally save to file. | |
| Args: | |
| embeddings (np.ndarray): shape (n_samples, embedding_dim) | |
| labels (torch.Tensor or np.ndarray): shape (n_samples,) | |
| title (str): plot title | |
| save_path (str, optional): path to save the plot | |
| metrics (dict, optional): intraclass metrics to display | |
| axis_limits (tuple, optional): (xmin, xmax, ymin, ymax) for consistent axes | |
| """ | |
| if isinstance(labels, torch.Tensor): | |
| labels_np = labels.cpu().numpy() | |
| else: | |
| labels_np = labels | |
| tsne = TSNE(n_components=2, random_state=42) | |
| reduced = tsne.fit_transform(embeddings) | |
| fig, ax = plt.subplots(figsize=(10, 8), facecolor='white') | |
| ax.set_facecolor('white') | |
| unique_labels = np.unique(labels_np) | |
| for label in unique_labels: | |
| idxs = labels_np == label | |
| ax.scatter(reduced[idxs, 0], reduced[idxs, 1], label=f"Class {label}", alpha=0.6, s=50) | |
| ax.legend() | |
| # Set fixed axis limits if provided | |
| if axis_limits is not None: | |
| ax.set_xlim(axis_limits[0], axis_limits[1]) | |
| ax.set_ylim(axis_limits[2], axis_limits[3]) | |
| # Add mean intraclass distance to title if metrics provided | |
| if metrics: | |
| title_with_metric = f"{title}\nMean Intraclass Distance: {metrics['mean_intraclass_distance']:.2f}" | |
| ax.set_title(title_with_metric, fontsize=14, fontweight='bold') | |
| else: | |
| ax.set_title(title, fontsize=14, fontweight='bold') | |
| # Add metrics as text if provided | |
| if metrics: | |
| metrics_text = f"Mean Intraclass Variance: {metrics['mean_intraclass_variance']:.2f}\n" | |
| metrics_text += "\nPer-Class Variance:\n" | |
| for cls, var in sorted(metrics['per_class_variance'].items()): | |
| metrics_text += f" Class {cls}: {var:.2f}\n" | |
| ax.text(0.02, 0.98, metrics_text, transform=ax.transAxes, | |
| fontsize=9, verticalalignment='top', | |
| bbox=dict(boxstyle='round', facecolor='white', edgecolor='black', alpha=0.8)) | |
| ax.set_xlabel("TSNE Dimension 1") | |
| ax.set_ylabel("TSNE Dimension 2") | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white', edgecolor='white') | |
| print(f"Saved plot to {save_path}") | |
| plt.close() | |
| class ContrastiveEncoder(nn.Module): | |
| def __init__(self, input_size, embedding_size): | |
| super().__init__() | |
| self.model = nn.Sequential( | |
| nn.Linear(input_size, 128), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(128), | |
| nn.Linear(128, embedding_size) | |
| ) | |
| def forward(self, x): | |
| x = self.model(x) | |
| x = F.normalize(x, p=2, dim=1) | |
| return x | |
| class TripletLoss(nn.Module): | |
| def __init__(self, margin=1.0): | |
| super().__init__() | |
| self.margin = margin | |
| def forward(self, anchor, positive, negative): | |
| pos_dist = F.pairwise_distance(anchor, positive) | |
| neg_dist = F.pairwise_distance(anchor, negative) | |
| losses = F.relu(pos_dist - neg_dist + self.margin) | |
| return losses.mean() | |
| class InfoNCELoss(nn.Module): | |
| """ | |
| InfoNCE Loss from CPC paper (van den Oord et al., 2018). | |
| Uses cross-entropy with in-batch negatives. | |
| Each anchor[i] should match with positive[i] (at index i). | |
| All other samples (positives + negatives) serve as negatives. | |
| """ | |
| def __init__(self, temperature=0.1): | |
| super(InfoNCELoss, self).__init__() | |
| self.temperature = temperature | |
| def forward(self, anchor, positive, negative): | |
| """ | |
| Args: | |
| anchor: (batch_size, embedding_dim) | |
| positive: (batch_size, embedding_dim) - positive samples | |
| negative: (batch_size, embedding_dim) - negative samples | |
| For each anchor[i], positive[i] is the correct match (label=i). | |
| All samples in [positive, negative] serve as the candidate set. | |
| """ | |
| batch_size = anchor.shape[0] | |
| # Normalize embeddings for cosine similarity | |
| anchor = F.normalize(anchor, p=2, dim=-1) | |
| positive = F.normalize(positive, p=2, dim=-1) | |
| negative = F.normalize(negative, p=2, dim=-1) | |
| # Concatenate positives and negatives to form candidate set | |
| # Shape: (batch_size, 2 * batch_size) | |
| candidates = torch.cat([positive, negative], dim=0) # (2*batch_size, embedding_dim) | |
| # Compute similarity matrix: (batch_size, 2*batch_size) | |
| similarity_matrix = torch.matmul(anchor, candidates.T) | |
| similarity_matrix /= self.temperature | |
| # Labels: anchor[i] should match positive[i] which is at index i in candidates | |
| labels = torch.arange(batch_size, device=similarity_matrix.device) | |
| # Cross entropy: -log(exp(sim[i,i]) / sum_j(exp(sim[i,j]))) | |
| return F.cross_entropy(similarity_matrix, labels) | |
| def get_triplet_batch(data, labels, batch_size): | |
| """ | |
| Sample a batch of triplets (anchor, positive, negative) for triplet loss training. | |
| For each triplet: | |
| - Anchor: a random sample from data | |
| - Positive: another random sample with the same label as the anchor | |
| - Negative: a random sample with a different label from the anchor | |
| Args: | |
| data (torch.Tensor): shape (n_samples, embedding_dim) | |
| labels (torch.Tensor): shape (n_samples,), class label for each sample | |
| batch_size (int): number of triplets to sample | |
| Returns: | |
| anchors (torch.Tensor): shape (batch_size, embedding_dim) | |
| positives (torch.Tensor): shape (batch_size, embedding_dim) | |
| negatives (torch.Tensor): shape (batch_size, embedding_dim) | |
| """ | |
| n = len(data) | |
| dim = len(data[0]) | |
| anchors, positives, negatives = [], [], [] | |
| for _ in range(batch_size): | |
| anchor_idx = np.random.randint(0, n) | |
| anchor_label = labels[anchor_idx] | |
| positive_idx = np.random.choice(np.where(labels == anchor_label)[0]) | |
| positive = data[positive_idx] | |
| negative_idx = np.random.choice(np.where(labels != anchor_label)[0]) | |
| negative = data[negative_idx] | |
| anchors.append(data[anchor_idx]) | |
| positives.append(positive) | |
| negatives.append(negative) | |
| return (torch.stack(anchors), torch.stack(positives), torch.stack(negatives)) | |
| def train_model(data, labels, dataloader, num_epochs=20, batch_size=32, config=None, loss_name="InfoNCE"): | |
| """ | |
| Train the ContrastiveEncoder using specified loss function. | |
| Saves TSNE plots and calculates intraclass metrics every epoch. | |
| Args: | |
| data (torch.Tensor): full dataset | |
| labels (torch.Tensor): full labels | |
| dataloader (DataLoader): dataloader for batching | |
| num_epochs (int): number of training epochs | |
| batch_size (int): size of each training batch | |
| config (dict): configuration dictionary containing hyperparameters | |
| loss_name (str): name of the loss function for organizing output | |
| Returns: | |
| model (ContrastiveEncoder): trained model | |
| """ | |
| model = ContrastiveEncoder(EMBEDDING_SIZE, EMBEDDING_SIZE).to(DEVICE) | |
| # Select loss function based on name | |
| if loss_name.lower() == "infonce": | |
| criterion = InfoNCELoss(temperature=config.get('temperature', 0.1)) | |
| elif loss_name.lower() == "triplet": | |
| criterion = TripletLoss(margin=config.get('margin', 1.0)) | |
| else: | |
| raise ValueError(f"Unsupported loss function name: {loss_name}") | |
| optimizer = optim.Adam(model.parameters(), lr=config.get('learning_rate', 1e-3)) | |
| # Create subfolder for this loss function | |
| loss_output_dir = os.path.join(OUTPUT_DIR, loss_name) | |
| os.makedirs(loss_output_dir, exist_ok=True) | |
| metrics_history = [] | |
| # Save initial embeddings (raw data, not passed through model) | |
| initial_embeddings = data.cpu().numpy() | |
| initial_metrics = calculate_intraclass_metrics(initial_embeddings, labels) | |
| plot_embeddings( | |
| initial_embeddings, | |
| labels, | |
| "Initial Embeddings (Epoch 0)", | |
| save_path=os.path.join(loss_output_dir, "epoch_000.png"), | |
| metrics=initial_metrics, | |
| ) | |
| metrics_history.append({ | |
| 'epoch': 0, | |
| 'loss': 0.0, | |
| **initial_metrics | |
| }) | |
| model.train() | |
| for epoch in range(num_epochs): | |
| epoch_loss = 0.0 | |
| for batch_data, batch_labels in dataloader: | |
| batch_data, batch_labels = batch_data.to(DEVICE), batch_labels.to(DEVICE) | |
| anchors, positives, negatives = get_triplet_batch(batch_data, batch_labels, batch_size) | |
| anchor_emb = model(anchors) | |
| positive_emb = model(positives) | |
| negative_emb = model(negatives) | |
| loss = criterion(anchor_emb, positive_emb, negative_emb) | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| epoch_loss += loss.item() | |
| avg_loss = epoch_loss / len(dataloader) | |
| # Track loss only | |
| metrics_history.append({ | |
| 'epoch': epoch + 1, | |
| 'loss': avg_loss, | |
| }) | |
| # Print progress | |
| print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.2f}") | |
| # Final evaluation after training is complete | |
| model.eval() | |
| with torch.no_grad(): | |
| learned_embeddings = model(data.to(DEVICE)).cpu() | |
| final_metrics = calculate_intraclass_metrics(learned_embeddings, labels) | |
| # Save final TSNE plot | |
| save_path = os.path.join(loss_output_dir, f"epoch_{num_epochs:03d}.png") | |
| plot_embeddings( | |
| learned_embeddings.numpy(), | |
| labels, | |
| f"Learned Embeddings (Epoch {num_epochs})", | |
| save_path=save_path, | |
| metrics=final_metrics, | |
| ) | |
| # Add final metrics to the last entry in metrics_history | |
| metrics_history[-1].update(final_metrics) | |
| return model, metrics_history, loss_output_dir | |
| def main(): | |
| """Main execution function.""" | |
| # Generate synthetic data (same data for all loss functions for fair comparison) | |
| print("Generating synthetic data...") | |
| data = torch.randn(NUM_SAMPLES, EMBEDDING_SIZE) | |
| labels = torch.randint(0, NUM_CLASSES, (NUM_SAMPLES,)) | |
| # Define loss functions to compare | |
| loss_configs = [ | |
| {'name': 'Triplet', 'config': {'margin': 1.0, 'learning_rate': 1e-3}}, | |
| {'name': 'InfoNCE', 'config': {'temperature': 0.1, 'learning_rate': 1e-3}}, | |
| {'name': 'NTXent', 'config': {'temperature': 0.5, 'learning_rate': 1e-3}}, | |
| ] | |
| all_results = [] | |
| # Train with each loss function | |
| for loss_setup in loss_configs: | |
| loss_name = loss_setup['name'] | |
| config = loss_setup['config'] | |
| print("\n" + "="*70) | |
| print(f"TRAINING WITH {loss_name.upper()} LOSS") | |
| print("="*70) | |
| # Create dataloader (fresh for each training run) | |
| dataset = TensorDataset(data, labels) | |
| dataloader = DataLoader(dataset, batch_size=32, shuffle=True) | |
| print(f"\nTraining model for 100 epochs using {loss_name} loss...") | |
| print(f"Plots will be saved to '{OUTPUT_DIR}/{loss_name}/' directory\n") | |
| trained_model, metrics_history, loss_output_dir = train_model( | |
| data=data, | |
| labels=labels, | |
| dataloader=dataloader, | |
| num_epochs=100, | |
| batch_size=32, | |
| config=config, | |
| loss_name=loss_name | |
| ) | |
| # Save final metrics summary | |
| print("\n" + "-"*60) | |
| print(f"{loss_name} Training Complete!") | |
| print("-"*60) | |
| print(f"Total epochs: {len(metrics_history) - 1}") | |
| print(f"Final loss: {metrics_history[-1]['loss']:.2f}") | |
| print(f"Final mean intraclass distance: {metrics_history[-1]['mean_intraclass_distance']:.2f}") | |
| print(f"Final mean intraclass variance: {metrics_history[-1]['mean_intraclass_variance']:.2f}") | |
| print(f"\nInitial mean intraclass distance: {metrics_history[0]['mean_intraclass_distance']:.2f}") | |
| print(f"Initial mean intraclass variance: {metrics_history[0]['mean_intraclass_variance']:.2f}") | |
| print(f"\nAll plots saved to '{loss_output_dir}/' directory") | |
| all_results.append({ | |
| 'loss_name': loss_name, | |
| 'metrics_history': metrics_history, | |
| 'output_dir': loss_output_dir | |
| }) | |
| # Print comparison summary | |
| print("\n" + "="*70) | |
| print("FINAL COMPARISON SUMMARY") | |
| print("="*70) | |
| print(f"{'Loss Function':<15} {'Initial Dist':<15} {'Final Dist':<15} {'Improvement':<15}") | |
| print("-"*70) | |
| for result in all_results: | |
| initial_dist = result['metrics_history'][0]['mean_intraclass_distance'] | |
| final_dist = result['metrics_history'][-1]['mean_intraclass_distance'] | |
| improvement = initial_dist - final_dist | |
| print(f"{result['loss_name']:<15} {initial_dist:<15.2f} {final_dist:<15.2f} {improvement:<15.2f}") | |
| print("="*70) | |
| print(f"\nAll results saved to '{OUTPUT_DIR}/' with subfolders for each loss function") | |
| print("="*70) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment