Created
October 15, 2025 08:13
-
-
Save Stfort52/f8a0ba49078f802ad8bcbbe397cb2d08 to your computer and use it in GitHub Desktop.
Rewrite of "Multi-Label Supervised Contrastive Learning" for more causal use
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # From the paper "Multi-Label Supervised Contrastive Learning" | |
| # https://doi.org/10.1609/aaai.v38i15.29619 | |
| # https://github.com/williamzhangsjtu/MulSupCon | |
| from typing import Callable, Literal, overload | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import Tensor, nn | |
| OUTPUT_FUNC_T = Callable[[Tensor, Tensor, Tensor], tuple[Tensor, ...]] | |
| class MulSupConLoss(nn.Module): | |
| def __init__( | |
| self, | |
| method: Literal["all", "any", "msc", "mscw"] = "msc", | |
| temperature: float = 0.1, | |
| ): | |
| """ | |
| Multi-label Supervised Contrastive Loss from https://doi.org/10.1609/aaai.v38i15.29619 | |
| Parameters: | |
| method: how to generate mask and score | |
| - all: labels exactly matched with the anchor | |
| - any: labels with at least one common label with the anchor | |
| - msc: treat each of anchor's label separately (Eq. 4 in the paper) | |
| - mscw: same as msc but weight the loss by 1/|y| (Eq. 6 in the paper) | |
| temperature: temperature for scaling the logits | |
| """ | |
| super().__init__() | |
| self.temperature = temperature | |
| self.method = method | |
| match method: | |
| case "all" | "any": | |
| self.inner_forward = self._unweighted_forward | |
| self.output_func = _get_output_func(pattern=method, with_weight=False) | |
| case "msc" | "mscw": | |
| self.inner_forward = self._weighted_forward | |
| self.output_func = _get_output_func( | |
| pattern="msc", with_weight=(method == "mscw") | |
| ) | |
| case _: | |
| raise NotImplementedError("Unknown method: {}".format(method)) | |
| def forward( | |
| self, | |
| batch_logits: Tensor, | |
| batch_labels: Tensor, | |
| ref_logits: Tensor, | |
| ref_labels: Tensor, | |
| ) -> Tensor: | |
| """ | |
| Forward pass to compute the MulSupCon loss. | |
| Parameters: | |
| batch_logits: B x D tensor, embeddings of the anchor | |
| batch_labels: B x C tensor, labels of the anchor | |
| ref_logits: Q x D tensor, embeddings of samples from reference set (i.e. the queue in MoCo) | |
| ref_labels: Q x C tensor, labels of samples from reference set | |
| Returns: | |
| loss: the computed loss | |
| """ | |
| batch_logits = F.normalize(batch_logits, p=2, dim=-1) # B x D | |
| ref_logits = F.normalize(ref_logits, p=2, dim=-1) # Q x D | |
| scores = batch_logits @ ref_logits.T # B x Q | |
| output = self.output_func(batch_labels.float(), ref_labels.float(), scores) | |
| return self.inner_forward(*output) | |
| def _unweighted_forward(self, score: Tensor, mask: Tensor) -> Tensor: | |
| num_pos = mask.sum(1) | |
| loss = ( | |
| -(torch.log((F.softmax(score / self.temperature, dim=1))) * mask).sum(1) | |
| / num_pos | |
| ) | |
| return loss.mean() | |
| def _weighted_forward(self, score: Tensor, mask: Tensor, weight: Tensor) -> Tensor: | |
| num_pos = mask.sum(1) | |
| loss = ( | |
| -(torch.log((F.softmax(score / self.temperature, dim=1))) * mask).sum(1) | |
| / num_pos | |
| ) | |
| return (loss * weight).sum() | |
| @overload | |
| def _get_output_func( | |
| pattern: Literal["all", "any"], with_weight: Literal[False] = False | |
| ) -> Callable[[Tensor, Tensor, Tensor], tuple[Tensor, Tensor]]: ... | |
| @overload | |
| def _get_output_func( | |
| pattern: Literal["msc"], with_weight: bool = False | |
| ) -> Callable[[Tensor, Tensor, Tensor], tuple[Tensor, Tensor, Tensor]]: ... | |
| def _get_output_func( | |
| pattern: Literal["all", "any", "msc"] = "msc", with_weight: bool = False | |
| ) -> OUTPUT_FUNC_T: | |
| """ | |
| Paremeters: | |
| pattern: how to generate mask and score | |
| - all: labels exactly matched with the anchor | |
| - any: labels with at least one common label with the anchor | |
| - msc: treat each of anchor's label separately | |
| with_weight: argument for sep pattern, whether to use 1/|y| to weight the loss | |
| Returns: | |
| a function that takes in (batch_labels, ref_labels, scores) and returns (scores, mask, weight (if msc)) | |
| """ | |
| if pattern != "msc" and with_weight: | |
| raise ValueError("with_weight can only be True when pattern is 'msc'") | |
| def generate_output_MulSupCon( | |
| batch_labels: Tensor, ref_labels: Tensor, scores: Tensor | |
| ): | |
| """ | |
| MulSupCon | |
| Parameters: | |
| batch_labels: B x C tensor, labels of the anchor | |
| ref_labels: Q x C tensor, labels of samples from queue | |
| scores: B x Q tensor, cosine similarity between the anchor and samples from queue | |
| """ | |
| B = len(batch_labels) | |
| indices = torch.where(batch_labels == 1) | |
| scores = scores[indices[0]] # P x Q | |
| labels = batch_labels.new_zeros(scores.shape[0], ref_labels.shape[1]) # P x C | |
| rows = torch.arange(labels.shape[0], device=labels.device) | |
| labels[rows, indices[1]] = 1 | |
| masks = (labels @ ref_labels.T).to(torch.bool) | |
| n_score_per_sample = batch_labels.sum(dim=1).to(torch.int16).tolist() | |
| if with_weight: | |
| weights_per_sample = [ | |
| 1 / (n * B) for n in n_score_per_sample for _ in range(n) | |
| ] | |
| else: | |
| weights_per_sample = [ | |
| 1 / len(scores) for n in n_score_per_sample for _ in range(n) | |
| ] | |
| weights_per_sample = torch.tensor( | |
| weights_per_sample, device=scores.device, dtype=torch.float32 | |
| ) | |
| return scores, masks.to(torch.long), weights_per_sample | |
| def generate_output_all(batch_labels: Tensor, ref_labels: Tensor, scores: Tensor): | |
| """ | |
| positives: labels exactly matched with the anchor | |
| """ | |
| mul_matrix = (batch_labels @ ref_labels.T).to(torch.int16) | |
| mask1 = ( | |
| torch.sum(batch_labels, dim=1).unsqueeze(1).to(torch.int16) == mul_matrix | |
| ) | |
| mask2 = ( | |
| torch.sum(ref_labels, dim=1).unsqueeze(1).to(torch.int16) == mul_matrix.T | |
| ) | |
| mask = mask1 & mask2.T | |
| return scores, mask.to(torch.long) | |
| def generate_output_any(batch_labels: Tensor, ref_labels: Tensor, scores: Tensor): | |
| """ | |
| positives: labels with at least one common label with the anchor | |
| """ | |
| mul_matrix = batch_labels @ ref_labels.T | |
| return scores, (mul_matrix > 0).to(torch.long) | |
| if pattern == "all": | |
| return generate_output_all | |
| elif pattern == "any": | |
| return generate_output_any | |
| elif pattern == "msc": | |
| return generate_output_MulSupCon | |
| else: | |
| raise NotImplementedError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment