Created
May 27, 2019 14:43
-
-
Save narrowlyapplicable/98d12825d0a1c422da500c758b93fd95 to your computer and use it in GitHub Desktop.
GPyのGPRegressionにPriorを設定し、MCMCで推定する方法
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# GPy(GPregression)に関する備忘録\n", | |
| "- GPRegressionをベイズ的に使うためのTips\n", | |
| " - hyperparametersに対するPrior設定法\n", | |
| " - MCMCによるhyperparameter推定法\n", | |
| " - 結果の様々なplot方法" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# 0. Install\n", | |
| "- (2019.05時点で)python3.7で\"pip install GPy\"しようとするとエラーが出ることがある。\n", | |
| " - python3.6用の環境を別に用意する方が無難。\n", | |
| "- 加えてpipからinstallした場合、GPyのmodel.plot()が機能しないことがあった。\n", | |
| " - \"conda install -c conda-forge gpy\"を推奨。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import matplotlib.pyplot as plt\n", | |
| "import numpy as np\n", | |
| "import pandas as pd\n", | |
| "import GPy" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "plt.style.use('ggplot')\n", | |
| "np.random.seed(0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# 1. Data\n", | |
| "- 前回\"GP&TPonSinWave\"と同じものを使う。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "X = np.arange(-5,5,0.5)\n", | |
| "truY = np.sin(X)\n", | |
| "Y = truY + np.random.normal(loc=0, scale=0.1, size=X.shape[0])\n", | |
| "X_new = np.arange(-7,7,0.05)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "fig, ax = plt.subplots()\n", | |
| "ax.plot(X, Y, color=\"C1\", label=\"pred\", marker=\"o\", linestyle=\"\")\n", | |
| "ax.plot(X, truY, color=\"C0\", label=\"true\")\n", | |
| "ax.legend(edgecolor=\"b\")\n", | |
| "fig.tight_layout()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# 2. Model\n", | |
| "- 最も使用機会が多いであろう、GPRegressionベースで検討する。\n", | |
| " - GPと異なり、デフォルトでガウス分布に従う観測ノイズ(\"Gaussian_noise\")が設定されている。\n", | |
| " - 用意されていないモデル(ポワソン回帰etc.)を定義したい場合はGPから構築することになる。ポワソン回帰の例は<http://statmodeling.hatenablog.com/entry/how-to-use-GPy>に紹介されている。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| " /anaconda3/envs/py36/lib/python3.6/site-packages/matplotlib/figure.py:2369: UserWarning:This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "kern = GPy.kern.Matern52(input_dim=1, ARD=True)\n", | |
| "gpmodel = GPy.models.GPRegression(X.reshape((-1,1)), Y.reshape((-1,1)), kernel=kern, normalizer=True)\n", | |
| "gpmodel.randomize()\n", | |
| "gpmodel.plot()\n", | |
| "plt.plot(X_new, np.sin(X_new), color=\"r\", linestyle=\"dotted\", alpha=0.7, linewidth=2)\n", | |
| "plt.tight_layout()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "1. 入力データの次元(input_dim)を1に設定。1次元なのでARDには意味がないが、多次元であれば各次元ごとにlengthscaleが設定される。\n", | |
| "2. モデル定義。注意点はデータをndim=2に設定する必要があること。\n", | |
| "3. hyperparameterの初期値を乱数設定している。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style type=\"text/css\">\n", | |
| ".pd{\n", | |
| " font-family: \"Courier New\", Courier, monospace !important;\n", | |
| " width: 100%;\n", | |
| " padding: 3px;\n", | |
| "}\n", | |
| "</style>\n", | |
| "\n", | |
| "<p class=pd>\n", | |
| "<b>Model</b>: GP regression<br>\n", | |
| "<b>Objective</b>: 28.045667621230827<br>\n", | |
| "<b>Number of Parameters</b>: 3<br>\n", | |
| "<b>Number of Optimization Parameters</b>: 3<br>\n", | |
| "<b>Updates</b>: True<br>\n", | |
| "</p>\n", | |
| "<style type=\"text/css\">\n", | |
| ".tg {font-family:\"Courier New\", Courier, monospace !important;padding:2px 3px;word-break:normal;border-collapse:collapse;border-spacing:0;border-color:#DCDCDC;margin:0px auto;width:100%;}\n", | |
| ".tg td{font-family:\"Courier New\", Courier, monospace !important;font-weight:bold;color:#444;background-color:#F7FDFA;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#DCDCDC;}\n", | |
| ".tg th{font-family:\"Courier New\", Courier, monospace !important;font-weight:normal;color:#fff;background-color:#26ADE4;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#DCDCDC;}\n", | |
| ".tg .tg-left{font-family:\"Courier New\", Courier, monospace !important;font-weight:normal;text-align:left;}\n", | |
| ".tg .tg-center{font-family:\"Courier New\", Courier, monospace !important;font-weight:normal;text-align:center;}\n", | |
| ".tg .tg-right{font-family:\"Courier New\", Courier, monospace !important;font-weight:normal;text-align:right;}\n", | |
| "</style>\n", | |
| "<table class=\"tg\"><tr><th><b> GP_regression. </b></th><th><b> value</b></th><th><b>constraints</b></th><th><b>priors</b></th></tr>\n", | |
| "<tr><td class=tg-left> Mat52.variance </td><td class=tg-right>0.07496698473222704</td><td class=tg-center> +ve </td><td class=tg-center> </td></tr>\n", | |
| "<tr><td class=tg-left> Mat52.lengthscale </td><td class=tg-right> 1.0724342655855537</td><td class=tg-center> +ve </td><td class=tg-center> </td></tr>\n", | |
| "<tr><td class=tg-left> Gaussian_noise.variance</td><td class=tg-right> 1.2160004114647829</td><td class=tg-center> +ve </td><td class=tg-center> </td></tr>\n", | |
| "</table>" | |
| ], | |
| "text/plain": [ | |
| "<GPy.models.gp_regression.GPRegression at 0x1a19249550>" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "gpmodel" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# 3. Priors\n", | |
| "- hyperparameterに事前分布を設定するには、\n", | |
| " 1. GPy.priorsから事前分布を選択する。\n", | |
| " 2. set_prior()により特定のhyperparameterに設定する。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "norm = GPy.priors.Gaussian(mu=0, sigma=1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "ここで定義したガウス事前分布を、観測ノイズのvarianceに対して設定したい。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Gaussian_noise.variance\n" | |
| ] | |
| }, | |
| { | |
| "ename": "AssertionError", | |
| "evalue": "Domain of prior and constraint have to match, please unconstrain if you REALLY wish to use this prior", | |
| "output_type": "error", | |
| "traceback": [ | |
| "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
| "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", | |
| "\u001b[0;32m<ipython-input-8-0f073bde4163>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mparam_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgpmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameter_names\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam_name\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mgpmodel\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mparam_name\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_prior\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
| "\u001b[0;32m/anaconda3/envs/py36/lib/python3.6/site-packages/GPy/core/parameterization/priorizable.py\u001b[0m in \u001b[0;36mset_prior\u001b[0;34m(self, prior, warning)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mprior\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdomain\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0m_REAL\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mrav_i\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_raveled_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0m__fixed__\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdomain\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0m_REAL\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcon\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcon\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconstraints\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproperties_for\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrav_i\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Domain of prior and constraint have to match, please unconstrain if you REALLY wish to use this prior'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0munset_priors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mpriors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
| "\u001b[0;31mAssertionError\u001b[0m: Domain of prior and constraint have to match, please unconstrain if you REALLY wish to use this prior" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "param_name = gpmodel.parameter_names()\n", | |
| "print(param_name[2])\n", | |
| "gpmodel[param_name[2]].set_prior(norm)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "- priorのdomainが違うと警告が出た。 \n", | |
| " - 改めてgpmodelのconstaintsの欄を見ると\"ve+\", すなわち正の実数値に制約されていることがわかる。\n", | |
| " - 一方で事前分布側のdomainを確認すると…" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'real'" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "norm.domain" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "これを正の実数に制約すれば良い。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "reconstraining parameters GP_regression.Gaussian_noise.variance\n", | |
| "reconstraining parameters GP_regression.Gaussian_noise.variance\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style type=\"text/css\">\n", | |
| ".pd{\n", | |
| " font-family: \"Courier New\", Courier, monospace !important;\n", | |
| " width: 100%;\n", | |
| " padding: 3px;\n", | |
| "}\n", | |
| "</style>\n", | |
| "\n", | |
| "<p class=pd>\n", | |
| "<b>Model</b>: GP regression<br>\n", | |
| "<b>Objective</b>: 30.05549886738204<br>\n", | |
| "<b>Number of Parameters</b>: 3<br>\n", | |
| "<b>Number of Optimization Parameters</b>: 3<br>\n", | |
| "<b>Updates</b>: True<br>\n", | |
| "</p>\n", | |
| "<style type=\"text/css\">\n", | |
| ".tg {font-family:\"Courier New\", Courier, monospace !important;padding:2px 3px;word-break:normal;border-collapse:collapse;border-spacing:0;border-color:#DCDCDC;margin:0px auto;width:100%;}\n", | |
| ".tg td{font-family:\"Courier New\", Courier, monospace !important;font-weight:bold;color:#444;background-color:#F7FDFA;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#DCDCDC;}\n", | |
| ".tg th{font-family:\"Courier New\", Courier, monospace !important;font-weight:normal;color:#fff;background-color:#26ADE4;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#DCDCDC;}\n", | |
| ".tg .tg-left{font-family:\"Courier New\", Courier, monospace !important;font-weight:normal;text-align:left;}\n", | |
| ".tg .tg-center{font-family:\"Courier New\", Courier, monospace !important;font-weight:normal;text-align:center;}\n", | |
| ".tg .tg-right{font-family:\"Courier New\", Courier, monospace !important;font-weight:normal;text-align:right;}\n", | |
| "</style>\n", | |
| "<table class=\"tg\"><tr><th><b> GP_regression. </b></th><th><b> value</b></th><th><b>constraints</b></th><th><b>priors </b></th></tr>\n", | |
| "<tr><td class=tg-left> Mat52.variance </td><td class=tg-right>0.07496698473222704</td><td class=tg-center> +ve </td><td class=tg-center> </td></tr>\n", | |
| "<tr><td class=tg-left> Mat52.lengthscale </td><td class=tg-right> 1.0724342655855537</td><td class=tg-center> +ve </td><td class=tg-center> </td></tr>\n", | |
| "<tr><td class=tg-left> Gaussian_noise.variance</td><td class=tg-right> 1.2160004114647829</td><td class=tg-center> +ve </td><td class=tg-center>N(0, 1)</td></tr>\n", | |
| "</table>" | |
| ], | |
| "text/plain": [ | |
| "<GPy.models.gp_regression.GPRegression at 0x1a19249550>" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "norm.domain = \"positive\"\n", | |
| "gpmodel[param_name[2]].set_prior(norm)\n", | |
| "gpmodel" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "- これは半正規分布による制約に相当する(はず)。\n", | |
| " - これでは強すぎるので、半t分布による制約に変更したい。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "ename": "TypeError", | |
| "evalue": "object() takes no parameters", | |
| "output_type": "error", | |
| "traceback": [ | |
| "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
| "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
| "\u001b[0;32m<ipython-input-11-4841f2918b84>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mhalft\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGPy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpriors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mHalfT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnu\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
| "\u001b[0;32m/anaconda3/envs/py36/lib/python3.6/site-packages/GPy/core/parameterization/priors.py\u001b[0m in \u001b[0;36m__new__\u001b[0;34m(cls, A, nu)\u001b[0m\n\u001b[1;32m 1238\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mA\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mA\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnu\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mnu\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1239\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1240\u001b[0;31m \u001b[0mo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mPrior\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__new__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnu\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1241\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_instances\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweakref\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mref\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1242\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_instances\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
| "\u001b[0;31mTypeError\u001b[0m: object() takes no parameters" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "halft = GPy.priors.HalfT(nu=4, A=1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "- 何度か試行したが、このHalfTやInverse_Gammaは定義しようとするとエラーが出る。\n", | |
| " - GPy側の問題?\n", | |
| "- やむを得ないので、StudentTのDomainを上書きして代用する。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "reconstraining parameters GP_regression.Gaussian_noise.variance\n", | |
| "reconstraining parameters GP_regression.Gaussian_noise.variance\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style type=\"text/css\">\n", | |
| ".pd{\n", | |
| " font-family: \"Courier New\", Courier, monospace !important;\n", | |
| " width: 100%;\n", | |
| " padding: 3px;\n", | |
| "}\n", | |
| "</style>\n", | |
| "\n", | |
| "<p class=pd>\n", | |
| "<b>Model</b>: GP regression<br>\n", | |
| "<b>Objective</b>: 32.25080493682037<br>\n", | |
| "<b>Number of Parameters</b>: 3<br>\n", | |
| "<b>Number of Optimization Parameters</b>: 3<br>\n", | |
| "<b>Updates</b>: True<br>\n", | |
| "</p>\n", | |
| "<style type=\"text/css\">\n", | |
| ".tg {font-family:\"Courier New\", Courier, monospace !important;padding:2px 3px;word-break:normal;border-collapse:collapse;border-spacing:0;border-color:#DCDCDC;margin:0px auto;width:100%;}\n", | |
| ".tg td{font-family:\"Courier New\", Courier, monospace !important;font-weight:bold;color:#444;background-color:#F7FDFA;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#DCDCDC;}\n", | |
| ".tg th{font-family:\"Courier New\", Courier, monospace !important;font-weight:normal;color:#fff;background-color:#26ADE4;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#DCDCDC;}\n", | |
| ".tg .tg-left{font-family:\"Courier New\", Courier, monospace !important;font-weight:normal;text-align:left;}\n", | |
| ".tg .tg-center{font-family:\"Courier New\", Courier, monospace !important;font-weight:normal;text-align:center;}\n", | |
| ".tg .tg-right{font-family:\"Courier New\", Courier, monospace !important;font-weight:normal;text-align:right;}\n", | |
| "</style>\n", | |
| "<table class=\"tg\"><tr><th><b> GP_regression. </b></th><th><b> value</b></th><th><b>constraints</b></th><th><b> priors </b></th></tr>\n", | |
| "<tr><td class=tg-left> Mat52.variance </td><td class=tg-right>0.07496698473222704</td><td class=tg-center> +ve </td><td class=tg-center> </td></tr>\n", | |
| "<tr><td class=tg-left> Mat52.lengthscale </td><td class=tg-right> 1.0724342655855537</td><td class=tg-center> +ve </td><td class=tg-center> </td></tr>\n", | |
| "<tr><td class=tg-left> Gaussian_noise.variance</td><td class=tg-right> 1.2160004114647829</td><td class=tg-center> +ve </td><td class=tg-center>St(0, 0.3, 4)</td></tr>\n", | |
| "</table>" | |
| ], | |
| "text/plain": [ | |
| "<GPy.models.gp_regression.GPRegression at 0x1a19249550>" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "halft = GPy.priors.StudentT(mu=0, sigma=0.3, nu=4)\n", | |
| "halft.domain = \"positive\"\n", | |
| "gpmodel[param_name[2]].set_prior(halft)\n", | |
| "gpmodel" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# 4. MCMC" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "\"GP&TPonSinWave\"で書いた通り。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 1min 12s, sys: 395 ms, total: 1min 13s\n", | |
| "Wall time: 36.8 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "hmc = GPy.inference.mcmc.HMC(gpmodel) # sampler\n", | |
| "sample = hmc.sample()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "- デフォルトではiter=1000\n", | |
| "- 他にMH法の実装もある。レプリカ交換法はない。 " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "得たMCMCサンプルの分布は以下の通り" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 720x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "import seaborn as sns\n", | |
| "\n", | |
| "fig, ax = plt.subplots(figsize=(10,4))\n", | |
| "ax.set_yscale('log')\n", | |
| "sns.boxenplot(data=sample, ax=ax)\n", | |
| "ax.set_xticklabels(param_name)\n", | |
| "fig.tight_layout()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "EAP推定量を代入した結果は以下の通り" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| " /anaconda3/envs/py36/lib/python3.6/site-packages/matplotlib/figure.py:2369: UserWarning:This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "for ii in range(len(param_name)):\n", | |
| " gpmodel[param_name[ii]] = np.mean(sample, axis=0)[ii]\n", | |
| "gpmodel.plot()\n", | |
| "plt.plot(X_new, np.sin(X_new), color=\"r\", linestyle=\"dotted\", alpha=0.7, linewidth=2)\n", | |
| "plt.tight_layout()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# 5. plot" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 5.1. plot()\n", | |
| "- 通常のplotでは、デフォルトで95%予測区間が表示される。これを変更するには、引数lowerとupperを指定すれば良い(0~100)。\n", | |
| "- 入力が2次元なら、projection=\"3d\"で3次元グラフにできる。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| " /anaconda3/envs/py36/lib/python3.6/site-packages/matplotlib/figure.py:2369: UserWarning:This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "gpmodel.plot(lower=5, upper=95)\n", | |
| "plt.plot(X_new, np.sin(X_new), color=\"r\", linestyle=\"dotted\", alpha=0.7, linewidth=2)\n", | |
| "plt.title('90% predict interval')\n", | |
| "plt.tight_layout()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "その他の引数に関しては<https://gpy.readthedocs.io/en/deploy/GPy.plotting.gpy_plot.html#module-GPy.plotting.gpy_plot.gp_plots>を参照。 \n", | |
| "\n", | |
| "個々の要素を個別にプロットする関数も用意されている。 \n", | |
| "例えば予測区間なら" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| " /anaconda3/envs/py36/lib/python3.6/site-packages/matplotlib/figure.py:2369: UserWarning:This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "gpmodel.plot_confidence(lower=5, upper=95)\n", | |
| "plt.tight_layout()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 5.2. posterior_samples_f\n", | |
| "- GPから(その実現値としての)関数をサンプリングする関数\n", | |
| " - 紛らわしい名前の\"posterior_samples\"も存在するが、内部で\"posterior_samples_f\"を呼び出している模様。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(-1.6, 1.6)" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "X_new = np.arange(-7,7,0.1)\n", | |
| "\n", | |
| "fig, ax1 = plt.subplots(1, 1)\n", | |
| "ax1.plot(X_new, gpmodel.posterior_samples_f(X_new.reshape((-1,1)), size=10).reshape(-1,10), color=\"b\", alpha=0.1)\n", | |
| "#ax1.plot(X, Y , marker=\"o\", color=\"r\", linestyle=\"\")\n", | |
| "ax1.plot(X_new, np.sin(X_new), color=\"r\", linestyle=\"dotted\", alpha=0.7)\n", | |
| "ax1.set_title('GP posterior')\n", | |
| "ax1.set_ylim([-1.6, 1.6])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.8" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment