Created
November 27, 2024 21:21
-
-
Save gabrieldernbach/b3674abc77e560c4b25c1bd35586226d to your computer and use it in GitHub Desktop.
mine.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class MINE(nn.Module): | |
| """ | |
| Stub to implement Mutual Information Neural Estimation. | |
| See https://arxiv.org/pdf/1801.04062 | |
| Quote: | |
| We argue that the estimation of mutual information | |
| between high dimensional continuous random variables can be | |
| achieved by gradient descent over neural networks. We present a | |
| Mutual Information Neural Estimator (MINE) that is linearly | |
| scalable in dimensionality as well as in sample | |
| size, trainable through back-prop, and strongly consistent | |
| """ | |
| def __init__(self,): | |
| super().__init__() | |
| self.fc1 = nn.Linear(1, 100) | |
| self.fc2 = nn.Linear(1, 100) | |
| self.fc3 = nn.Linear(100, 1) | |
| def predict(self, x, y): | |
| a = F.relu(self.fc1(x) + self.fc2(y)) | |
| return self.fc3(a) | |
| def forward(self, x, y): | |
| # estimates the mutual information between x and y | |
| # see eq (5), the Donsker-Varadhan representation | |
| py = y[torch.randperm(len(y))] | |
| xy = self.predict(x, y) | |
| xpy = self.predict(x, py) | |
| mi = torch.mean(xy) - torch.log(torch.mean(torch.exp(xpy))) | |
| return -mi |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment