Last active
November 13, 2018 18:27
-
-
Save Hsankesara/7d9977ab629631c1a215800c519111b7 to your computer and use it in GitHub Desktop.
RNN Stacked Network with batchwise implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "name": "RNN_Stacked_v2.ipynb", | |
| "version": "0.3.2", | |
| "views": {}, | |
| "default_view": {}, | |
| "provenance": [], | |
| "collapsed_sections": [] | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "metadata": { | |
| "id": "-64Ln4xoEnUU", | |
| "colab_type": "text" | |
| }, | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Stacked RNN Network Implementation from scratch in tensorflow" | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "SrBj_rffoMyV", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "import pandas as pd\n", | |
| "import numpy as np" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "trP-g463ocib", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "poetry = pd.read_csv('poetry.csv')" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "iFw4SeSdvqxz", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "num_poems = 3" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "iV9_HxKcofkx", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "poem = poetry['content'][:num_poems]" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "V872HNj-py5B", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "temp = ''\n", | |
| "for i in range(num_poems):\n", | |
| " temp += poem[i] + '\\r\\n'\n", | |
| "poem = temp" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "qKbmBG2Voie-", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "poem = poem.replace('\\r\\n\\r\\n', '\\r\\n')" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "4K6E1qsSorpd", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "poem = poem.replace('\\r\\n', ' ')" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "wQKWX4mKo3O3", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "poem = poem.replace('\\'', '')" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "weBF_W_XEnWI", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "import re" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "bZ1_Zx99EnWi", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "poem = re.sub(' +',' ',poem)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "7SqYLgMvpng5", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "poem = poem.split()" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "MCKwmGYppvj1", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "words = list(set(poem))" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "BrQoDxwdEnXD", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "vocab_size = len(words)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "H2KOrF8xEnXR", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "from scipy.sparse import csr_matrix\n", | |
| "from scipy.sparse import vstack" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "RkyA-h_q5m1K", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| }, | |
| "cellView": "code" | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "strdict = {}\n", | |
| "revdict = {}\n", | |
| "i = 0\n", | |
| "for word in words:\n", | |
| " row = [0]\n", | |
| " col = [i]\n", | |
| " data = [1]\n", | |
| " strdict[word] =csr_matrix((data, (row, col)), shape=(1, vocab_size))\n", | |
| " revdict[i] = word\n", | |
| " i += 1" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "ktb11a_j5qne", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "def convert_to_df(start, size, seq, vsize, batch_size):\n", | |
| " word_count = len(seq)\n", | |
| " inp = np.array([])\n", | |
| " out = np.array([])\n", | |
| " for i in range(batch_size):\n", | |
| " if start >= word_count - size:\n", | |
| " break\n", | |
| " ones = seq[start:start + size]\n", | |
| " inp_vector = vstack([strdict[x] for x in ones])\n", | |
| " out_vec = strdict[seq[start + size]]\n", | |
| " if i == 0:\n", | |
| " inp = inp_vector.toarray()\n", | |
| " out = out_vec.toarray()\n", | |
| " else:\n", | |
| " inp = np.dstack((inp, inp_vector.toarray()))\n", | |
| " out = np.vstack((out, out_vec.toarray()))\n", | |
| " start += 1\n", | |
| " inp = np.swapaxes(inp, 2, 0)\n", | |
| " inp = np.swapaxes(inp, 1, 2)\n", | |
| " return inp, out" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "I47QFPnyqTtG", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "import tensorflow as tf" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "KcBhjDxhxfy3", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "def rnn_cell(W,Wo, U, b, prev_cell, curr):\n", | |
| " h = tf.add(tf.matmul(curr, U), tf.matmul(prev_cell, W)) + b\n", | |
| " out = tf.matmul(h, Wo)\n", | |
| " return h, out" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "LRkDKYtYrYrT", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "def rnn_layer(xin, We, Wo, Ue, be, num_inp):\n", | |
| " layer = tf.zeros((1, vocab_size))\n", | |
| " next_inp = []\n", | |
| " for i in range(num_inp):\n", | |
| " curr = xin[:, i]\n", | |
| " curr = tf.cast(curr, tf.float32)\n", | |
| " layer = tf.cast(layer, tf.float32)\n", | |
| " layer, out = rnn_cell(We, Wo, Ue, be, layer, curr)\n", | |
| " next_inp.append(out)\n", | |
| " next_inp = tf.stack(next_inp)\n", | |
| " next_inp = tf.transpose(next_inp, [1, 0, 2])\n", | |
| " return next_inp" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "1NMOgovPqnzr", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "learning_rate = 0.0001\n", | |
| "training_iters = 200\n", | |
| "display_step = 20\n", | |
| "num_inp = 4\n", | |
| "n_hidden = 3\n", | |
| "m = len(poem)\n", | |
| "batch_size = 64" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "0tqele7awk0O", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "x = tf.placeholder(tf.float64, [None, num_inp, vocab_size])\n", | |
| "y = tf.placeholder(tf.float64, [None, vocab_size])" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "FhqUFQyMw3SG", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "W = {}\n", | |
| "U = {}\n", | |
| "WO = {}\n", | |
| "b = {}\n", | |
| "for i in range(1, n_hidden + 1):\n", | |
| " W[i] = tf.Variable(tf.random_normal([vocab_size, vocab_size]))\n", | |
| " WO[i] = tf.Variable(tf.random_normal([vocab_size, vocab_size]))\n", | |
| " U[i] = tf.Variable(tf.random_normal([vocab_size, vocab_size]))\n", | |
| " b[i] = tf.Variable(tf.zeros([1, vocab_size]))" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "BDpm960uxe74", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| }, | |
| "base_uri": "https://localhost:8080/", | |
| "height": 35 | |
| }, | |
| "outputId": "4781e593-c48a-4819-ab03-f490086071eb", | |
| "executionInfo": { | |
| "status": "ok", | |
| "timestamp": 1528389829106, | |
| "user_tz": -330, | |
| "elapsed": 2171, | |
| "user": { | |
| "displayName": "Heet Sankesara", | |
| "photoUrl": "//lh5.googleusercontent.com/-mO3PS3oBtRQ/AAAAAAAAAAI/AAAAAAAAADs/_ic_rsddWU4/s50-c-k-no/photo.jpg", | |
| "userId": "112240562069909648160" | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "tf.device(\"/device:GPU:0\")" | |
| ], | |
| "execution_count": 23, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "<contextlib._GeneratorContextManager at 0x7f0ce395f4a8>" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 23 | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "6xKblWN4yRsq", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "prev_inp = x\n", | |
| "for i in range(1, n_hidden + 1):\n", | |
| " prev_inp = rnn_layer(prev_inp, W[i],WO[i], U[i], b[i], num_inp)\n", | |
| "dense_layer1 = prev_inp[:, -1]" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "zeY4biq1I5Vz", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "pred = tf.nn.softmax(dense_layer1)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "tbSNM77mzCL1", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=dense_layer1, labels=y))\n", | |
| "optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "K-4QkAIL2RxA", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "# Model evaluation\n", | |
| "correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))\n", | |
| "accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "Pz0OUV1H4oms", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "saver = tf.train.Saver()\n", | |
| "init = tf.global_variables_initializer()" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "e3AkyEMS4sha", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| }, | |
| "base_uri": "https://localhost:8080/", | |
| "height": 287 | |
| }, | |
| "outputId": "6638ef1d-31fe-4385-e011-85e989467130", | |
| "executionInfo": { | |
| "status": "ok", | |
| "timestamp": 1528389921849, | |
| "user_tz": -330, | |
| "elapsed": 78022, | |
| "user": { | |
| "displayName": "Heet Sankesara", | |
| "photoUrl": "//lh5.googleusercontent.com/-mO3PS3oBtRQ/AAAAAAAAAAI/AAAAAAAAADs/_ic_rsddWU4/s50-c-k-no/photo.jpg", | |
| "userId": "112240562069909648160" | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "with tf.Session() as sess:\n", | |
| " sess.run(init)\n", | |
| " window_size = m - num_inp\n", | |
| " for epoch in range(training_iters):\n", | |
| " cst_total = 0\n", | |
| " acc_total = 0\n", | |
| " total_batches = int(np.ceil(window_size / batch_size))\n", | |
| " for i in range(total_batches):\n", | |
| " df_x, df_y = convert_to_df(i, num_inp, poem, vocab_size, batch_size)\n", | |
| " _, cst, acc = sess.run([optimizer, cost, accuracy], feed_dict = {x : df_x, y : df_y})\n", | |
| " cst_total += cst\n", | |
| " acc_total += acc\n", | |
| " if (epoch + 1) % display_step == 0:\n", | |
| " print('After ', (epoch + 1), 'iterations: Cost = ', cst_total / total_batches, 'and Accuracy: ', acc_total * 100 / total_batches, '%' )\n", | |
| " print('Optimiation finished!!!')\n", | |
| " save_path = saver.save(sess, \"../working/model.ckpt\")\n", | |
| " print(\"Model saved in path: %s\" % save_path)\n", | |
| " print(\"Lets test\")\n", | |
| " sentence = 'If my imagination from'\n", | |
| " sent = sentence.split()\n", | |
| " sent = vstack([strdict[s] for s in sent])\n", | |
| " sent = sent.toarray()\n", | |
| " for q in range(64):\n", | |
| " one_hot = sess.run(pred, feed_dict={x : sent.reshape((1, num_inp, vocab_size))})\n", | |
| " index = int(tf.argmax(one_hot, 1).eval())\n", | |
| " sentence = \"%s %s\" % (sentence,revdict[index])\n", | |
| " sent = sent[1:]\n", | |
| " sent = np.vstack((sent, one_hot))\n", | |
| " print(sentence)\n", | |
| " " | |
| ], | |
| "execution_count": 29, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "After 20 iterations: Cost = 1765896135.1111112 and Accuracy: 82.29166666666667 %\n", | |
| "After 40 iterations: Cost = 440820942.2222222 and Accuracy: 86.97916666666667 %\n", | |
| "After 60 iterations: Cost = 192934734.2222222 and Accuracy: 90.27777777777777 %\n", | |
| "After 80 iterations: Cost = 91874517.33333333 and Accuracy: 91.14583333333333 %\n", | |
| "After 100 iterations: Cost = 57729813.333333336 and Accuracy: 92.53472222222223 %\n", | |
| "After 120 iterations: Cost = 29111367.111111112 and Accuracy: 94.44444444444444 %\n", | |
| "After 140 iterations: Cost = 20999146.666666668 and Accuracy: 94.79166666666667 %\n", | |
| "After 160 iterations: Cost = 14925162.666666666 and Accuracy: 95.83333333333333 %\n", | |
| "After 180 iterations: Cost = 14771164.444444444 and Accuracy: 96.35416666666667 %\n", | |
| "After 200 iterations: Cost = 14433827.555555556 and Accuracy: 97.04861111111111 %\n", | |
| "Optimiation finished!!!\n", | |
| "Model saved in path: ../working/model.ckpt\n", | |
| "Lets test\n", | |
| "If my imagination from neither, been writing double humble no Both thoughts posterity: sable rest, And Whereupon tree threne service mutual trumpet uncontrolled That lay sound Queen; Be Arabian But see tell and Foul well defunctive Two yet Twixt be fiend, I what Turtle fevers soft, self be; troop the to Simple near. lay troop Distance interdict Arabian near. of tyrant and interdict the eagle, Two tyrant Keep\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "id": "iuOKuKphEnbw", | |
| "colab_type": "code", | |
| "colab": { | |
| "autoexec": { | |
| "startup": false, | |
| "wait_interval": 0 | |
| } | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment