Skip to content

Instantly share code, notes, and snippets.

@dpressel
Last active June 8, 2020 17:18
Show Gist options
  • Select an option

  • Save dpressel/3b4780bafcef14377085544f44183353 to your computer and use it in GitHub Desktop.

Select an option

Save dpressel/3b4780bafcef14377085544f44183353 to your computer and use it in GitHub Desktop.
Highway layer using PyTorch
import torch
import torch.nn as nn
class Highway(nn.Module):
def __init__(self, input_size):
super(Highway, self).__init__()
self.proj = nn.Linear(input_size, input_size)
self.transform = nn.Linear(input_size, input_size)
self.transform.bias.data.fill_(-2.0)
def forward(self, input):
proj_result = nn.functional.relu(self.proj(input))
proj_gate = nn.functional.sigmoid(self.transform(input))
gated = (proj_gate * proj_result) + ((1 - proj_gate) * input)
return gated
@dpressel
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment