Last active
July 4, 2025 03:25
-
-
Save kn1kn1/cfbc71eedb1ba16a92a9eab88379b9f3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/kn1kn1/cfbc71eedb1ba16a92a9eab88379b9f3/basic_synthetic_continuous_basic_ja.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "Rkn_2ZmhPLvJ" | |
| }, | |
| "source": [ | |
| "# Quickstart: オフライン強化学習とオフ方策評価 (連続行動空間)\n", | |
| "\n", | |
| "このノートブックでは,簡単な人工データセット上でのオフ方策学習(OPE)とオフライン強化学習(Offline RL)の実行例を紹介します.\n", | |
| "\n", | |
| "特に、以下の5つの順番で解説していきます:\n", | |
| "\n", | |
| "0. (シミュレーション環境の設定とオンライン強化学習)\n", | |
| "1. 人工データの生成\n", | |
| "2. オフライン強化学習\n", | |
| "3. オフ方策評価とオフライン強化学習手法の評価\n", | |
| "4. オフ方策評価手法の評価\n", | |
| "\n", | |
| "このライブラリは,オンラインとオフラインの方策学習およびモデルベースのオフ方策評価の一部で [d3rlpy](https://github.com/takuseno/d3rlpy)のアルゴリズムを利用しています.\n", | |
| "また、実装のワークフローは[Open Bandit Pipeline](https://github.com/st-tech/zr-obp)を参考にしています." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "!pip install numpy==1.26.4" | |
| ], | |
| "metadata": { | |
| "id": "5vI3yK45PNrV" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "!pip install scope-rl==0.2.1" | |
| ], | |
| "metadata": { | |
| "id": "FoMPmR3PPQsP" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "CBY7cMq_PLvM" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# SCOPE-RL モジュールのインポート\n", | |
| "import scope_rl\n", | |
| "from basicgym import BasicEnv\n", | |
| "from scope_rl.dataset import SyntheticDataset\n", | |
| "from scope_rl.policy import OnlineHead, TruncatedGaussianHead\n", | |
| "from scope_rl.ope.online import (\n", | |
| " calc_on_policy_policy_value,\n", | |
| " visualize_on_policy_policy_value,\n", | |
| ")\n", | |
| "\n", | |
| "# d3rlpy アルゴリズムのインポート\n", | |
| "from d3rlpy.algos import DiscreteRandomPolicyConfig\n", | |
| "from d3rlpy.algos import RandomPolicyConfig as ContinuousRandomPolicyConfig\n", | |
| "from d3rlpy.preprocessing import MinMaxObservationScaler, MinMaxActionScaler\n", | |
| "\n", | |
| "# その他のライブラリのインポート\n", | |
| "import gym\n", | |
| "import torch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "RiAUjzXNPLvN" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import pickle\n", | |
| "from glob import glob\n", | |
| "from tqdm import tqdm\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "import pandas as pd\n", | |
| "\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import seaborn as sns\n", | |
| "\n", | |
| "%matplotlib inline" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "PA9APYr7PLvO", | |
| "outputId": "57cceb2e-91a0-46f1-a7da-4a0f6c5c36af" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "0.1.2\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# バージョン\n", | |
| "print(scope_rl.__version__)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "fGyLXwjXPLvO" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# random stateの設定\n", | |
| "random_state = 12345" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "03tvlTYiPLvP" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "_WgP0YRMPLvP" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# logディレクトリ\n", | |
| "from pathlib import Path\n", | |
| "Path(\"logs/\").mkdir(exist_ok=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "3M1g11HOPLvP" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# warnings\n", | |
| "import warnings\n", | |
| "warnings.simplefilter('ignore')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "1Ku02_W0PLvQ" | |
| }, | |
| "source": [ | |
| "\n", | |
| "## 0. (シミュレーション環境の設定とオンライン強化学習)\n", | |
| "まず初めに,今回用いる簡単な環境設定について紹介します.\n", | |
| "\n", | |
| "#### 強化学習での環境のセットアップ\n", | |
| "今回は簡単なシミュレーション環境上で,方策の獲得する累積報酬を最大化する問題を考えます.\n", | |
| "\n", | |
| "この強化学習の問題を(部分観測)マルコフ決定過程((PO)MDP)として定式化します.\n", | |
| "- `状態`: 状態観測(POMDPの場合は観測ノイズが発生).\n", | |
| "- `行動`: 強化学習エージェント (方策) により選択された行動.\n", | |
| "- `報酬`: 状態と行動に応じて観測される報酬.\n", | |
| "\n", | |
| "より詳細に環境の引数を確認したい場合は,次のノートブックを参照してください.[examples/quickstart_ja/basic/basic_synthetic_customize_env_ja.ipynb](https://github.com/hakuhodo-technologies/scope-rl/blob/main/examples/quickstart_ja/basic/basic_synthetic_customize_env_ja.ipynb)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "N7B7uH6pPLvQ" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# 環境のセットアップ\n", | |
| "env = BasicEnv(random_state=random_state)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "R3-7NqhdPLvQ" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# ランダムに行動を選択するエージェントを定義\n", | |
| "agent = OnlineHead(\n", | |
| " ContinuousRandomPolicyConfig().create(device=device),\n", | |
| " name=\"random\",\n", | |
| ")\n", | |
| "agent.build_with_env(env)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "8tjnY9UsPLvR" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# 環境とエージェントの相互作用\n", | |
| "# 6行で強化学習の相互作用を記述できる\n", | |
| "for episode in range(10):\n", | |
| " obs, info = env.reset()\n", | |
| " done = False\n", | |
| "\n", | |
| " while not done:\n", | |
| " action = agent.predict_online(obs)\n", | |
| " obs, reward, done, truncated, info = env.step(action)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "e7atAJYpPLvR", | |
| "outputId": "af1a2f55-f213-4e20-b2ba-a20c07f967e4" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[-0.37164978 -0.49943402 0.36963097 -0.28399277 -0.62862005]\n", | |
| "(5,)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# 状態\n", | |
| "print(obs)\n", | |
| "print(obs.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "5USFoow3PLvR", | |
| "outputId": "cda276f6-c96d-464a-e683-f5ea54c2d3d3" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "", | |
| "text/plain": [ | |
| "<Figure size 640x480 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# ランダムエージェントの報酬の推移を可視化\n", | |
| "obs, info = env.reset()\n", | |
| "done = False\n", | |
| "reward_list = []\n", | |
| "\n", | |
| "while not done:\n", | |
| " action = agent.sample_action_online(obs)\n", | |
| " obs, reward, done, truncated, info = env.step(action)\n", | |
| " reward_list.append(reward)\n", | |
| "\n", | |
| "# プロット\n", | |
| "fig = plt.figure()\n", | |
| "ax1 = fig.add_subplot(111)\n", | |
| "ax1.plot(reward_list[:-1], label='reward', color='tab:orange')\n", | |
| "ax1.set_xlabel('timestep')\n", | |
| "ax1.set_ylabel('reward')\n", | |
| "ax1.legend(loc='upper left')\n", | |
| "plt.show()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "ctNCWd2aPLvR" | |
| }, | |
| "source": [ | |
| "#### オンライン強化学習と方策の比較\n", | |
| "まずは方策価値のオンラインでの推定に基づき,オンライン強化学習で学習した方策とランダム方策を比較します." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "x1InWkecPLvR", | |
| "outputId": "ed70b570-6e96-4e70-98c2-166962bf8750" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Box(-1.0, 1.0, (3,), float64)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# SCOPE-RLで実装されている連続行動空間に対する標準的な環境\n", | |
| "env = gym.make(\"BasicEnv-continuous-v0\")\n", | |
| "print(env.action_space)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "2pC95N_XPLvS" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from d3rlpy.algos import SACConfig\n", | |
| "from d3rlpy.models.encoders import VectorEncoderFactory\n", | |
| "from d3rlpy.models.q_functions import MeanQFunctionFactory\n", | |
| "from d3rlpy.dataset import create_fifo_replay_buffer\n", | |
| "\n", | |
| "# モデル\n", | |
| "sac = SACConfig(\n", | |
| " actor_encoder_factory=VectorEncoderFactory(hidden_units=[30, 30]),\n", | |
| " critic_encoder_factory=VectorEncoderFactory(hidden_units=[30, 30]),\n", | |
| " q_func_factory=MeanQFunctionFactory(),\n", | |
| " action_scaler=MinMaxActionScaler(\n", | |
| " minimum=env.action_space.low,\n", | |
| " maximum=env.action_space.high,\n", | |
| " ),\n", | |
| ").create(device=device)\n", | |
| "\n", | |
| "# 再生バッファの設定\n", | |
| "buffer = create_fifo_replay_buffer(\n", | |
| " limit=10000,\n", | |
| " env=env,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "Rzn8TYwUPLvS" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# 学習の開始\n", | |
| "# 事前学習したモデルを利用する場合はスキップ\n", | |
| "sac.fit_online(\n", | |
| " env,\n", | |
| " buffer,\n", | |
| " eval_env=env,\n", | |
| " n_steps=100000,\n", | |
| " n_steps_per_epoch=1000,\n", | |
| " update_start_step=1000,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "2vR9qQggPLvS" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# モデルの保存\n", | |
| "sac.save_model(\"d3rlpy_logs/sac.pt\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "4x7OSi4APLvS" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# モデルのリロード\n", | |
| "sac.build_with_env(env)\n", | |
| "sac.load_model(\"d3rlpy_logs/sac.pt\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "syKSgeDYPLvS", | |
| "outputId": "adfe405a-152d-4420-8822-d86ba298c0b2" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Text(0, 0.5, 'episode reward')" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "", | |
| "text/plain": [ | |
| "<Figure size 640x480 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# モデルの学習曲線\n", | |
| "log_path = sorted(glob(\"d3rlpy_logs/SAC_online_*/evaluation.csv\"))[-1]\n", | |
| "df = pd.read_csv(\n", | |
| " log_path,\n", | |
| " usecols=[1,2],\n", | |
| " names=[\"train_episode\", \"episodic_rewards\"]\n", | |
| ")\n", | |
| "plt.plot(df[\"train_episode\"], df[\"episodic_rewards\"])\n", | |
| "plt.title(\"Learning curve of SAC\")\n", | |
| "plt.xlabel(\"episode\")\n", | |
| "plt.ylabel(\"episode reward\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "cRC3_WlAPLvS" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "random = ContinuousRandomPolicyConfig().create(device=device)\n", | |
| "random.build_with_env(env)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "referenced_widgets": [ | |
| "b380f9e17ff042e0af27ac842430000e", | |
| "9e88116c8ce942a988c8dd206e72ccf9" | |
| ] | |
| }, | |
| "id": "DYXBIVkPPLvS", | |
| "outputId": "ca85f4b5-52bc-4f46-d1dc-2145340c074f" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "b380f9e17ff042e0af27ac842430000e", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "[calculate on-policy policy value]: 0%| | 0/100 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "9e88116c8ce942a988c8dd206e72ccf9", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "[calculate on-policy policy value]: 0%| | 0/100 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "", | |
| "text/plain": [ | |
| "<Figure size 400x400 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# オンラインでの性能比較\n", | |
| "visualize_on_policy_policy_value(\n", | |
| " env=env,\n", | |
| " policies=[sac, random],\n", | |
| " policy_names=[\"sac\", \"random\"],\n", | |
| " n_trajectories=100,\n", | |
| " random_state=random_state,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "referenced_widgets": [ | |
| "2206752b1e194a04ab609bd3dda78f50", | |
| "f8fc4a2b5a5740ba8a72e4a7b44234a6" | |
| ] | |
| }, | |
| "id": "izvOXfu4PLvS", | |
| "outputId": "baf15c84-e914-4617-8e5b-08e2281f10a6" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "2206752b1e194a04ab609bd3dda78f50", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "[calculate on-policy policy value]: 0%| | 0/100 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "f8fc4a2b5a5740ba8a72e4a7b44234a6", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "[calculate on-policy policy value]: 0%| | 0/100 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "improvement: 6.220823797617203\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# sac方策のオンライン評価による方策価値\n", | |
| "# ランダム方策より大幅に改善\n", | |
| "sac_performance = calc_on_policy_policy_value(env, sac, n_trajectories=100, random_state=random_state)\n", | |
| "random_performance = calc_on_policy_policy_value(env, random, n_trajectories=100, random_state=random_state)\n", | |
| "print(\"improvement:\", sac_performance - random_performance)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "yIFMKrK3PLvS" | |
| }, | |
| "source": [ | |
| "`scope_rl.ope.online`では,以下の関数を用いて方策の性能をオンラインで評価できます.\n", | |
| "\n", | |
| "(統計量)\n", | |
| "\n", | |
| "* `calc_on_policy_policy_value`\n", | |
| "* `calc_on_policy_policy_value_interval`\n", | |
| "* `calc_on_policy_variance`\n", | |
| "* `calc_on_policy_conditional_value_at_risk`\n", | |
| "* `calc_on_policy_policy_interquartile_range`\n", | |
| "* `calc_on_policy_cumulative_distribution_function`\n", | |
| "\n", | |
| "(可視化)\n", | |
| "* `visualize_on_policy_policy_value`\n", | |
| "* `visualize_on_policy_policy_value_with_variance`\n", | |
| "* `visualize_on_policy_cumulative_distribution_function`\n", | |
| "* `visualize_on_policy_conditional_value_at_risk`\n", | |
| "* `visualize_on_policy_interquartile_range`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "q_LnVAUwPLvS" | |
| }, | |
| "source": [ | |
| "## 1. 人工データの生成\n", | |
| "前章ではシミュレーション環境をオンライン環境とみなし,どのようにどのようにオンライン強化学習やオンラインでの性能評価を行うか確認しました.今回のようにシミュレーション環境上では簡単にエージェントを学習することができますが,実環境でオンラインのデータ収集や逐次的意思決定を行うことにはリスクを伴います.そこでオフライン強化学習ではデータ収集方策$\\pi_0$から集められた以下のオフラインのログデータのみを利用することで,新たな方策を学習することを考えます.\n", | |
| "\n", | |
| "$$\\mathcal{D}_0 := \\{ \\{ s_t, a_t, s_{t+1}, r_t \\}_{t=1}^T \\}_{i=1}^n \\sim \\prod_{i=1}^n d_{s_0}(s_0) \\prod_{t=1}^T \\pi_0(a_t | s_t) p(s_{t+1} | s_t, a_t) p(r_t | s_t, a_t),$$\n", | |
| "\n", | |
| "ここで $s_t \\in \\mathcal{S}$ は時刻$t$で観測された状態,$a_t$ はデータ収集方策 $\\pi_0$によって選択された行動,$r_t$は状態と行動に応じて観測された報酬を表します.\n", | |
| "\n", | |
| "SCOPE-RLでは人工データを簡単に生成するデータセットモジュールクラスとして `SyntheticDataset`を実装しており,`SyntheticDataset` は以下の引数を持っています:\n", | |
| "- `env`: 強化学習の(シミュレーション)環境.\n", | |
| "- `max_episode_steps`: 一つのエピソードにおける連続意思決定の数(上限)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "cm2ZbWPXPLvS" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "behavior_policy = TruncatedGaussianHead(\n", | |
| " sac,\n", | |
| " minimum=env.action_space.low,\n", | |
| " maximum=env.action_space.high,\n", | |
| " sigma=np.array([0.5]),\n", | |
| " name=\"sac_sigma_0.5\",\n", | |
| " random_state=random_state,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "k6Dss8ikPLvT" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# データセットクラスの初期化\n", | |
| "dataset = SyntheticDataset(\n", | |
| " env=env,\n", | |
| " max_episode_steps=env.step_per_episode,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "icfNsRjePLvT" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# ログデータがある存在する場合はスキップ\n", | |
| "# データ収集方策によってログデータを集める\n", | |
| "train_logged_dataset = dataset.obtain_episodes(\n", | |
| " behavior_policies=behavior_policy,\n", | |
| " n_trajectories=10000,\n", | |
| " obtain_info=False,\n", | |
| " random_state=random_state,\n", | |
| ")\n", | |
| "test_logged_dataset = dataset.obtain_episodes(\n", | |
| " behavior_policies=behavior_policy,\n", | |
| " n_trajectories=10000,\n", | |
| " obtain_info=False,\n", | |
| " random_state=random_state + 1,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "DQ-IinwCPLvT", | |
| "outputId": "e0996c6a-190e-4837-a549-3efacc3fea59" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'size': 100000,\n", | |
| " 'n_trajectories': 10000,\n", | |
| " 'step_per_trajectory': 10,\n", | |
| " 'action_type': 'continuous',\n", | |
| " 'n_actions': None,\n", | |
| " 'action_dim': 3,\n", | |
| " 'action_meaning': None,\n", | |
| " 'action_keys': None,\n", | |
| " 'state_dim': 5,\n", | |
| " 'state_keys': None,\n", | |
| " 'state': array([[ 0.46692103, -0.60091272, 0.12748286, 0.10612129, 0.62719618],\n", | |
| " [ 0.10666174, -0.43413801, 0.69911049, -0.49153976, 0.26416026],\n", | |
| " [-0.09165767, 0.53838334, 0.57967616, 0.25195753, 0.54975922],\n", | |
| " ...,\n", | |
| " [-0.58171771, 0.6782399 , 0.07907731, 0.44068758, 0.03371026],\n", | |
| " [-0.03506825, -0.34642982, 0.70207422, -0.36781008, 0.50056381],\n", | |
| " [ 0.09089026, 0.70256162, 0.47892005, 0.20888019, 0.47450056]]),\n", | |
| " 'action': array([[ 0.63094648, -0.93446503, 0.56403605],\n", | |
| " [ 0.48463834, 0.22951301, -0.71212477],\n", | |
| " [ 0.90848962, -0.58848582, -0.62280191],\n", | |
| " ...,\n", | |
| " [ 0.633153 , -0.9542258 , -0.43497972],\n", | |
| " [ 0.42549025, 0.32503334, -0.6607296 ],\n", | |
| " [ 0.80997434, -0.30832653, -0.29731225]]),\n", | |
| " 'reward': array([0.51629582, 0.27206326, 0.4348537 , ..., 0.21085268, 0.30432125,\n", | |
| " 0.44587291]),\n", | |
| " 'done': array([0., 0., 0., ..., 0., 0., 1.]),\n", | |
| " 'terminal': array([0., 0., 0., ..., 0., 0., 1.]),\n", | |
| " 'info': {},\n", | |
| " 'pscore': array([[0.94034815, 0.94034815, 0.94034815],\n", | |
| " [0.89860991, 0.89860991, 0.89860991],\n", | |
| " [1.15576638, 1.15576638, 1.15576638],\n", | |
| " ...,\n", | |
| " [0.90518726, 0.90518726, 0.90518726],\n", | |
| " [0.69265362, 0.69265362, 0.69265362],\n", | |
| " [1.02172236, 1.02172236, 1.02172236]]),\n", | |
| " 'behavior_policy': 'sac_sigma_0.5',\n", | |
| " 'dataset_id': 0}" | |
| ] | |
| }, | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "train_logged_dataset" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "xvnfg1pNPLvT" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "with open(\"logs/train_dataset_continuous_sac.pkl\", \"wb\") as f:\n", | |
| " pickle.dump(train_logged_dataset, f)\n", | |
| "with open(\"logs/test_dataset_continuous_sac.pkl\", \"wb\") as f:\n", | |
| " pickle.dump(test_logged_dataset, f)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "UZNITAetPLvT" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "with open(\"logs/train_dataset_continuous_sac.pkl\", \"rb\") as f:\n", | |
| " train_logged_dataset = pickle.load(f)\n", | |
| "with open(\"logs/test_dataset_continuous_sac.pkl\", \"rb\") as f:\n", | |
| " test_logged_dataset = pickle.load(f)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "referenced_widgets": [ | |
| "42f9c085cbec44518c3c3d5808db24e2", | |
| "c4c19c7a6f444638b2d5a9fc90ab7da3", | |
| "4011ac24ea444bbdb2a46e7a288fa9af" | |
| ] | |
| }, | |
| "id": "HvggqcuAPLvT", | |
| "outputId": "8f3b0775-be29-4478-cd32-f658645eaea1" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "42f9c085cbec44518c3c3d5808db24e2", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "[calculate on-policy policy value]: 0%| | 0/100 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "c4c19c7a6f444638b2d5a9fc90ab7da3", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "[calculate on-policy policy value]: 0%| | 0/100 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "4011ac24ea444bbdb2a46e7a288fa9af", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "[calculate on-policy policy value]: 0%| | 0/100 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "", | |
| "text/plain": [ | |
| "<Figure size 600x400 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# オンラインでの方策性能の比較\n", | |
| "visualize_on_policy_policy_value(\n", | |
| " env=env,\n", | |
| " policies=[sac, behavior_policy, random],\n", | |
| " policy_names=[\"sac\", \"sac (sigma=0.5)\", \"random\"],\n", | |
| " n_trajectories=100,\n", | |
| " random_state=random_state,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "UFA2O2suPLvT" | |
| }, | |
| "source": [ | |
| "より詳細にデータ収集と可視化について知りたい場合は,[examples/quickstart_ja/basic/basic_synthetic_data_collection_ja.ipynb](https://github.com/hakuhodo-technologies/scope-rl/blob/main/examples/quickstart_ja/basic/basic_synthetic_data_collection_ja.ipynb)を参照してください." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "mchWPyzEPLvT" | |
| }, | |
| "source": [ | |
| "## 2. オフライン強化学習\n", | |
| "この章では実環境と相互作用せずに,ログデータのみから新しい方策を学習する方法を紹介します.\n", | |
| "オフライン強化学習のアルゴリズムは [d3rlpy](https://github.com/takuseno/d3rlpy)に実装されているものを利用します." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "8FJNgAzgPLvT" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# d3rlpyのモジュールのインポート\n", | |
| "from d3rlpy.dataset import MDPDataset\n", | |
| "from d3rlpy.algos import CQLConfig" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "5lx-4g4kPLvT" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "offlinerl_dataset = MDPDataset(\n", | |
| " observations=test_logged_dataset[\"state\"],\n", | |
| " actions=test_logged_dataset[\"action\"],\n", | |
| " rewards=test_logged_dataset[\"reward\"],\n", | |
| " terminals=test_logged_dataset[\"done\"],\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "3Bpn8hMUPLvT" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Conservative Q-Learning 方策を利用\n", | |
| "cql = CQLConfig(\n", | |
| " actor_encoder_factory=VectorEncoderFactory(hidden_units=[30, 30]),\n", | |
| " critic_encoder_factory=VectorEncoderFactory(hidden_units=[30, 30]),\n", | |
| " q_func_factory=MeanQFunctionFactory(),\n", | |
| " action_scaler=MinMaxActionScaler(\n", | |
| " minimum=env.action_space.low, # 方策が取りうる最小の値\n", | |
| " maximum=env.action_space.high, # 方策が取りうる最大の値\n", | |
| " )\n", | |
| ").create(device=device)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "M3LRuz9JPLvT" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "cql.fit(\n", | |
| " offlinerl_dataset,\n", | |
| " n_steps=10000,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "ZlqTivQNPLvZ" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# モデルの保存\n", | |
| "cql.save_model(\"d3rlpy_logs/cql_continuous.pt\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "GRW5cOVcPLvZ", | |
| "outputId": "f09d7fb5-a8a1-4d41-88fe-43d5cdb0735f" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2023-07-22 16:29:44 [warning ] Parameters will be reinitialized.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# モデルのリロード\n", | |
| "cql.build_with_env(env)\n", | |
| "cql.load_model(\"d3rlpy_logs/cql_continuous.pt\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "referenced_widgets": [ | |
| "a0c3e1e5016d424a9089bce7c22b8d95", | |
| "92ba2e1f069847fc88ebb06371864b51", | |
| "cb9095474942427caee28dd9ecec72ac", | |
| "433d86a8d318460e95b1a06ba56487c0" | |
| ] | |
| }, | |
| "id": "eJXZaU1nPLvZ", | |
| "outputId": "0302ed03-80e9-4b65-a2b9-a36e55bd94ab" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "a0c3e1e5016d424a9089bce7c22b8d95", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "[calculate on-policy policy value]: 0%| | 0/100 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "92ba2e1f069847fc88ebb06371864b51", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "[calculate on-policy policy value]: 0%| | 0/100 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "cb9095474942427caee28dd9ecec72ac", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "[calculate on-policy policy value]: 0%| | 0/100 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "433d86a8d318460e95b1a06ba56487c0", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "[calculate on-policy policy value]: 0%| | 0/100 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "", | |
| "text/plain": [ | |
| "<Figure size 800x400 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# オンラインでの方策性能の比較\n", | |
| "# オフラインで学習した方策もオンラインで学習した方策と同等程度の性能を達成\n", | |
| "visualize_on_policy_policy_value(\n", | |
| " env=env,\n", | |
| " policies=[sac, behavior_policy, cql, random],\n", | |
| " policy_names=[\"sac\", \"sac (sigma=0.5)\", \"cql\", \"random\"],\n", | |
| " n_trajectories=100,\n", | |
| " random_state=random_state,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "nF3_BzLnPLvZ" | |
| }, | |
| "source": [ | |
| "## 3. オフ方策評価とオフライン強化学習手法の評価\n", | |
| "前章ではオンラインの性能評価 (A/Bテスト) により学習した方策の性能を確認しました.しかし,実務においては方策の性能が悪い場合に顧客の満足度を毀損してしまうなど,A/Bテストにも大きなリスクが伴うことが知られています.そこで,学習した方策が実環境でどのように動作するかを事前に見積もるために,オフラインでの方策性能の評価を行います.\n", | |
| "\n", | |
| "#### 推定目標\n", | |
| "ここでは,方策の期待性能 (*方策価値*) を推定します.\n", | |
| "\n", | |
| "$$ V(\\pi) := \\mathbb{E}\\left[\\sum_{t=1}^T \\gamma^{t-1} r_t \\mid \\pi \\right] $$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "r8IGfvzBPLvZ" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# SCOPE-RLからのopeモジュールのインポート\n", | |
| "from scope_rl.ope import CreateOPEInput\n", | |
| "from scope_rl.ope import OffPolicyEvaluation as OPE\n", | |
| "from scope_rl.ope.continuous import DirectMethod as DM\n", | |
| "from scope_rl.ope.continuous import TrajectoryWiseImportanceSampling as TIS\n", | |
| "from scope_rl.ope.continuous import PerDecisionImportanceSampling as PDIS\n", | |
| "from scope_rl.ope.continuous import DoublyRobust as DR\n", | |
| "from scope_rl.ope.continuous import SelfNormalizedTIS as SNTIS\n", | |
| "from scope_rl.ope.continuous import SelfNormalizedPDIS as SNPDIS\n", | |
| "from scope_rl.ope.continuous import SelfNormalizedDR as SNDR\n", | |
| "from scope_rl.policy import ContinuousEvalHead" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "5ZN5g3N-PLvZ" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# 評価方策を定義\n", | |
| "cql_ = ContinuousEvalHead(\n", | |
| " base_policy=cql,\n", | |
| " name=\"cql\",\n", | |
| ")\n", | |
| "sac_ = ContinuousEvalHead(\n", | |
| " base_policy=sac,\n", | |
| " name=\"sac\",\n", | |
| ")\n", | |
| "evaluation_policies = [cql_, sac_]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "jldHZYaKPLvZ" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# まずOPEのインプットを準備\n", | |
| "prep = CreateOPEInput(\n", | |
| " env=env,\n", | |
| " model_args={\n", | |
| " \"fqe\": {\n", | |
| " \"encoder_factory\": VectorEncoderFactory(hidden_units=[30, 30]),\n", | |
| " \"q_func_factory\": MeanQFunctionFactory(),\n", | |
| " \"learning_rate\": 1e-4,\n", | |
| " },\n", | |
| " },\n", | |
| " state_scaler=MinMaxObservationScaler(\n", | |
| " minimum=test_logged_dataset[\"state\"].min(axis=0),\n", | |
| " maximum=test_logged_dataset[\"state\"].max(axis=0),\n", | |
| " ),\n", | |
| " action_scaler=MinMaxActionScaler(\n", | |
| " minimum=env.action_space.low,\n", | |
| " maximum=env.action_space.high,\n", | |
| " ),\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "KYhPmsH9PLvZ" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "input_dict = prep.obtain_whole_inputs(\n", | |
| " logged_dataset=test_logged_dataset,\n", | |
| " evaluation_policies=evaluation_policies,\n", | |
| " require_value_prediction=True,\n", | |
| " n_trajectories_on_policy_evaluation=100,\n", | |
| " random_state=random_state,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "C1Tp2kuhPLvZ" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "with open(\"logs/ope_input_dict_continuous.pkl\", \"wb\") as f:\n", | |
| " pickle.dump(input_dict, f)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "6_y_CQsuPLvZ" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "with open(\"logs/ope_input_dict_continuous.pkl\", \"rb\") as f:\n", | |
| " input_dict = pickle.load(f)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "fHxQwHi-PLva", | |
| "outputId": "31f81c30-7f85-44c3-ae26-b91b3ca9b039" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'cql': {'evaluation_policy_action_dist': None,\n", | |
| " 'evaluation_policy_action': array([[ 0.7895372 , -0.79363215, -0.52138674],\n", | |
| " [ 0.8106339 , -0.33552885, -0.45466143],\n", | |
| " [ 0.78984964, -0.4607501 , 0.70200634],\n", | |
| " ...,\n", | |
| " [ 0.78719187, 0.18121517, -0.56885785],\n", | |
| " [ 0.78742146, 0.06306946, -0.48978412],\n", | |
| " [ 0.76593447, -0.43075836, 0.6505569 ]], dtype=float32),\n", | |
| " 'state_action_value_prediction': array([[4.30566072, 4.58573961],\n", | |
| " [4.97807884, 4.73458385],\n", | |
| " [4.75874472, 4.98025274],\n", | |
| " ...,\n", | |
| " [4.28577471, 4.6361475 ],\n", | |
| " [4.45088673, 4.58396673],\n", | |
| " [4.59209585, 4.74256086]]),\n", | |
| " 'initial_state_value_prediction': array([4.58573961, 4.65792894, 4.77657366, ..., 4.72337389, 4.84374475,\n", | |
| " 4.34953928]),\n", | |
| " 'state_action_marginal_importance_weight': None,\n", | |
| " 'state_marginal_importance_weight': None,\n", | |
| " 'on_policy_policy_value': array([5.74958172, 5.31987359, 5.48685406, 5.67917462, 5.116804 ,\n", | |
| " 5.67244663, 5.33136124, 5.08846073, 5.33374774, 5.52337794,\n", | |
| " 5.3665136 , 5.51723173, 5.21675983, 5.56649301, 5.66331608,\n", | |
| " 5.5020216 , 5.42224366, 5.36061777, 5.27333682, 5.34842719,\n", | |
| " 5.21663029, 5.50297997, 5.21280586, 5.29410713, 5.27231692,\n", | |
| " 5.32935076, 5.4187381 , 5.63229718, 5.39490127, 5.55270347,\n", | |
| " 5.35659077, 5.53594237, 5.09255133, 5.5723891 , 5.51750553,\n", | |
| " 5.5278357 , 5.4170849 , 5.40114623, 5.36637647, 5.53502923,\n", | |
| " 5.42176323, 5.4504927 , 5.48108056, 5.30647794, 5.11547507,\n", | |
| " 5.42704371, 5.48291705, 5.32882222, 5.55444547, 5.54054854,\n", | |
| " 5.59120766, 5.57179807, 5.37788577, 5.45202666, 5.652961 ,\n", | |
| " 5.16112705, 5.99347832, 5.21848124, 5.62500358, 5.7027091 ,\n", | |
| " 5.49142645, 5.41976854, 6.08446954, 5.48365786, 5.60331058,\n", | |
| " 5.75577116, 5.49147728, 5.38499021, 5.58039468, 5.44336472,\n", | |
| " 5.23734497, 5.45965116, 5.56531265, 4.95377563, 5.55036139,\n", | |
| " 5.50900376, 5.14605189, 5.20541641, 5.68976302, 5.66397365,\n", | |
| " 5.39175972, 5.3144299 , 5.49217746, 5.73677416, 5.53788389,\n", | |
| " 5.61985157, 5.49950872, 5.98476523, 5.97510753, 5.47228223,\n", | |
| " 5.21004315, 5.82224613, 5.43995411, 5.32581354, 5.20931119,\n", | |
| " 5.35452023, 5.76586113, 5.74633475, 5.20548451, 5.64406232]),\n", | |
| " 'gamma': 1.0,\n", | |
| " 'behavior_policy': 'sac_sigma_0.5',\n", | |
| " 'evaluation_policy': 'cql',\n", | |
| " 'dataset_id': 0},\n", | |
| " 'sac': {'evaluation_policy_action_dist': None,\n", | |
| " 'evaluation_policy_action': array([[ 0.9611497 , -0.97684574, -0.22405702],\n", | |
| " [ 0.96218014, -0.4036672 , -0.26055682],\n", | |
| " [ 0.9764993 , -0.67134833, 0.7849159 ],\n", | |
| " ...,\n", | |
| " [ 0.9841423 , 0.20341933, -0.631484 ],\n", | |
| " [ 0.9831486 , 0.02586651, -0.55481225],\n", | |
| " [ 0.9478135 , -0.28132623, 0.5369407 ]], dtype=float32),\n", | |
| " 'state_action_value_prediction': array([[5.31686211, 6.07998896],\n", | |
| " [5.91870594, 5.75539255],\n", | |
| " [5.59816265, 5.87946272],\n", | |
| " ...,\n", | |
| " [5.02844667, 5.45477772],\n", | |
| " [5.09789944, 5.41336918],\n", | |
| " [5.25617313, 5.42402124]]),\n", | |
| " 'initial_state_value_prediction': array([6.07998896, 5.42326832, 5.73191738, ..., 5.74739695, 5.92206669,\n", | |
| " 5.41095066]),\n", | |
| " 'state_action_marginal_importance_weight': None,\n", | |
| " 'state_marginal_importance_weight': None,\n", | |
| " 'on_policy_policy_value': array([6.09022037, 6.32040044, 6.63314716, 6.32587884, 5.94799035,\n", | |
| " 6.18751298, 6.70182178, 6.18631965, 5.94538534, 6.52917627,\n", | |
| " 6.51161632, 6.37259652, 5.62279331, 6.20060001, 6.72941676,\n", | |
| " 5.75506929, 6.18472014, 6.19029891, 6.18716371, 5.91147484,\n", | |
| " 6.05846759, 6.45905514, 5.6540668 , 5.83800332, 5.59527464,\n", | |
| " 5.63868089, 6.22684424, 6.46364278, 5.95472293, 6.54427809,\n", | |
| " 5.97328679, 6.46434761, 6.02378658, 6.3989872 , 6.30630791,\n", | |
| " 6.77377617, 6.42833196, 6.1274932 , 6.13344936, 6.28930782,\n", | |
| " 6.04073885, 6.25466379, 6.18562348, 6.64294738, 5.48437976,\n", | |
| " 6.20051981, 6.30738976, 6.03213259, 6.26023494, 6.4932229 ,\n", | |
| " 6.46970749, 6.41663935, 6.03184991, 5.73305688, 6.10182476,\n", | |
| " 5.86358585, 6.61050078, 6.04027827, 6.65348242, 6.37870681,\n", | |
| " 6.6875429 , 5.97847487, 6.63592013, 6.32216829, 6.56415138,\n", | |
| " 6.40538109, 6.46079024, 5.98343649, 5.73867663, 6.28519443,\n", | |
| " 6.10751007, 6.35279383, 6.50316469, 6.00166357, 6.78168563,\n", | |
| " 6.52767799, 5.8270543 , 5.64433012, 7.06787604, 6.32129163,\n", | |
| " 5.99241529, 6.47414793, 6.64510297, 6.45757072, 5.79394901,\n", | |
| " 6.69239703, 6.13964648, 6.502342 , 6.59521697, 6.25035521,\n", | |
| " 5.62877118, 6.63336862, 6.47385173, 5.94986879, 5.69937017,\n", | |
| " 5.94733121, 6.54736512, 6.19049303, 6.08320053, 6.65535394]),\n", | |
| " 'gamma': 1.0,\n", | |
| " 'behavior_policy': 'sac_sigma_0.5',\n", | |
| " 'evaluation_policy': 'sac',\n", | |
| " 'dataset_id': 0}}" | |
| ] | |
| }, | |
| "execution_count": 43, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "input_dict" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "0_5C0irJPLva" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "ope = OPE(\n", | |
| " logged_dataset=test_logged_dataset,\n", | |
| " ope_estimators=[DM(), TIS(), PDIS(), DR(), SNTIS(), SNPDIS(), SNDR()],\n", | |
| " action_scaler=MinMaxActionScaler(\n", | |
| " minimum=env.action_space.low,\n", | |
| " maximum=env.action_space.high,\n", | |
| " ),\n", | |
| " bandwidth=5.0,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "UKKdncJXPLva" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# opeによる性能評価\n", | |
| "# 計算に時間がかかります(約15分)\n", | |
| "policy_value_df_dict, policy_value_interval_df_dict = ope.summarize_off_policy_estimates(input_dict, random_state=random_state)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "0dnwxrEsPLva", | |
| "outputId": "f4b60c8e-af09-43c8-8421-773a9188a1df" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'cql': policy_value relative_policy_value\n", | |
| " on_policy 5.466111e+00 1.000000e+00\n", | |
| " dm 4.717382e+00 8.630235e-01\n", | |
| " tis 1.348137e+20 2.466355e+19\n", | |
| " pdis -2.820616e+20 -5.160188e+19\n", | |
| " dr -1.780842e+20 -3.257968e+19\n", | |
| " sntis 1.322637e+00 2.419705e-01\n", | |
| " snpdis -3.629344e-01 -6.639719e-02\n", | |
| " sndr 3.815478e+00 6.980243e-01,\n", | |
| " 'sac': policy_value relative_policy_value\n", | |
| " on_policy 6.236361e+00 1.000000e+00\n", | |
| " dm 5.716103e+00 9.165767e-01\n", | |
| " tis 1.235306e+20 1.980812e+19\n", | |
| " pdis -2.605517e+20 -4.177945e+19\n", | |
| " dr -3.265419e+20 -5.236096e+19\n", | |
| " sntis 1.322643e+00 2.120857e-01\n", | |
| " snpdis -3.612269e-01 -5.792271e-02\n", | |
| " sndr 4.322576e+00 6.931247e-01}" | |
| ] | |
| }, | |
| "execution_count": 46, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# 推定した方策性能\n", | |
| "policy_value_df_dict" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "EEFDeqMsPLva", | |
| "outputId": "a7cf95b1-1c7e-43ac-8546-56c798459f88" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'cql': mean 95.0% CI (lower) 95.0% CI (upper)\n", | |
| " on_policy 5.466634e+00 5.425790e+00 5.505745e+00\n", | |
| " dm 4.717618e+00 4.713465e+00 4.722355e+00\n", | |
| " tis 1.455944e+20 1.767848e+14 4.042672e+20\n", | |
| " pdis -3.046265e+20 -8.461850e+20 1.695091e+14\n", | |
| " dr -1.923214e+20 -5.338968e+20 -3.702428e+14\n", | |
| " sntis 1.428405e+00 1.734410e-06 3.966206e+00\n", | |
| " snpdis -2.624129e-01 -1.593248e+00 7.838554e-01\n", | |
| " sndr 3.190053e+00 -6.431012e+00 1.000763e+01,\n", | |
| " 'sac': mean 95.0% CI (lower) 95.0% CI (upper)\n", | |
| " on_policy 6.238561e+00 6.165749e+00 6.299698e+00\n", | |
| " dm 5.716434e+00 5.710103e+00 5.723485e+00\n", | |
| " tis 1.334089e+20 1.605203e+14 3.704296e+20\n", | |
| " pdis -2.813959e+20 -7.816554e+20 1.602881e+14\n", | |
| " dr -3.526548e+20 -9.792319e+20 -4.095352e+14\n", | |
| " sntis 1.428410e+00 1.718692e-06 3.966192e+00\n", | |
| " snpdis -2.605759e-01 -1.591601e+00 7.864232e-01\n", | |
| " sndr 3.744827e+00 -7.261544e+00 1.096623e+01}" | |
| ] | |
| }, | |
| "execution_count": 47, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# 推定した方策性能の信頼区間\n", | |
| "policy_value_interval_df_dict" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "_yU4WsuBPLva" | |
| }, | |
| "source": [ | |
| "## 4. オフ方策評価手法の評価\n", | |
| "オフ方策推定量がどの程度正確に方策価値を推定できているかを評価します." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "tferlZyxPLva", | |
| "outputId": "ff3b1391-d0a0-4e8e-f88b-2e8a536438ff" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "", | |
| "text/plain": [ | |
| "<Figure size 800x800 with 2 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# opeの結果の比較・可視化\n", | |
| "# 多くのオフライン指標が方策価値を過小評価しています\n", | |
| "# これはオフラインデータが, 評価方策が取りうるデータを十分にカバーできていないことが原因です\n", | |
| "\n", | |
| "ope.visualize_off_policy_estimates(\n", | |
| " input_dict,\n", | |
| " compared_estimators=[\"dm\", \"sntis\", \"snpdis\", \"sndr\"],\n", | |
| " random_state=random_state,\n", | |
| " sharey=True,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "leeDs-2XPLva", | |
| "outputId": "b23ce996-f978-4254-bd61-d192a0fcdc52" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>cql</th>\n", | |
| " <th>sac</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>dm</th>\n", | |
| " <td>5.605949e-01</td>\n", | |
| " <td>2.706679e-01</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>tis</th>\n", | |
| " <td>1.817473e+40</td>\n", | |
| " <td>1.525981e+40</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>pdis</th>\n", | |
| " <td>7.955874e+40</td>\n", | |
| " <td>6.788720e+40</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>dr</th>\n", | |
| " <td>3.171398e+40</td>\n", | |
| " <td>1.066296e+41</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>sntis</th>\n", | |
| " <td>1.716838e+01</td>\n", | |
| " <td>2.414463e+01</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>snpdis</th>\n", | |
| " <td>3.397777e+01</td>\n", | |
| " <td>4.352817e+01</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>sndr</th>\n", | |
| " <td>2.724589e+00</td>\n", | |
| " <td>3.662573e+00</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " cql sac\n", | |
| "dm 5.605949e-01 2.706679e-01\n", | |
| "tis 1.817473e+40 1.525981e+40\n", | |
| "pdis 7.955874e+40 6.788720e+40\n", | |
| "dr 3.171398e+40 1.066296e+41\n", | |
| "sntis 1.716838e+01 2.414463e+01\n", | |
| "snpdis 3.397777e+01 4.352817e+01\n", | |
| "sndr 2.724589e+00 3.662573e+00" | |
| ] | |
| }, | |
| "execution_count": 49, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# OPEの推定の正確さを測るために2乗誤差を利用します\n", | |
| "eval_metric_ope_df = ope.evaluate_performance_of_ope_estimators(\n", | |
| " input_dict,\n", | |
| " metric=\"se\",\n", | |
| " return_by_dataframe=True,\n", | |
| ")\n", | |
| "eval_metric_ope_df" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "mQhh3C4APLva" | |
| }, | |
| "source": [ | |
| "OPEのより発展的なトピック(累積分布関数の推定など) やオフ方策選択 (OPS) は,[examples/quickstart_ja/basic/basic_synthetic_continuous_advanced_ja.ipynb](https://github.com/hakuhodo-technologies/scope-rl/blob/main/examples/quickstart_ja/basic/basic_synthetic_continuous_advanced_ja.ipynb)を参照してください.\n", | |
| "\n", | |
| "OPEのより発展的な推定量 (state(-action) marginal estimators や double reinforcement learningなど) は[examples/quickstart_ja/basic/basic_synthetic_continuous_zoo_ja.ipynb](https://github.com/hakuhodo-technologies/scope-rl/blob/main/examples/quickstart_ja/basic/basic_synthetic_continuous_zoo_ja.ipynb)を参照してください.\n", | |
| "\n", | |
| "離散的な行動空間に対する例は[examples/quickstart_ja/basic/basic_synthetic_discrete_basic_ja.ipynb](https://github.com/hakuhodo-technologies/scope-rl/blob/main/examples/quickstart_ja/basic/basic_synthetic_discrete_basic_ja.ipynb)を参照してください." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "aCcU2xn0PLva" | |
| }, | |
| "source": [ | |
| "## 参考文献\n", | |
| "\n", | |
| "- Haanvid Lee, Jongmin Lee, Yunseon Choi, Wonseok Jeon, Byung-Jun Lee, Yung-Kyun Noh, and Kee-Eung Kim. \\\n", | |
| "\"Local Metric Learning for Off-Policy Evaluation in Contextual Bandits with Continuous Actions.\", 2022.\n", | |
| "\n", | |
| "- Yuta Saito, Shunsuke Aihara, Megumi Matsutani, and Yusuke Narita. \\\n", | |
| "\"Open Bandit Dataset and Pipeline: Towards Realistic and Reproducible Off-Policy Evaluation.\", 2021.\n", | |
| "\n", | |
| "- Takuma Seno and Michita Imai. \\\n", | |
| "\"d3rlpy: An Offline Deep Reinforcement Library.\", 2021.\n", | |
| "\n", | |
| "- Sergey Levine, Aviral Kumar, George Tucker, and Justin Fu. \\\n", | |
| "\"Offline Reinforcement Learning: Tutorial, Review, and Perspectives on Open Problems.\" 2020.\n", | |
| "\n", | |
| "- Aviral Kumar, Aurick Zhou, George Tucker, and Sergey Levine. \\\n", | |
| "\"Conservative Q-Learning for Offline Reinforcement Learning.\", 2020.\n", | |
| "\n", | |
| "- Nathan Kallus and Angela Zhou. \\\n", | |
| "\"Policy Evaluation and Optimization with Continuous Treatments.\", 2019.\n", | |
| "\n", | |
| "- Nathan Kallus and Masatoshi Uehara. \\\n", | |
| "\"Intrinsically Efficient, Stable, and Bounded Off-Policy Evaluation for Reinforcement Learning.\", 2019.\n", | |
| "\n", | |
| "- Hoang Le, Cameron Voloshin, and Yisong Yue. \\\n", | |
| "\"Batch Policy Learning under Constraints.\", 2019.\n", | |
| "\n", | |
| "- Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, and Sergey Levine. \\\n", | |
| "\"Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.\" 2018.\n", | |
| "\n", | |
| "- Di Wu, Xiujun Chen, Xun Yang, Hao Wang, Qing Tan, Xiaoxun Zhang, Jian Xu, and Kun Gai. \\\n", | |
| "\"Budget Constrained Bidding by Model-free Reinforcement Learning in Display Advertising.\", 2018.\n", | |
| "\n", | |
| "- Jun Zhao, Guang Qiu, Ziyu Guan, Wei Zhao, and Xiaofei He. \\\n", | |
| "\"Deep Reinforcement Learning for Sponsored Search Real-time Bidding.\", 2018.\n", | |
| "\n", | |
| "- Nan Jiang and Lihong Li. \\\n", | |
| "\"Doubly Robust Off-policy Value Evaluation for Reinforcement Learning.\", 2016.\n", | |
| "\n", | |
| "- Philip S. Thomas and Emma Brunskill. \\\n", | |
| "\"Data-Efficient Off-Policy Policy Evaluation for Reinforcement Learning.\", 2016.\n", | |
| "\n", | |
| "- Greg Brockman, Vicki Cheung, Ludwig Pettersson, Jonas Schneider, John Schulman, Jie Tang, and Wojciech Zaremba. \\\n", | |
| "\"OpenAI Gym.\", 2016.\n", | |
| "\n", | |
| "- Adith Swaminathan and Thorsten Joachims. \\\n", | |
| "\"The Self-Normalized Estimator for Counterfactual Learning.\", 2015.\n", | |
| "\n", | |
| "- Miroslav Dudík, Dumitru Erhan, John Langford, and Lihong Li. \\\n", | |
| "\"Doubly Robust Policy Evaluation and Optimization.\", 2014.\n", | |
| "\n", | |
| "- Alex Strehl, John Langford, Sham Kakade, and Lihong Li. \\\n", | |
| "\"Learning from Logged Implicit Exploration Data.\", 2010.\n", | |
| "\n", | |
| "- Alina Beygelzimer and John Langford. \\\n", | |
| "\"The Offset Tree for Learning with Partial Labels.\", 2009.\n", | |
| "\n", | |
| "- Doina Precup, Richard S. Sutton, and Satinder P. Singh. \\\n", | |
| "\"Eligibility Traces for Off-Policy Policy Evaluation.\", 2000." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "UvckNFq5PLva" | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.9.6" | |
| }, | |
| "vscode": { | |
| "interpreter": { | |
| "hash": "70404ee114725fce8ed9e697d67827f8546c678889944e6d695790702cbfe1f5" | |
| } | |
| }, | |
| "colab": { | |
| "provenance": [], | |
| "include_colab_link": true | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment