Skip to content

Instantly share code, notes, and snippets.

@gregchu
Last active October 25, 2025 00:12
Show Gist options
  • Select an option

  • Save gregchu/f709b8a4389fa5b8e850d56b28227d03 to your computer and use it in GitHub Desktop.

Select an option

Save gregchu/f709b8a4389fa5b8e850d56b28227d03 to your computer and use it in GitHub Desktop.
Contrastive Learning of Random Vectors
"""
Contrastive Learning Script
Aims:
- Train model to move embeddings closer that are of the same class
- Visualize embeddings in 2D space before and after contrastive learning
- Calculate and track intraclass variance metrics
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
NUM_SAMPLES = 300
EMBEDDING_SIZE = 64
NUM_CLASSES = 3
OUTPUT_DIR = "tsne_plots"
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
def calculate_intraclass_metrics(embeddings, labels):
"""
Calculate intraclass distance and variance metrics.
Args:
embeddings (torch.Tensor or np.ndarray): shape (n_samples, embedding_dim)
labels (torch.Tensor or np.ndarray): shape (n_samples,)
Returns:
dict: metrics including mean intraclass distance and variance per class
"""
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().numpy()
if isinstance(labels, torch.Tensor):
labels = labels.cpu().numpy()
metrics = {
'mean_intraclass_distance': 0.0,
'mean_intraclass_variance': 0.0,
'per_class_distance': {},
'per_class_variance': {}
}
unique_labels = np.unique(labels)
all_distances = []
all_variances = []
for label in unique_labels:
# Get all embeddings for this class
class_mask = labels == label
class_embeddings = embeddings[class_mask]
if len(class_embeddings) < 2:
continue
# Calculate pairwise distances within class
distances = []
for i in range(len(class_embeddings)):
for j in range(i + 1, len(class_embeddings)):
dist = np.linalg.norm(class_embeddings[i] - class_embeddings[j])
distances.append(dist)
# Calculate variance (spread from centroid)
centroid = np.mean(class_embeddings, axis=0)
variances = [np.linalg.norm(emb - centroid) for emb in class_embeddings]
mean_dist = np.mean(distances) if distances else 0.0
mean_var = np.mean(variances) if variances else 0.0
metrics['per_class_distance'][int(label)] = mean_dist
metrics['per_class_variance'][int(label)] = mean_var
all_distances.extend(distances)
all_variances.extend(variances)
metrics['mean_intraclass_distance'] = np.mean(all_distances) if all_distances else 0.0
metrics['mean_intraclass_variance'] = np.mean(all_variances) if all_variances else 0.0
return metrics
def plot_embeddings(embeddings, labels, title, save_path=None, metrics=None, axis_limits=None):
"""
Plot embeddings using TSNE and optionally save to file.
Args:
embeddings (np.ndarray): shape (n_samples, embedding_dim)
labels (torch.Tensor or np.ndarray): shape (n_samples,)
title (str): plot title
save_path (str, optional): path to save the plot
metrics (dict, optional): intraclass metrics to display
axis_limits (tuple, optional): (xmin, xmax, ymin, ymax) for consistent axes
"""
if isinstance(labels, torch.Tensor):
labels_np = labels.cpu().numpy()
else:
labels_np = labels
tsne = TSNE(n_components=2, random_state=42)
reduced = tsne.fit_transform(embeddings)
fig, ax = plt.subplots(figsize=(10, 8), facecolor='white')
ax.set_facecolor('white')
unique_labels = np.unique(labels_np)
for label in unique_labels:
idxs = labels_np == label
ax.scatter(reduced[idxs, 0], reduced[idxs, 1], label=f"Class {label}", alpha=0.6, s=50)
ax.legend()
# Set fixed axis limits if provided
if axis_limits is not None:
ax.set_xlim(axis_limits[0], axis_limits[1])
ax.set_ylim(axis_limits[2], axis_limits[3])
# Add mean intraclass distance to title if metrics provided
if metrics:
title_with_metric = f"{title}\nMean Intraclass Distance: {metrics['mean_intraclass_distance']:.2f}"
ax.set_title(title_with_metric, fontsize=14, fontweight='bold')
else:
ax.set_title(title, fontsize=14, fontweight='bold')
# Add metrics as text if provided
if metrics:
metrics_text = f"Mean Intraclass Variance: {metrics['mean_intraclass_variance']:.2f}\n"
metrics_text += "\nPer-Class Variance:\n"
for cls, var in sorted(metrics['per_class_variance'].items()):
metrics_text += f" Class {cls}: {var:.2f}\n"
ax.text(0.02, 0.98, metrics_text, transform=ax.transAxes,
fontsize=9, verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='white', edgecolor='black', alpha=0.8))
ax.set_xlabel("TSNE Dimension 1")
ax.set_ylabel("TSNE Dimension 2")
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white', edgecolor='white')
print(f"Saved plot to {save_path}")
plt.close()
class ContrastiveEncoder(nn.Module):
def __init__(self, input_size, embedding_size):
super().__init__()
self.model = nn.Sequential(
nn.Linear(input_size, 128),
nn.ReLU(),
nn.BatchNorm1d(128),
nn.Linear(128, embedding_size)
)
def forward(self, x):
x = self.model(x)
x = F.normalize(x, p=2, dim=1)
return x
class TripletLoss(nn.Module):
def __init__(self, margin=1.0):
super().__init__()
self.margin = margin
def forward(self, anchor, positive, negative):
pos_dist = F.pairwise_distance(anchor, positive)
neg_dist = F.pairwise_distance(anchor, negative)
losses = F.relu(pos_dist - neg_dist + self.margin)
return losses.mean()
class InfoNCELoss(nn.Module):
"""
InfoNCE Loss from CPC paper (van den Oord et al., 2018).
Uses cross-entropy with in-batch negatives.
Each anchor[i] should match with positive[i] (at index i).
All other samples (positives + negatives) serve as negatives.
"""
def __init__(self, temperature=0.1):
super(InfoNCELoss, self).__init__()
self.temperature = temperature
def forward(self, anchor, positive, negative):
"""
Args:
anchor: (batch_size, embedding_dim)
positive: (batch_size, embedding_dim) - positive samples
negative: (batch_size, embedding_dim) - negative samples
For each anchor[i], positive[i] is the correct match (label=i).
All samples in [positive, negative] serve as the candidate set.
"""
batch_size = anchor.shape[0]
# Normalize embeddings for cosine similarity
anchor = F.normalize(anchor, p=2, dim=-1)
positive = F.normalize(positive, p=2, dim=-1)
negative = F.normalize(negative, p=2, dim=-1)
# Concatenate positives and negatives to form candidate set
# Shape: (batch_size, 2 * batch_size)
candidates = torch.cat([positive, negative], dim=0) # (2*batch_size, embedding_dim)
# Compute similarity matrix: (batch_size, 2*batch_size)
similarity_matrix = torch.matmul(anchor, candidates.T)
similarity_matrix /= self.temperature
# Labels: anchor[i] should match positive[i] which is at index i in candidates
labels = torch.arange(batch_size, device=similarity_matrix.device)
# Cross entropy: -log(exp(sim[i,i]) / sum_j(exp(sim[i,j])))
return F.cross_entropy(similarity_matrix, labels)
def get_triplet_batch(data, labels, batch_size):
"""
Sample a batch of triplets (anchor, positive, negative) for triplet loss training.
For each triplet:
- Anchor: a random sample from data
- Positive: another random sample with the same label as the anchor
- Negative: a random sample with a different label from the anchor
Args:
data (torch.Tensor): shape (n_samples, embedding_dim)
labels (torch.Tensor): shape (n_samples,), class label for each sample
batch_size (int): number of triplets to sample
Returns:
anchors (torch.Tensor): shape (batch_size, embedding_dim)
positives (torch.Tensor): shape (batch_size, embedding_dim)
negatives (torch.Tensor): shape (batch_size, embedding_dim)
"""
n = len(data)
dim = len(data[0])
anchors, positives, negatives = [], [], []
for _ in range(batch_size):
anchor_idx = np.random.randint(0, n)
anchor_label = labels[anchor_idx]
positive_idx = np.random.choice(np.where(labels == anchor_label)[0])
positive = data[positive_idx]
negative_idx = np.random.choice(np.where(labels != anchor_label)[0])
negative = data[negative_idx]
anchors.append(data[anchor_idx])
positives.append(positive)
negatives.append(negative)
return (torch.stack(anchors), torch.stack(positives), torch.stack(negatives))
def train_model(data, labels, dataloader, num_epochs=20, batch_size=32, config=None, loss_name="InfoNCE"):
"""
Train the ContrastiveEncoder using specified loss function.
Saves TSNE plots and calculates intraclass metrics every epoch.
Args:
data (torch.Tensor): full dataset
labels (torch.Tensor): full labels
dataloader (DataLoader): dataloader for batching
num_epochs (int): number of training epochs
batch_size (int): size of each training batch
config (dict): configuration dictionary containing hyperparameters
loss_name (str): name of the loss function for organizing output
Returns:
model (ContrastiveEncoder): trained model
"""
model = ContrastiveEncoder(EMBEDDING_SIZE, EMBEDDING_SIZE).to(DEVICE)
# Select loss function based on name
if loss_name.lower() == "infonce":
criterion = InfoNCELoss(temperature=config.get('temperature', 0.1))
elif loss_name.lower() == "triplet":
criterion = TripletLoss(margin=config.get('margin', 1.0))
else:
raise ValueError(f"Unsupported loss function name: {loss_name}")
optimizer = optim.Adam(model.parameters(), lr=config.get('learning_rate', 1e-3))
# Create subfolder for this loss function
loss_output_dir = os.path.join(OUTPUT_DIR, loss_name)
os.makedirs(loss_output_dir, exist_ok=True)
metrics_history = []
# Save initial embeddings (raw data, not passed through model)
initial_embeddings = data.cpu().numpy()
initial_metrics = calculate_intraclass_metrics(initial_embeddings, labels)
plot_embeddings(
initial_embeddings,
labels,
"Initial Embeddings (Epoch 0)",
save_path=os.path.join(loss_output_dir, "epoch_000.png"),
metrics=initial_metrics,
)
metrics_history.append({
'epoch': 0,
'loss': 0.0,
**initial_metrics
})
model.train()
for epoch in range(num_epochs):
epoch_loss = 0.0
for batch_data, batch_labels in dataloader:
batch_data, batch_labels = batch_data.to(DEVICE), batch_labels.to(DEVICE)
anchors, positives, negatives = get_triplet_batch(batch_data, batch_labels, batch_size)
anchor_emb = model(anchors)
positive_emb = model(positives)
negative_emb = model(negatives)
loss = criterion(anchor_emb, positive_emb, negative_emb)
loss.backward()
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(dataloader)
# Track loss only
metrics_history.append({
'epoch': epoch + 1,
'loss': avg_loss,
})
# Print progress
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.2f}")
# Final evaluation after training is complete
model.eval()
with torch.no_grad():
learned_embeddings = model(data.to(DEVICE)).cpu()
final_metrics = calculate_intraclass_metrics(learned_embeddings, labels)
# Save final TSNE plot
save_path = os.path.join(loss_output_dir, f"epoch_{num_epochs:03d}.png")
plot_embeddings(
learned_embeddings.numpy(),
labels,
f"Learned Embeddings (Epoch {num_epochs})",
save_path=save_path,
metrics=final_metrics,
)
# Add final metrics to the last entry in metrics_history
metrics_history[-1].update(final_metrics)
return model, metrics_history, loss_output_dir
def main():
"""Main execution function."""
# Generate synthetic data (same data for all loss functions for fair comparison)
print("Generating synthetic data...")
data = torch.randn(NUM_SAMPLES, EMBEDDING_SIZE)
labels = torch.randint(0, NUM_CLASSES, (NUM_SAMPLES,))
# Define loss functions to compare
loss_configs = [
{'name': 'Triplet', 'config': {'margin': 1.0, 'learning_rate': 1e-3}},
{'name': 'InfoNCE', 'config': {'temperature': 0.1, 'learning_rate': 1e-3}},
{'name': 'NTXent', 'config': {'temperature': 0.5, 'learning_rate': 1e-3}},
]
all_results = []
# Train with each loss function
for loss_setup in loss_configs:
loss_name = loss_setup['name']
config = loss_setup['config']
print("\n" + "="*70)
print(f"TRAINING WITH {loss_name.upper()} LOSS")
print("="*70)
# Create dataloader (fresh for each training run)
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
print(f"\nTraining model for 100 epochs using {loss_name} loss...")
print(f"Plots will be saved to '{OUTPUT_DIR}/{loss_name}/' directory\n")
trained_model, metrics_history, loss_output_dir = train_model(
data=data,
labels=labels,
dataloader=dataloader,
num_epochs=100,
batch_size=32,
config=config,
loss_name=loss_name
)
# Save final metrics summary
print("\n" + "-"*60)
print(f"{loss_name} Training Complete!")
print("-"*60)
print(f"Total epochs: {len(metrics_history) - 1}")
print(f"Final loss: {metrics_history[-1]['loss']:.2f}")
print(f"Final mean intraclass distance: {metrics_history[-1]['mean_intraclass_distance']:.2f}")
print(f"Final mean intraclass variance: {metrics_history[-1]['mean_intraclass_variance']:.2f}")
print(f"\nInitial mean intraclass distance: {metrics_history[0]['mean_intraclass_distance']:.2f}")
print(f"Initial mean intraclass variance: {metrics_history[0]['mean_intraclass_variance']:.2f}")
print(f"\nAll plots saved to '{loss_output_dir}/' directory")
all_results.append({
'loss_name': loss_name,
'metrics_history': metrics_history,
'output_dir': loss_output_dir
})
# Print comparison summary
print("\n" + "="*70)
print("FINAL COMPARISON SUMMARY")
print("="*70)
print(f"{'Loss Function':<15} {'Initial Dist':<15} {'Final Dist':<15} {'Improvement':<15}")
print("-"*70)
for result in all_results:
initial_dist = result['metrics_history'][0]['mean_intraclass_distance']
final_dist = result['metrics_history'][-1]['mean_intraclass_distance']
improvement = initial_dist - final_dist
print(f"{result['loss_name']:<15} {initial_dist:<15.2f} {final_dist:<15.2f} {improvement:<15.2f}")
print("="*70)
print(f"\nAll results saved to '{OUTPUT_DIR}/' with subfolders for each loss function")
print("="*70)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment