Skip to content

Instantly share code, notes, and snippets.

@gabrieldernbach
Created November 27, 2024 21:21
Show Gist options
  • Select an option

  • Save gabrieldernbach/b3674abc77e560c4b25c1bd35586226d to your computer and use it in GitHub Desktop.

Select an option

Save gabrieldernbach/b3674abc77e560c4b25c1bd35586226d to your computer and use it in GitHub Desktop.
mine.py
class MINE(nn.Module):
"""
Stub to implement Mutual Information Neural Estimation.
See https://arxiv.org/pdf/1801.04062
Quote:
We argue that the estimation of mutual information
between high dimensional continuous random variables can be
achieved by gradient descent over neural networks. We present a
Mutual Information Neural Estimator (MINE) that is linearly
scalable in dimensionality as well as in sample
size, trainable through back-prop, and strongly consistent
"""
def __init__(self,):
super().__init__()
self.fc1 = nn.Linear(1, 100)
self.fc2 = nn.Linear(1, 100)
self.fc3 = nn.Linear(100, 1)
def predict(self, x, y):
a = F.relu(self.fc1(x) + self.fc2(y))
return self.fc3(a)
def forward(self, x, y):
# estimates the mutual information between x and y
# see eq (5), the Donsker-Varadhan representation
py = y[torch.randperm(len(y))]
xy = self.predict(x, y)
xpy = self.predict(x, py)
mi = torch.mean(xy) - torch.log(torch.mean(torch.exp(xpy)))
return -mi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment