Skip to content

Instantly share code, notes, and snippets.

@cedias
Last active October 16, 2017 12:53
Show Gist options
  • Select an option

  • Save cedias/946a380807b7e1bf92d738268b71415a to your computer and use it in GitHub Desktop.

Select an option

Save cedias/946a380807b7e1bf92d738268b71415a to your computer and use it in GitHub Desktop.
#memerr.py
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from random import randint
class MiniNet(nn.Module):
def __init__(self, ntoken,sparse):
super(MiniNet, self).__init__()
self.embed = nn.Embedding(ntoken+1, 100, sparse=sparse)
self.linear = nn.Linear(100,4)
def forward(self,input):
out = self.embed(input)
out = out.sum(-2)
return self.linear(out)
def train(epoch,net,optimizer,dataset,criterion):
with tqdm(total=len(dataset),desc="Training") as pbar:
for iteration, (data,label) in enumerate(dataset):
data = Variable(data.cuda().squeeze(1).long())
label = Variable(label.cuda().long())
optimizer.zero_grad()
out = net(data)
loss = criterion(out, label)
loss.backward()
optimizer.step()
pbar.update(1)
def main(args):
criterion = torch.nn.CrossEntropyLoss()
net = MiniNet(args.num_emb,args.not_sparse)
train_set = []
for x in range(args.num_ex):
train_set.append((torch.Tensor([randint(0,args.num_emb) for x in range(10)]).unsqueeze(0),randint(0,3)))
dataloader = DataLoader(train_set, batch_size=args.b_size, shuffle=True, num_workers=2,pin_memory=False)
net.cuda()
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=args.momentum)
for epoch in range(args.epochs):
train(epoch,net,optimizer,dataloader,criterion)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Momentum/sparse error')
parser.add_argument("--momentum",type=float,default=0.9)
parser.add_argument("--num-emb",type=int,default=2048)
parser.add_argument("--num-ex",type=int,default=10000)
parser.add_argument("--b-size",type=int, default=128)
parser.add_argument("--not-sparse",action="store_false")
parser.add_argument("--epochs",type=int,default=100)
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment