Skip to content

Instantly share code, notes, and snippets.

@Stfort52
Created October 15, 2025 08:13
Show Gist options
  • Select an option

  • Save Stfort52/f8a0ba49078f802ad8bcbbe397cb2d08 to your computer and use it in GitHub Desktop.

Select an option

Save Stfort52/f8a0ba49078f802ad8bcbbe397cb2d08 to your computer and use it in GitHub Desktop.
Rewrite of "Multi-Label Supervised Contrastive Learning" for more causal use
# From the paper "Multi-Label Supervised Contrastive Learning"
# https://doi.org/10.1609/aaai.v38i15.29619
# https://github.com/williamzhangsjtu/MulSupCon
from typing import Callable, Literal, overload
import torch
import torch.nn.functional as F
from torch import Tensor, nn
OUTPUT_FUNC_T = Callable[[Tensor, Tensor, Tensor], tuple[Tensor, ...]]
class MulSupConLoss(nn.Module):
def __init__(
self,
method: Literal["all", "any", "msc", "mscw"] = "msc",
temperature: float = 0.1,
):
"""
Multi-label Supervised Contrastive Loss from https://doi.org/10.1609/aaai.v38i15.29619
Parameters:
method: how to generate mask and score
- all: labels exactly matched with the anchor
- any: labels with at least one common label with the anchor
- msc: treat each of anchor's label separately (Eq. 4 in the paper)
- mscw: same as msc but weight the loss by 1/|y| (Eq. 6 in the paper)
temperature: temperature for scaling the logits
"""
super().__init__()
self.temperature = temperature
self.method = method
match method:
case "all" | "any":
self.inner_forward = self._unweighted_forward
self.output_func = _get_output_func(pattern=method, with_weight=False)
case "msc" | "mscw":
self.inner_forward = self._weighted_forward
self.output_func = _get_output_func(
pattern="msc", with_weight=(method == "mscw")
)
case _:
raise NotImplementedError("Unknown method: {}".format(method))
def forward(
self,
batch_logits: Tensor,
batch_labels: Tensor,
ref_logits: Tensor,
ref_labels: Tensor,
) -> Tensor:
"""
Forward pass to compute the MulSupCon loss.
Parameters:
batch_logits: B x D tensor, embeddings of the anchor
batch_labels: B x C tensor, labels of the anchor
ref_logits: Q x D tensor, embeddings of samples from reference set (i.e. the queue in MoCo)
ref_labels: Q x C tensor, labels of samples from reference set
Returns:
loss: the computed loss
"""
batch_logits = F.normalize(batch_logits, p=2, dim=-1) # B x D
ref_logits = F.normalize(ref_logits, p=2, dim=-1) # Q x D
scores = batch_logits @ ref_logits.T # B x Q
output = self.output_func(batch_labels.float(), ref_labels.float(), scores)
return self.inner_forward(*output)
def _unweighted_forward(self, score: Tensor, mask: Tensor) -> Tensor:
num_pos = mask.sum(1)
loss = (
-(torch.log((F.softmax(score / self.temperature, dim=1))) * mask).sum(1)
/ num_pos
)
return loss.mean()
def _weighted_forward(self, score: Tensor, mask: Tensor, weight: Tensor) -> Tensor:
num_pos = mask.sum(1)
loss = (
-(torch.log((F.softmax(score / self.temperature, dim=1))) * mask).sum(1)
/ num_pos
)
return (loss * weight).sum()
@overload
def _get_output_func(
pattern: Literal["all", "any"], with_weight: Literal[False] = False
) -> Callable[[Tensor, Tensor, Tensor], tuple[Tensor, Tensor]]: ...
@overload
def _get_output_func(
pattern: Literal["msc"], with_weight: bool = False
) -> Callable[[Tensor, Tensor, Tensor], tuple[Tensor, Tensor, Tensor]]: ...
def _get_output_func(
pattern: Literal["all", "any", "msc"] = "msc", with_weight: bool = False
) -> OUTPUT_FUNC_T:
"""
Paremeters:
pattern: how to generate mask and score
- all: labels exactly matched with the anchor
- any: labels with at least one common label with the anchor
- msc: treat each of anchor's label separately
with_weight: argument for sep pattern, whether to use 1/|y| to weight the loss
Returns:
a function that takes in (batch_labels, ref_labels, scores) and returns (scores, mask, weight (if msc))
"""
if pattern != "msc" and with_weight:
raise ValueError("with_weight can only be True when pattern is 'msc'")
def generate_output_MulSupCon(
batch_labels: Tensor, ref_labels: Tensor, scores: Tensor
):
"""
MulSupCon
Parameters:
batch_labels: B x C tensor, labels of the anchor
ref_labels: Q x C tensor, labels of samples from queue
scores: B x Q tensor, cosine similarity between the anchor and samples from queue
"""
B = len(batch_labels)
indices = torch.where(batch_labels == 1)
scores = scores[indices[0]] # P x Q
labels = batch_labels.new_zeros(scores.shape[0], ref_labels.shape[1]) # P x C
rows = torch.arange(labels.shape[0], device=labels.device)
labels[rows, indices[1]] = 1
masks = (labels @ ref_labels.T).to(torch.bool)
n_score_per_sample = batch_labels.sum(dim=1).to(torch.int16).tolist()
if with_weight:
weights_per_sample = [
1 / (n * B) for n in n_score_per_sample for _ in range(n)
]
else:
weights_per_sample = [
1 / len(scores) for n in n_score_per_sample for _ in range(n)
]
weights_per_sample = torch.tensor(
weights_per_sample, device=scores.device, dtype=torch.float32
)
return scores, masks.to(torch.long), weights_per_sample
def generate_output_all(batch_labels: Tensor, ref_labels: Tensor, scores: Tensor):
"""
positives: labels exactly matched with the anchor
"""
mul_matrix = (batch_labels @ ref_labels.T).to(torch.int16)
mask1 = (
torch.sum(batch_labels, dim=1).unsqueeze(1).to(torch.int16) == mul_matrix
)
mask2 = (
torch.sum(ref_labels, dim=1).unsqueeze(1).to(torch.int16) == mul_matrix.T
)
mask = mask1 & mask2.T
return scores, mask.to(torch.long)
def generate_output_any(batch_labels: Tensor, ref_labels: Tensor, scores: Tensor):
"""
positives: labels with at least one common label with the anchor
"""
mul_matrix = batch_labels @ ref_labels.T
return scores, (mul_matrix > 0).to(torch.long)
if pattern == "all":
return generate_output_all
elif pattern == "any":
return generate_output_any
elif pattern == "msc":
return generate_output_MulSupCon
else:
raise NotImplementedError
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment