Skip to content

Instantly share code, notes, and snippets.

@tanutarou
Created September 27, 2018 18:58
Show Gist options
  • Select an option

  • Save tanutarou/bbec3a80baf5fe7d21ebf3d92b7966d7 to your computer and use it in GitHub Desktop.

Select an option

Save tanutarou/bbec3a80baf5fe7d21ebf3d92b7966d7 to your computer and use it in GitHub Desktop.
[WIP] bayesian polynomial fitting example by using tensorflow probability
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from IPython.core.pylabtools import figsize\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline\n",
"import numpy as np\n",
"figsize(12.5, 4)\n",
"import seaborn as sns\n",
"sns.set(context='paper', style='darkgrid', rc={'figure.facecolor':'white'}, font_scale=1.2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# データの準備"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True parameters\n",
"w=[ 7.02419783 1.10560521 -0.72663653]\n",
"b=0.5\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x7ff3ab3e6198>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD+CAYAAADfwXXpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAEtFJREFUeJzt3VFsZNV9x/HfzNi7eJZFY1YeNoS1QQ39i0CjAJE2FCJRqQV206aINA8prSoVpCpRozwkSlQ1VfOQiqpVRLqpQiJBFNKmPASRVApeUPpQQRN2JRTIamHzpzTsLtuyGWfjacx6qL2e6YPHideZse94zsyde+73IyHsu9d3zsH457P/e+5/Cq1WSwCAOBXTHgAAYHAIeQCIGCEPABEj5AEgYoQ8AESMkAeAiBHyABAxQh4AIkbIA0DECHkAiNhY0hPNbJ+kr0m6UlJT0pOSPu3uLTP7uKSPtU895O6HEl6WngoAsD2FRCcl7V1jZm+T9HZ3f97Mdkj6rqRDko5JmpV0U/vUH0g64O6vJrhsa25uIdHrZ02lUla9vpj2MIYmT/PN01wl5juKpqZ2SwlDPnG5xt3fcPfn2x8vaTXc90m6R9I33X3B3RckPd4+BgBIWeJyzXpmtkfS3ZLukPRRSS+v++PTkt6Z9FqVSnk7Qxh5pVIx2rl1kqf55mmuEvPNup5D3sx2anW1/gV3P2Fmif7K0M2o/7Vou7LwV76Q8jTfPM1VYr6jqF2uSaSn3TVmVpL0DUkvuPvn24df12rZZs20pDO9XBcAMBi9ruS/ImlB0ifWHfuWpCfN7IH25x+SdDDA2AAAfeplC+Wtku6TdFzSC2YmSV9190Nm9iVJL7ZP/aK7vxJ8pAAQgVq9odr8oqqTZVUrEwN/vcQh7+7fU5ctO+7+oKQHQw0KAGJ0+OgpzT53Ss1mS8ViQQdvmdGB/TMDfU2eeAWAIajVG5p97pTOv3VBjaUVnX/rgmaPnFat3hjo6xLyADAEtflFNZsXP3zabDY1N0/IA0DmVSfLKhYvrngXi0VNTQ62Lk/IA8AQVCsTOnjLjHZNjGtiZ0m7Jsb1/vdOD/zm67aeeAUA9O7A/hndbFXNzTc0NTkxWrtrAAD9q1aGE+5rKNcAQMQIeQCIGCEPABEj5AEgYoQ8AESMkAeAiBHyABAxQh4AIkbIA0DECHkAiBghDwARI+QBIGKEPABEjJAHgIgR8gAQMUIeACJGyANAxAh5AIgYIQ8AESPkASBihDwARIyQB4CIjSU90cwekvQBSVe6e6F97HZJ35H0avu0V939D0IPEgCwPYlDXtJjkj4r6eyG48+7++2hBgQACCdxyLv7M5JkZoMbDQAgqF5W8t2828xekPSmpM+5+9MBrgkACKDQarV6+gIza62ryV8mSe7+czN7l6TDkm5z99cSXq61vLzS0+tnRalU1MpKM+1hDE2e5punuUrMdxSNj5ckqZDk3L5W8u7+83UfHzOz5yTdJClpyKteX+xnCCOrUilHO7dO8jTfPM1VYr6jaGpqd+Jz+9pCaWZvM7O1Vf1VkvZLeqmfawIAwullC+XDku5qf3xG0lOSXpT0ETNb1upfHf7C3X80iIECAHrXy+6a+7v80T8GGgsAjJRavaHa/KKqk2VVKxNpD2dbQuyuAYDoHD56SrPPnVKz2VKxWNDBW2Z0YP9M2sPqGW0NAGCDWr2h2edO6fxbF9RYWtH5ty5o9shp1eqNtIfWM0IeADaozS+q2bx4e3mz2dTcPCEPAJlXnSyrWLx4G3qxWNTUZPbq8oQ8AGxQrUzo4C0z2jUxromdJe2aGNf73zudyZuv3HgFgA4O7J/RzVbV3HxDU5MTmQx4iZAHgK6qleyG+xrKNQAQMUIeACJGyANAxAh5AIgYIQ8AESPkASBihDwARIyQB4CIEfIAcq1Wb+j4a+cy2WEyCZ54BZBbnXrGf/jO69IeVlCs5AHkUree8WfPnU97aEER8gByqVvP+LM/W0xpRINByAPIpW494/deXk5pRINByAPIpW494/fu2ZX20ILixiuA3IqlZ/xmCHkAQdXqDdXmF1WdLGciNGPoGb8ZQh5AMJ22JB7YP5P2sHKNmjyAILptSYz1IaOsIOQBBNFtS+LcPCGfJkIeQBDdtiROTcZb784CQh5AEN22JMZ8UzMLuPEKoKtavaGTtfMq7ygmCus8bEnMmsQhb2YPSfqApCvdvbDu+Mclfaz96SF3PxR2iADSsLZTpiWpICXeKRP7lsSs6aVc85ikm9YfMLNrJf25pBvb/3zMzN4RbngA0rB+p8ziWxfYKZNhiUPe3Z9x959sOHyPpG+6+4K7L0h6vH0MQIaxUyYe/dbkr5L08rrPT0t6Zy8XqFTiaga0plQqRju3TvI03zzM9R3TLY2NFaWllV8cGysV9WvTk9HPPbbvb78hX9j6lM3V63G19VxTqZSjnVsneZpvHuZ6Samgu/ZPa/bIabVaLRUKBR3YP61LSoXo556F7+/U1O7E5/Yb8q9L2rfu82lJZ/q8JoARsLZTZnG5qfL4amX3+GvnMtOTBqv6DflvSXrSzB5of/4hSQf7vCaAEVGtTKhSKeuxp0/QkyajetlC+bCku9ofn5H0lLvfb2ZfkvRi+7Qvuvsr4YcJIC1nz53/xU6bNbNHTutmq7Kiz4DEIe/u93c5/qCkB4ONCMBIOXuu+04bQn700dYAwKb27qEnTZYR8gA2tXfPLnrSZBi9awBsiZ402UXIA0iEnjTZRLkGACJGyANAxAh5AIgYIQ8AESPkASBihDyQE7V6Q8dfO8cbf+QMWyiBHFh7Kz8ajOUPK3kgcuvfyq+xtMJb+eUMIQ9EjrfyyzdCHohcdZIGY3lGyAORq1YmaDCWY9x4BXKABmP5RcgDOUGDsXyiXAMAESPkASBihDwARIyQBzKo1xYFtDTIL268AhnTa4uCUC0NavWGavOLqk6WuYGbIYQ8kCHrWxSsmT1yWjdbtWPw9np+N/S+yS7KNUCG9NqiIERLg7PnztP7JsMIeSBDem1REKKlwdlz9L7JMkIeyJBeWxSEaGmwdw+9b7KMmjyQMb22KOi3pcHePbt08JYZzR45rWazqWKxSO+bDCHkgQzqtUVBvy0N6H2TXYQ8gETofZNNwULezE5KWpS01D70h+7+cqjrAwB6F3olf9DdTwa+JgBgm9hdAwARK7Rara3PSqBdrqlLKkj6jqTPuvvyFl/WWl5eCfL6o6ZUKmplpZn2MIYmT/PN01wl5juKxsdL0mrWbilkueY2dz9jZpdK+idJn5T0wFZfVK8vBhzC6KhUytHOrZM8zbffuWatB0yevrdSNuY7NbU78bnBQt7dz7T//aaZPSLpz0Jdu5Os/aAAEj1gMHxBQt7MdkkqufvPzWxM0gclHQtx7U74QUEWhWoWBvQi1I3XKyQ9Y2bHJP1Q0oqkvwl07Yus/0GhWRKyJESzMKBXQVby7v5jSe8Oca2tbPaDwmoIadqqhBiiWRjQq8w98coPCkZRkhLiWrOwfnvAJLkfxT0rrMleyAf6QQFC6aXW3m8PmCS/TLhnhfUyF/ISzZIwWnotIW63B0ySXybc3MVGmX3itVqZ0PXXXM7/uEjdsEqISW7ccnMXG2U25IFREeKNORK9ToJfJtyzwkaZLNcAo2a7JcRebpAmuR/FPStsFKx3zTa15uYW0nz9gcnCo9Eh5Wm+oea63RuktXrjol8mnX5RbDynH3n63krZmG+7rcHQe9cASKifG6Trb9x2+0XBG3xgDTV5IAUhbpDy9DeSIOSBFIS4QfrK63Utb2iJy04abETIAynod0fO4aOn9Nh3X9HS8sUhz04abERNHkhJPztyZp87pcbSxW+4U945xk4a/ApCHkjRdm6Qdqrn7xgv6sO/fa1u/Y23hRweIkC5BsiYTvX88bGSrt1XSWlEGGWEPJAxw3rCFnGgXANkUKd6Pu2F0QkhD2RUkoeiAMo1QMbxUBQ2Q8gDGUd7YWyGkAcyjvbC2AwhD2Qcu22wGW68AhHgLTHRTe5Cnm1miBXthdFJrkKebWYA8iY3NXm2mQHIo/yEPNvMkIJavaHjr51jMYHU5KZcwzYzDBvlQYyC3Kzk2WaGYaI8iFERbCVvZjdI+rqk3ZJOSLrX3RdCXT8Etplh0NZ2b9XfXOpaHuT/OwxTyHLNlyV9xt1nzezvJH1K0l8FvH4QbDPDoKwvz0jSyoaQpzyINAQp15jZFZKucffZ9qFHJH0wxLWBLNhYnmksrail1bfkozyINIVayV8l6cy6z09L2pfkCyuVcqAhjJZSqRjt3DrJ03w7zfVk7bxaG84bHyvqvt+7XpOXXaK9l5e1d8+u4Q0yoDx9b6X45hsq5Atbn9JZvb4YaAijpVIpRzu3TvI0305zLe8o/soPQaFQ0Nv3/PLJ6qz+98nT91bKxnynpnYnPjfU7pozWl3Nr5nWxSt7IGrs3sKoCrKSd/ezZnbSzA626/L3SXoixLWBrGD3FkZRyN01H5H0qJn9gySXdG/AawOZwO4tjJpgIe/uxyTdGOp6AID+5eaJVwDII0IeACJGyANAxAh5YAu0C0aW5abVMLAdndoFf/jO69IeFpAYK3mgbeOKvVu74LPnzqc8UiA5VvKAOq/Y91Uv7dgu+OzPFnX1VDb70CB/WMkj97qt2EulYsd3E9t7eTzNqxA/Qh651+39f5srrY79aLLaTRL5RLkGubfZ+/9ef83l9KNBprGSR+5t1UGyWlkNewIeWcRKHhAdJBEvQh5oo4MkYkS5BgAiRsgDQMQIeQCIGCGP4GjoBYwObrwiqE7tAQ7sn0l7WEBusZJHMN3aA7CiB9JDyCOYbu0B5uYJeSAthDyC2aw9AIB0EPIIZqv2AACGjxuvSKRWb6g2v6jqZHnT0KY9ADBaCHlsqdcdM7QHAEYH5Rpsih0zQLYR8tgUO2aAbCPksSl2zADZRshjU+yYAbKNG6/YEjtmgOzqK+TN7GpJr0h6uX1o0d1/s99BYfSwYwbIphAr+f9x93cHuA4AIDBq8gAQsUKr1dr6rC7a5ZoTWi3XrEg65O7/3MMlWsvLK9t+/VFWKhW1stJMexhDk6f55mmuEvMdRePjJUkqbHWelCDkzewHkqY7/NHTkv5U0m53/6mZzUj6N0l/4u7fTzjW1tzcQsJTs6VSKateX/zF50nbAmTVxvnGLE9zlZjvKJqa2i0lDPkta/LuftMWp/xf+7xTZvavkm6RlDTkc4E30gCQlr5q8mZWNbOx9seTku6QdCzEwGJBWwAAaer3xuttkl40sx9KelbSo+7+3f6HFY+8tAU4e+487+sKjKC+tlC6+xOSngg0lijloS3A4aOn9NTR07pwoUk5ChgxbKEcsNjbAqyVoxYWlylHASOItgZDEHNbgM3KUTHNE8gqQn5IYm0LkIdyFJBllGvQl7Vy1O5ynOUoIOtYyaNvB/bP6LfeM63/Oj0fXTkKyDpCHkHs3bNLl5QSPYAHYIgo1wBAxAj5HKnVGzywBOQM5ZqcOHz0lL7zvZNaabZUKhb0u7dezQNLQA6wks+BWr2hbz/7YzWWVrR0oanG0oq+/exrrOiBHCDkc+CV1+e1fOHiB5aWLzT1n6/XUxoRgGEh5HOBXS9AXhHyOfDr+yraMXbxt3rHWFHX7qukNCIAw0LI50C1MqHff981Ku8c047xoso7x3T3+67hoSUgB9hdkxMxN0kD0B0hnyOxNkkD0B3lGgCIGCEPABEj5AEgYoQ8AESMkAeAiBVardbWZw1Oqi8OABmW6FH2tLdQ8rw9AAwQ5RoAiBghDwARI+QBIGKEPABEjJAHgIgR8gAQMUIeACKW9j75aJnZ70h6QNIlkpqS/tbd/yXdUQ2OmV0m6QlJ75H0orvfnu6IBsPMbpD0dUm7JZ2QdK+7L6Q7qsEws4ckfUDSle4e9TMtZrZP0tckXanVn9cnJX3a3TP/wCYr+cH5qaS73f0GSXdKetDM9qY8pkFakvQ5SX+U9kAG7MuSPuPu10r6kaRPpTyeQXpM0k1pD2JILmg11K+TdKOk/ZLuSXdIYRDyA+LuL7j7mfbHb0iqSYo25N39LXf/d0lvpj2WQTGzKyRd4+6z7UOPSPpgikMaKHd/xt1/kvY4hsHd33D359sfL0k6JmlfuqMKg3LNEJjZbZJ2SXop7bGgL1dJOrPu89OKJAjwS2a2R9Ldku5IeywhEPJ9MLMfSJru8EdPu/u97XNmtFrD/WN3Xx7m+EJLMt/IRV2XhmRmOyU9LukL7n4i7fGEQMj3wd03rVeaWVXSYUmfdPfvDWdUg7PVfHPgjFZX82umdfHKHhlmZiVJ35D0grt/Pu3xhEJNfkDau00OS/p7d38i7fGgf+5+VtJJMzvYPnSfVncUIQ5fkbQg6RNpDySktPvJR8vM/lLSZyT5usMfdffvpzSkgTOzlyRVJV0maU7SX7v7I+mOKiwze5ekRyVdqtXv7b3u/r/pjmowzOxhSXdJeruk/5b0lLvfn+6oBsPMbpX0H5KOS1ppH/6qux9Kb1RhEPIAEDHKNQAQMUIeACJGyANAxAh5AIgYIQ8AESPkASBihDwAROz/AcjAr+n2l9c5AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"np.random.seed(39)\n",
"def build_toy_dataset(N, w, b, noise_std=2.0):\n",
" D = len(w)\n",
" x = np.sort(np.random.randn(N))\n",
" xx = np.array([[r ** i for i in range(1, D+1)] for r in x])\n",
" y = np.dot(xx, w) + b + np.random.normal(0, noise_std, size=N)\n",
" return x, y\n",
"\n",
"N = 30 # データ数\n",
"D = 3 # データの次元数\n",
"\n",
"# 真のパラメータの設定\n",
"w_true = 5*np.random.randn(D)\n",
"b_true = 0.5\n",
"print(\"True parameters\")\n",
"print(\"w={}\".format(w_true))\n",
"print(\"b={}\".format(b_true))\n",
"X_train, y_train = build_toy_dataset(N, w_true, b_true)\n",
"\n",
"# plot\n",
"plt.scatter(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ベイズ線形回帰モデルで多項式あてはめを行う。\n",
"\n",
"以下のようにモデル化を行う。\n",
"\n",
"$\\bf{w}$の事前分布: $$p(\\bf{w}) = Normal(\\bf{w} | 0, \\sigma_w^2 \\bf{I})$$\n",
"\n",
"$b$の事前分布: $$p(b) = Normal(b | 0, \\sigma_b^2)$$\n",
"\n",
"予測分布: 平均が$\\sum_{j=1}^D w_j x^j+b$の正規分布からサンプルされたと考え、各サンプルが独立に得られたと仮定すれば、サンプルの同時確率は以下のようになる。 $$ p(\\bf{y}| \\bf{w}, b, \\bf{X}) = \\prod_{n=1}^N Normal(y_n | \\sum_{j=1}^D w_j x^j+b, \\sigma_y^2) $$\n",
"\n",
"ここで、ハイパパラメータ$\\sigma_w^2, \\sigma_b^2, \\sigma_y^2$は既知であるとする。$\\sigma_w^2=1, \\sigma_b^2=1, \\sigma_y^2=2$としてモデルを作ろう。"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow_probability as tfp\n",
"import tensorflow as tf\n",
"from tensorflow_probability import edward2 as ed\n",
"\n",
"def poly_model(MD, X):\n",
" \"\"\"\n",
" MD : モデルの次元\n",
" \"\"\"\n",
" w = ed.Normal(loc=tf.zeros(MD), scale=tf.ones(MD), name=\"w\")\n",
" b = ed.Normal(loc=tf.zeros(1), scale=tf.ones(1), name=\"b\")\n",
" y = ed.Normal(loc=tf.reduce_sum(tf.multiply(X, w)) + b, scale=tf.ones(N)*2.0, name=\"y\")\n",
" return y"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[17.73636 14.759049 16.072546 15.099797 14.31157 13.033747 13.001117\n",
" 19.633028 18.88944 15.825086 16.908482 14.417656 14.401501 15.917233\n",
" 18.199106 20.222195 14.134189 14.798289 20.22827 16.089785 18.989973\n",
" 15.795736 16.45015 18.50302 17.336428 14.673143 17.309721 15.50599\n",
" 19.046062 16.782604]\n"
]
}
],
"source": [
"MD = 3\n",
"XX_train = np.array([X_train**i for i in range(1, MD+1)], dtype=np.float32)\n",
"XX_train = XX_train.T\n",
"y = poly_model(MD=MD, X=XX_train)\n",
"with tf.Session() as sess:\n",
" print(sess.run(y))"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"# ベイズ推定\n",
"T = 10000 # Number of samples\n",
"nburn = 100 # Number of burn-in samples\n",
"stride = 10 # Frequency with which to plot samples\n",
"\n",
"qw = tf.nn.softplus(tf.random_normal([MD]))\n",
"qb = tf.nn.softplus(tf.random_normal([1]))\n",
"qy = tf.nn.softplus(tf.random_normal([len(XX_train)]))\n",
"\n",
"log_joint = ed.make_log_joint_fn(poly_model)\n",
"\n",
"def target_log_prob_fn(w, b, y):\n",
" \"\"\" Target log-probability as a function of states \"\"\"\n",
" return log_joint(MD, XX_train, w=w, b=b, y=y)\n",
"\n",
"hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(\n",
" target_log_prob_fn=target_log_prob_fn,\n",
" step_size=0.01,\n",
" num_leapfrog_steps=5)\n",
"states, kernels_results = tfp.mcmc.sample_chain(\n",
" num_results=T,\n",
" current_state=[qw, qb, qy],\n",
" kernel=hmc_kernel,\n",
" num_burnin_steps=nburn)\n",
"avg_w, avg_b, avg_y = states"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"with tf.Session() as sess:\n",
" [avg_w, avg_b, avg_y] = sess.run([avg_w, avg_b, avg_y])"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.266464\n"
]
}
],
"source": [
"print(avg_b.mean())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"toc_cell": false,
"toc_position": {},
"toc_section_display": "block",
"toc_window_display": false
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment