Created
November 6, 2018 20:33
-
-
Save dpressel/a81343a938e5214f0189d94b96914e3e to your computer and use it in GitHub Desktop.
mead-train --config config/sst2-rawten.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from baseline.reader import register_reader, SeqLabelReader | |
| from baseline.vectorizers import register_vectorizer, create_vectorizer, Vectorizer | |
| from collections import Counter | |
| from baseline.data import DataFeed | |
| from baseline.embeddings import register_embeddings | |
| from baseline.tf.embeddings import TensorFlowEmbeddings | |
| from baseline.utils import read_json | |
| import numpy as np | |
| import tensorflow as tf | |
| class RawFeed(DataFeed): | |
| def __init__(self, x, y): | |
| self.steps = 0 | |
| self.shuffle = False | |
| self.x = x | |
| self.y = y | |
| self.steps = len(y) | |
| self.word_lengths = (self.x.shape[1], self.x.shape[2]) | |
| def _batch(self, i): | |
| return {'word': self.x[i], 'word_lengths': self.word_lengths, 'y': self.y[i]} | |
| def __getitem__(self, i): | |
| return self._batch(i) | |
| def __iter__(self): | |
| shuffle = np.random.permutation(np.arange(self.steps)) if self.shuffle else np.arange(self.steps) | |
| for i in range(self.steps): | |
| si = shuffle[i] | |
| yield self._batch(si) | |
| def __len__(self): | |
| return self.steps | |
| @register_reader(task="classify", name="raw") | |
| class NPZSeqLabelReader(SeqLabelReader): | |
| """Read a bunch of tensors in from file and make the feed dict from this instead of words | |
| """ | |
| def __init__(self, _, __, labels_json=None, **kwargs): | |
| self.labels = read_json(labels_json) | |
| def build_vocab(self, files, **kwargs): | |
| return {'word': {}}, self.labels | |
| def load(self, filename, index, batchsz, shuffle=True, sort_key=None, **kwargs): | |
| v = np.load(filename + '.npz') | |
| x = v['arr_0'] | |
| print(x.shape) | |
| y = v['arr_1'] | |
| return RawFeed(x, y) | |
| @register_vectorizer(name='raw') | |
| class Raw2DVectorizer(Vectorizer): | |
| """Vectorizer that does nothing | |
| """ | |
| def __init__(self, **kwargs): | |
| super(Raw2DVectorizer, self).__init__() | |
| self.mxlen = kwargs.get('mxlen', -1) | |
| self.max_seen = 0 | |
| self.dsz = kwargs['dsz'] | |
| def count(self, tokens): | |
| """We dont bother really counting here, there would be nothing to count | |
| :param tokens: | |
| :return: | |
| """ | |
| seen = tokens.shape[0] | |
| self.dsz = tokens.shape[1] | |
| self.max_seen = max(self.max_seen, seen) | |
| return Counter() | |
| def run(self, tokens, vocab): | |
| """There is nothing to do here either, just forward along someone else's hard work | |
| :param tokens: | |
| :param vocab: | |
| :return: | |
| """ | |
| return tokens, tokens.shape[1] | |
| def get_dims(self): | |
| """The dims here are | |
| :return: | |
| """ | |
| self.mxlen, self.dsz | |
| def iterable(self, tokens): | |
| """TODO(dpressel): This looks like a method that is superfluous and shouldnt have been in the base class | |
| :param tokens: | |
| :return: | |
| """ | |
| raise Exception('Not implemented/unused') | |
| @register_embeddings(name='identity') | |
| class IdentityEmbeddings(TensorFlowEmbeddings): | |
| """This is a pass-through operation where the input is presumed to be the same as the output | |
| We still need a placeholder to pass the tensors through, but the `encode()` function doesnt do any work anymore | |
| """ | |
| @classmethod | |
| def create_placeholder(cls, name): | |
| return tf.placeholder(tf.float32, [None, None, IdentityEmbeddings.DSZ], name=name) | |
| DSZ = None | |
| def __init__(self, name, **kwargs): | |
| super(IdentityEmbeddings, self).__init__() | |
| self.vsz = kwargs.get('vsz') | |
| IdentityEmbeddings.DSZ = kwargs.get('dsz') | |
| self.finetune = kwargs.get('finetune', True) | |
| self.name = name | |
| def save_md(self, target): | |
| # Nothing to do | |
| pass | |
| def encode(self, x=None): | |
| if x is None: | |
| x = IdentityEmbeddings.create_placeholder(self.name) | |
| self.x = x | |
| return x | |
| def get_vsz(self): | |
| return self.vsz | |
| # Warning this function is only initialized AFTER encode | |
| def get_dsz(self): | |
| return IdentityEmbeddings.DSZ | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| task: classify | |
| modules: | |
| - rawten | |
| batchsz: 50 | |
| basedir: rawten | |
| dataset: SST2 | |
| preproc: | |
| clean: true | |
| loader: | |
| reader_type: raw | |
| labels_json: ../../data/stsa.binary.labels | |
| unif: 0.25 | |
| model: | |
| model_type: default | |
| filtsz: [3,4,5] | |
| cmotsz: 100 | |
| dropout: 0.5 | |
| features: | |
| - name: word | |
| vectorizer: | |
| type: raw | |
| dsz: 300 | |
| embeddings: | |
| type: identity | |
| dsz: 300 | |
| train: | |
| epochs: 2 | |
| optim: adadelta | |
| eta: 1.0 | |
| early_stopping_metric: acc | |
| verbose: | |
| console: True | |
| file: sst2-cm.csv | |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
How to run this:
tensor-inputs.pyinapi-examplesstsa.binary.*inbaseline/dataor copy the NPZ files to yourbl-cachesst2-rawten.ymlis pointing at the correct location of thestsa.binary.labelsjson file that was emitted bytensor-inputs.pyrawten.pyis somewhere in yourPYTHONPATHso it can be picked up