Skip to content

Instantly share code, notes, and snippets.

@dpressel
Created November 6, 2018 20:33
Show Gist options
  • Select an option

  • Save dpressel/a81343a938e5214f0189d94b96914e3e to your computer and use it in GitHub Desktop.

Select an option

Save dpressel/a81343a938e5214f0189d94b96914e3e to your computer and use it in GitHub Desktop.
mead-train --config config/sst2-rawten.py
from baseline.reader import register_reader, SeqLabelReader
from baseline.vectorizers import register_vectorizer, create_vectorizer, Vectorizer
from collections import Counter
from baseline.data import DataFeed
from baseline.embeddings import register_embeddings
from baseline.tf.embeddings import TensorFlowEmbeddings
from baseline.utils import read_json
import numpy as np
import tensorflow as tf
class RawFeed(DataFeed):
def __init__(self, x, y):
self.steps = 0
self.shuffle = False
self.x = x
self.y = y
self.steps = len(y)
self.word_lengths = (self.x.shape[1], self.x.shape[2])
def _batch(self, i):
return {'word': self.x[i], 'word_lengths': self.word_lengths, 'y': self.y[i]}
def __getitem__(self, i):
return self._batch(i)
def __iter__(self):
shuffle = np.random.permutation(np.arange(self.steps)) if self.shuffle else np.arange(self.steps)
for i in range(self.steps):
si = shuffle[i]
yield self._batch(si)
def __len__(self):
return self.steps
@register_reader(task="classify", name="raw")
class NPZSeqLabelReader(SeqLabelReader):
"""Read a bunch of tensors in from file and make the feed dict from this instead of words
"""
def __init__(self, _, __, labels_json=None, **kwargs):
self.labels = read_json(labels_json)
def build_vocab(self, files, **kwargs):
return {'word': {}}, self.labels
def load(self, filename, index, batchsz, shuffle=True, sort_key=None, **kwargs):
v = np.load(filename + '.npz')
x = v['arr_0']
print(x.shape)
y = v['arr_1']
return RawFeed(x, y)
@register_vectorizer(name='raw')
class Raw2DVectorizer(Vectorizer):
"""Vectorizer that does nothing
"""
def __init__(self, **kwargs):
super(Raw2DVectorizer, self).__init__()
self.mxlen = kwargs.get('mxlen', -1)
self.max_seen = 0
self.dsz = kwargs['dsz']
def count(self, tokens):
"""We dont bother really counting here, there would be nothing to count
:param tokens:
:return:
"""
seen = tokens.shape[0]
self.dsz = tokens.shape[1]
self.max_seen = max(self.max_seen, seen)
return Counter()
def run(self, tokens, vocab):
"""There is nothing to do here either, just forward along someone else's hard work
:param tokens:
:param vocab:
:return:
"""
return tokens, tokens.shape[1]
def get_dims(self):
"""The dims here are
:return:
"""
self.mxlen, self.dsz
def iterable(self, tokens):
"""TODO(dpressel): This looks like a method that is superfluous and shouldnt have been in the base class
:param tokens:
:return:
"""
raise Exception('Not implemented/unused')
@register_embeddings(name='identity')
class IdentityEmbeddings(TensorFlowEmbeddings):
"""This is a pass-through operation where the input is presumed to be the same as the output
We still need a placeholder to pass the tensors through, but the `encode()` function doesnt do any work anymore
"""
@classmethod
def create_placeholder(cls, name):
return tf.placeholder(tf.float32, [None, None, IdentityEmbeddings.DSZ], name=name)
DSZ = None
def __init__(self, name, **kwargs):
super(IdentityEmbeddings, self).__init__()
self.vsz = kwargs.get('vsz')
IdentityEmbeddings.DSZ = kwargs.get('dsz')
self.finetune = kwargs.get('finetune', True)
self.name = name
def save_md(self, target):
# Nothing to do
pass
def encode(self, x=None):
if x is None:
x = IdentityEmbeddings.create_placeholder(self.name)
self.x = x
return x
def get_vsz(self):
return self.vsz
# Warning this function is only initialized AFTER encode
def get_dsz(self):
return IdentityEmbeddings.DSZ
task: classify
modules:
- rawten
batchsz: 50
basedir: rawten
dataset: SST2
preproc:
clean: true
loader:
reader_type: raw
labels_json: ../../data/stsa.binary.labels
unif: 0.25
model:
model_type: default
filtsz: [3,4,5]
cmotsz: 100
dropout: 0.5
features:
- name: word
vectorizer:
type: raw
dsz: 300
embeddings:
type: identity
dsz: 300
train:
epochs: 2
optim: adadelta
eta: 1.0
early_stopping_metric: acc
verbose:
console: True
file: sst2-cm.csv
@dpressel
Copy link
Author

dpressel commented Nov 6, 2018

How to run this:

  1. run tensor-inputs.py in api-examples
  2. either make a datasets.json that reference the version of stsa.binary.* in baseline/data or copy the NPZ files to your bl-cache
  3. make sure your sst2-rawten.yml is pointing at the correct location of the stsa.binary.labels json file that was emitted by tensor-inputs.py
  4. make sure that rawten.py is somewhere in your PYTHONPATH so it can be picked up

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment