Last active
September 27, 2022 00:23
-
-
Save tomonori-masui/144c2057a64ec892a0a88066607eb3d2 to your computer and use it in GitHub Desktop.
Link predictor training and evaluation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from sklearn.metrics import roc_auc_score | |
| from torch_geometric.utils import negative_sampling | |
| class Net(torch.nn.Module): | |
| def __init__(self, in_channels, hidden_channels, out_channels): | |
| super().__init__() | |
| self.conv1 = GCNConv(in_channels, hidden_channels) | |
| self.conv2 = GCNConv(hidden_channels, out_channels) | |
| def encode(self, x, edge_index): | |
| x = self.conv1(x, edge_index).relu() | |
| return self.conv2(x, edge_index) | |
| def decode(self, z, edge_label_index): | |
| return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum( | |
| dim=-1 | |
| ) # product of a pair of nodes on each edge | |
| def decode_all(self, z): | |
| prob_adj = z @ z.t() | |
| return (prob_adj > 0).nonzero(as_tuple=False).t() | |
| def train_link_predictor( | |
| model, train_data, val_data, optimizer, criterion, n_epochs=100 | |
| ): | |
| for epoch in range(1, n_epochs + 1): | |
| model.train() | |
| optimizer.zero_grad() | |
| z = model.encode(train_data.x, train_data.edge_index) | |
| # sampling training negatives for every training epoch | |
| neg_edge_index = negative_sampling( | |
| edge_index=train_data.edge_index, num_nodes=train_data.num_nodes, | |
| num_neg_samples=train_data.edge_label_index.size(1), method='sparse') | |
| edge_label_index = torch.cat( | |
| [train_data.edge_label_index, neg_edge_index], | |
| dim=-1, | |
| ) | |
| edge_label = torch.cat([ | |
| train_data.edge_label, | |
| train_data.edge_label.new_zeros(neg_edge_index.size(1)) | |
| ], dim=0) | |
| out = model.decode(z, edge_label_index).view(-1) | |
| loss = criterion(out, edge_label) | |
| loss.backward() | |
| optimizer.step() | |
| val_auc = eval_link_predictor(model, val_data) | |
| if epoch % 10 == 0: | |
| print(f"Epoch: {epoch:03d}, Train Loss: {loss:.3f}, Val AUC: {val_auc:.3f}") | |
| return model | |
| @torch.no_grad() | |
| def eval_link_predictor(model, data): | |
| model.eval() | |
| z = model.encode(data.x, data.edge_index) | |
| out = model.decode(z, data.edge_label_index).view(-1).sigmoid() | |
| return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment