Created
May 28, 2025 21:00
-
-
Save sgbaird/aa140b187e72491a278d98039253f94e to your computer and use it in GitHub Desktop.
ax-mwe-for-sergio-saving.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/sgbaird/aa140b187e72491a278d98039253f94e/ax-mwe-for-sergio-saving.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "57xih43KFdOY", | |
| "outputId": "d3538b65-581d-4abd-c15f-1dbe0aad8ba4" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Collecting ax-platform==0.4.3\n", | |
| " Downloading ax_platform-0.4.3-py3-none-any.whl.metadata (11 kB)\n", | |
| "Collecting botorch==0.12.0 (from ax-platform==0.4.3)\n", | |
| " Downloading botorch-0.12.0-py3-none-any.whl.metadata (11 kB)\n", | |
| "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from ax-platform==0.4.3) (3.1.6)\n", | |
| "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from ax-platform==0.4.3) (2.2.2)\n", | |
| "Requirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from ax-platform==0.4.3) (1.15.3)\n", | |
| "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (from ax-platform==0.4.3) (1.6.1)\n", | |
| "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from ax-platform==0.4.3) (7.7.1)\n", | |
| "Requirement already satisfied: plotly>=5.12.0 in /usr/local/lib/python3.11/dist-packages (from ax-platform==0.4.3) (5.24.1)\n", | |
| "Collecting pyre-extensions (from ax-platform==0.4.3)\n", | |
| " Downloading pyre_extensions-0.0.32-py3-none-any.whl.metadata (4.0 kB)\n", | |
| "Requirement already satisfied: multipledispatch in /usr/local/lib/python3.11/dist-packages (from botorch==0.12.0->ax-platform==0.4.3) (1.0.0)\n", | |
| "Requirement already satisfied: mpmath<=1.3,>=0.19 in /usr/local/lib/python3.11/dist-packages (from botorch==0.12.0->ax-platform==0.4.3) (1.3.0)\n", | |
| "Requirement already satisfied: torch>=2.0.1 in /usr/local/lib/python3.11/dist-packages (from botorch==0.12.0->ax-platform==0.4.3) (2.6.0+cu124)\n", | |
| "Collecting pyro-ppl>=1.8.4 (from botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading pyro_ppl-1.9.1-py3-none-any.whl.metadata (7.8 kB)\n", | |
| "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.11/dist-packages (from botorch==0.12.0->ax-platform==0.4.3) (4.13.2)\n", | |
| "Collecting gpytorch==1.13 (from botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading gpytorch-1.13-py3-none-any.whl.metadata (8.0 kB)\n", | |
| "Collecting linear-operator==0.5.3 (from botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading linear_operator-0.5.3-py3-none-any.whl.metadata (15 kB)\n", | |
| "Collecting jaxtyping==0.2.19 (from gpytorch==1.13->botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading jaxtyping-0.2.19-py3-none-any.whl.metadata (5.7 kB)\n", | |
| "Requirement already satisfied: numpy>=1.20.0 in /usr/local/lib/python3.11/dist-packages (from jaxtyping==0.2.19->gpytorch==1.13->botorch==0.12.0->ax-platform==0.4.3) (2.0.2)\n", | |
| "Requirement already satisfied: typeguard>=2.13.3 in /usr/local/lib/python3.11/dist-packages (from jaxtyping==0.2.19->gpytorch==1.13->botorch==0.12.0->ax-platform==0.4.3) (4.4.2)\n", | |
| "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.11/dist-packages (from plotly>=5.12.0->ax-platform==0.4.3) (9.1.2)\n", | |
| "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from plotly>=5.12.0->ax-platform==0.4.3) (24.2)\n", | |
| "Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->ax-platform==0.4.3) (6.17.1)\n", | |
| "Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->ax-platform==0.4.3) (0.2.0)\n", | |
| "Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->ax-platform==0.4.3) (5.7.1)\n", | |
| "Requirement already satisfied: widgetsnbextension~=3.6.0 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->ax-platform==0.4.3) (3.6.10)\n", | |
| "Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->ax-platform==0.4.3) (7.34.0)\n", | |
| "Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->ax-platform==0.4.3) (3.0.15)\n", | |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->ax-platform==0.4.3) (3.0.2)\n", | |
| "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->ax-platform==0.4.3) (2.9.0.post0)\n", | |
| "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->ax-platform==0.4.3) (2025.2)\n", | |
| "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->ax-platform==0.4.3) (2025.2)\n", | |
| "Collecting typing-inspect (from pyre-extensions->ax-platform==0.4.3)\n", | |
| " Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)\n", | |
| "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->ax-platform==0.4.3) (1.5.0)\n", | |
| "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->ax-platform==0.4.3) (3.6.0)\n", | |
| "Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (1.8.0)\n", | |
| "Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (6.1.12)\n", | |
| "Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (0.1.7)\n", | |
| "Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (1.6.0)\n", | |
| "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (5.9.5)\n", | |
| "Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (24.0.1)\n", | |
| "Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (6.4.2)\n", | |
| "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (75.2.0)\n", | |
| "Collecting jedi>=0.16 (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3)\n", | |
| " Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)\n", | |
| "Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (4.4.2)\n", | |
| "Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (0.7.5)\n", | |
| "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (3.0.51)\n", | |
| "Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (2.19.1)\n", | |
| "Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (0.2.0)\n", | |
| "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (4.9.0)\n", | |
| "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.11/dist-packages (from pyro-ppl>=1.8.4->botorch==0.12.0->ax-platform==0.4.3) (3.4.0)\n", | |
| "Collecting pyro-api>=0.1.1 (from pyro-ppl>=1.8.4->botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading pyro_api-0.1.2-py3-none-any.whl.metadata (2.5 kB)\n", | |
| "Requirement already satisfied: tqdm>=4.36 in /usr/local/lib/python3.11/dist-packages (from pyro-ppl>=1.8.4->botorch==0.12.0->ax-platform==0.4.3) (4.67.1)\n", | |
| "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas->ax-platform==0.4.3) (1.17.0)\n", | |
| "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (3.18.0)\n", | |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (3.4.2)\n", | |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (2025.3.2)\n", | |
| "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", | |
| "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", | |
| "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", | |
| "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", | |
| "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", | |
| "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", | |
| "Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", | |
| "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", | |
| "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", | |
| "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (0.6.2)\n", | |
| "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (2.21.5)\n", | |
| "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (12.4.127)\n", | |
| "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n", | |
| " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", | |
| "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (3.2.0)\n", | |
| "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (1.13.1)\n", | |
| "Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.11/dist-packages (from widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (6.5.7)\n", | |
| "Collecting mypy-extensions>=0.3.0 (from typing-inspect->pyre-extensions->ax-platform==0.4.3)\n", | |
| " Downloading mypy_extensions-1.1.0-py3-none-any.whl.metadata (1.1 kB)\n", | |
| "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (0.8.4)\n", | |
| "Requirement already satisfied: jupyter-core>=4.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (5.7.2)\n", | |
| "Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (23.1.0)\n", | |
| "Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (5.10.4)\n", | |
| "Requirement already satisfied: nbconvert>=5 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (7.16.6)\n", | |
| "Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.8.3)\n", | |
| "Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.18.1)\n", | |
| "Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.22.0)\n", | |
| "Requirement already satisfied: nbclassic>=0.4.7 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.3.1)\n", | |
| "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.11/dist-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (0.7.0)\n", | |
| "Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (0.2.13)\n", | |
| "Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core>=4.6.0->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (4.3.8)\n", | |
| "Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.2.4)\n", | |
| "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (4.13.4)\n", | |
| "Requirement already satisfied: bleach!=5.0.0 in /usr/local/lib/python3.11/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (6.2.0)\n", | |
| "Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.7.1)\n", | |
| "Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.3.0)\n", | |
| "Requirement already satisfied: mistune<4,>=2.0.3 in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (3.1.3)\n", | |
| "Requirement already satisfied: nbclient>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.10.2)\n", | |
| "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.5.1)\n", | |
| "Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (2.21.1)\n", | |
| "Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (4.23.0)\n", | |
| "Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (21.2.0)\n", | |
| "Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach!=5.0.0->bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.5.1)\n", | |
| "Requirement already satisfied: tinycss2<1.5,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.4.0)\n", | |
| "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (25.3.0)\n", | |
| "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (2025.4.1)\n", | |
| "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.36.2)\n", | |
| "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.25.1)\n", | |
| "Requirement already satisfied: jupyter-server<3,>=1.8 in /usr/local/lib/python3.11/dist-packages (from notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.16.0)\n", | |
| "Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.17.1)\n", | |
| "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (2.7)\n", | |
| "Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (2.22)\n", | |
| "Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (4.9.0)\n", | |
| "Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.8.0)\n", | |
| "Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (3.10)\n", | |
| "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.3.1)\n", | |
| "Downloading ax_platform-0.4.3-py3-none-any.whl (1.3 MB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading botorch-0.12.0-py3-none-any.whl (644 kB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m644.8/644.8 kB\u001b[0m \u001b[31m35.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading gpytorch-1.13-py3-none-any.whl (277 kB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m277.8/277.8 kB\u001b[0m \u001b[31m17.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading linear_operator-0.5.3-py3-none-any.whl (176 kB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m176.4/176.4 kB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading jaxtyping-0.2.19-py3-none-any.whl (24 kB)\n", | |
| "Downloading pyre_extensions-0.0.32-py3-none-any.whl (12 kB)\n", | |
| "Downloading pyro_ppl-1.9.1-py3-none-any.whl (755 kB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m756.0/756.0 kB\u001b[0m \u001b[31m26.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m43.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m20.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m17.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m891.4 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m56.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)\n", | |
| "Downloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB)\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m48.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25hDownloading mypy_extensions-1.1.0-py3-none-any.whl (5.0 kB)\n", | |
| "Downloading pyro_api-0.1.2-py3-none-any.whl (11 kB)\n", | |
| "Installing collected packages: pyro-api, nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, mypy-extensions, jedi, typing-inspect, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jaxtyping, pyre-extensions, nvidia-cusolver-cu12, pyro-ppl, linear-operator, gpytorch, botorch, ax-platform\n", | |
| " Attempting uninstall: nvidia-nvjitlink-cu12\n", | |
| " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", | |
| " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", | |
| " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", | |
| " Attempting uninstall: nvidia-curand-cu12\n", | |
| " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", | |
| " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", | |
| " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", | |
| " Attempting uninstall: nvidia-cufft-cu12\n", | |
| " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", | |
| " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", | |
| " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", | |
| " Attempting uninstall: nvidia-cuda-runtime-cu12\n", | |
| " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", | |
| " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", | |
| " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", | |
| " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", | |
| " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", | |
| " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", | |
| " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", | |
| " Attempting uninstall: nvidia-cuda-cupti-cu12\n", | |
| " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", | |
| " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", | |
| " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", | |
| " Attempting uninstall: nvidia-cublas-cu12\n", | |
| " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", | |
| " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", | |
| " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", | |
| " Attempting uninstall: nvidia-cusparse-cu12\n", | |
| " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", | |
| " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", | |
| " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", | |
| " Attempting uninstall: nvidia-cudnn-cu12\n", | |
| " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", | |
| " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", | |
| " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%pip install ax-platform==0.4.3" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "zA74rOiWoDro" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def branin3(x1, x2, x3, c1):\n", | |
| " y = float(\n", | |
| " (x2 - 5.1 / (4 * np.pi**2) * x1**2 + 5.0 / np.pi * x1 - 6.0) ** 2\n", | |
| " + 10 * (1 - 1.0 / (8 * np.pi)) * np.cos(x1)\n", | |
| " + 10\n", | |
| " )\n", | |
| "\n", | |
| " # Contrived way to incorporate x3 into the objective\n", | |
| " y = y * (1 + 0.1 * x1 * x2 * x3)\n", | |
| "\n", | |
| " # add a made-up penalty based on category\n", | |
| " penalty_lookup = {\"A\": 1.0, \"B\": 0.0, \"C\": 2.0}\n", | |
| " y += penalty_lookup[c1]\n", | |
| "\n", | |
| " return y" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "background_save": true | |
| }, | |
| "id": "FJtWMajrhG32", | |
| "outputId": "40d861a7-8be2-4746-e533-942811e23cc7" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "[WARNING 05-27 21:37:57] ax.service.utils.with_db_settings_base: Ax currently requires a sqlalchemy version below 2.0. This will be addressed in a future release. Disabling SQL storage in Ax for now, if you would like to use SQL storage please install Ax with mysql extras via `pip install ax-platform[mysql]`.\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.\n", | |
| "[INFO 05-27 21:37:57] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.\n", | |
| "[INFO 05-27 21:37:57] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.\n", | |
| "[INFO 05-27 21:37:57] ax.service.utils.instantiation: Inferred value type of ParameterType.STRING for parameter c1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.\n", | |
| "/usr/local/lib/python3.11/dist-packages/ax/service/utils/instantiation.py:248: AxParameterWarning: `sort_values` is not specified for `ChoiceParameter` \"c1\". Defaulting to `False` for parameters of `ParameterType` STRING. To override this behavior (or avoid this warning), specify `sort_values` during `ChoiceParameter` construction.\n", | |
| " return ChoiceParameter(\n", | |
| "[INFO 05-27 21:37:57] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 10.0]), ChoiceParameter(name='c1', parameter_type=STRING, values=['A', 'B', 'C'], is_ordered=False, sort_values=False)], parameter_constraints=[ParameterConstraint(1.0*x1 + 1.0*x2 <= 10.0)]).\n", | |
| "/usr/local/lib/python3.11/dist-packages/ax/modelbridge/cross_validation.py:463: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.\n", | |
| " warn(\"Encountered exception in computing model fit quality: \" + str(e))\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 3.670143, 'x2': 2.112298, 'c1': 'C'} using model Sobol.\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Saved JSON-serialized state of optimization to `ax_client_snapshot.json`.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Suggested next experiment (trial #0) -- x1: 3.6701425071805716, x2: 2.1122976299375296, x3: 4.217559862881899, c1: C\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Generated by Honegumi (https://arxiv.org/abs/2502.06815)\n", | |
| "# %pip install ax-platform==0.4.3 matplotlib\n", | |
| "import numpy as np\n", | |
| "import pandas as pd\n", | |
| "from ax.service.ax_client import AxClient, ObjectiveProperties\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "\n", | |
| "\n", | |
| "from ax.modelbridge.factory import Models\n", | |
| "from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy\n", | |
| "\n", | |
| "\n", | |
| "obj1_name = \"branin\"\n", | |
| "\n", | |
| "\n", | |
| "# Define total for compositional constraint, where x1 + x2 + x3 == total\n", | |
| "total = 10.0\n", | |
| "\n", | |
| "gs = GenerationStrategy(\n", | |
| " steps=[\n", | |
| " GenerationStep(\n", | |
| " model=Models.SOBOL,\n", | |
| " num_trials=5, # how many sobol trials to perform (rule of thumb: 2 * number of params)\n", | |
| " min_trials_observed=3,\n", | |
| " max_parallelism=5,\n", | |
| " model_kwargs={\"seed\": 999},\n", | |
| " ),\n", | |
| " GenerationStep(\n", | |
| " model=Models.SAASBO,\n", | |
| " num_trials=-1,\n", | |
| " max_parallelism=3,\n", | |
| " model_kwargs={},\n", | |
| " ),\n", | |
| " ]\n", | |
| ")\n", | |
| "\n", | |
| "ax_client = AxClient(generation_strategy=gs)\n", | |
| "# if using SAASBO is too slow, remove `gs` entirely and just use `ax_client = AxClient()`, which will give you reasonable defaults\n", | |
| "\n", | |
| "ax_client.create_experiment(\n", | |
| " parameters=[\n", | |
| " {\"name\": \"x1\", \"type\": \"range\", \"bounds\": [0.0, total]},\n", | |
| " {\"name\": \"x2\", \"type\": \"range\", \"bounds\": [0.0, total]},\n", | |
| " {\n", | |
| " \"name\": \"c1\",\n", | |
| " \"type\": \"choice\",\n", | |
| " \"is_ordered\": False,\n", | |
| " \"values\": [\"A\", \"B\", \"C\"],\n", | |
| " },\n", | |
| " ],\n", | |
| " objectives={\n", | |
| " obj1_name: ObjectiveProperties(minimize=True),\n", | |
| " },\n", | |
| " parameter_constraints=[\n", | |
| " f\"x1 + x2 <= {total}\", # reparameterized compositional constraint, which is a type of sum constraint\n", | |
| " ],\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "parameterization, trial_index = ax_client.get_next_trial()\n", | |
| "\n", | |
| "# extract parameters\n", | |
| "x1 = parameterization[\"x1\"]\n", | |
| "x2 = parameterization[\"x2\"]\n", | |
| "x3 = total - (x1 + x2) # composition constraint: x1 + x2 + x3 == total\n", | |
| "\n", | |
| "c1 = parameterization[\"c1\"]\n", | |
| "\n", | |
| "print(f\"Suggested next experiment (trial #{trial_index}) -- x1: {x1}, x2: {x2}, x3: {x3}, c1: {c1}\")\n", | |
| "\n", | |
| "snapshot_fpath = \"ax_client_snapshot.json\"\n", | |
| "ax_client.save_to_json_file(snapshot_fpath)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "background_save": true | |
| }, | |
| "id": "ZKnLwJmdii6R", | |
| "outputId": "6fdfa44f-d368-4ae4-e28f-95c8459472df" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "9.487827279043518\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "del ax_client # just to emphasize that it's being loaded in a later cell (i.e., for illustration purposes)\n", | |
| "results = branin3(x1, x2, x3, c1) # this is you \"running the experiment\"\n", | |
| "print(results)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "background_save": true | |
| }, | |
| "id": "vzd_8070inpr", | |
| "outputId": "8f16bfe9-17e6-4da8-efe3-337720567a8f" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/usr/local/lib/python3.11/dist-packages/ax/storage/json_store/decoder.py:303: AxParameterWarning: `sort_values` is not specified for `ChoiceParameter` \"c1\". Defaulting to `False` for parameters of `ParameterType` STRING. To override this behavior (or avoid this warning), specify `sort_values` during `ChoiceParameter` construction.\n", | |
| " return _class(\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "ax_client = AxClient.load_from_json_file(snapshot_fpath)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "background_save": true | |
| }, | |
| "id": "gytHOHxAiqCv", | |
| "outputId": "dc96b827-7e7a-4898-d18e-d415998e9442" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Completed trial 0 with data: {'branin': (9.487827, None)}.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# NOTE: we need to pass trial_index to tell AxClient which parameterization these results correspond to)\n", | |
| "ax_client.complete_trial(trial_index=trial_index, raw_data={\"branin\": results})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "background_save": true | |
| }, | |
| "id": "iOh-9ZedjhfJ", | |
| "outputId": "a88cac84-c31c-4565-a0ed-6ef70f58aeb6" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/usr/local/lib/python3.11/dist-packages/ax/modelbridge/cross_validation.py:463: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.\n", | |
| " warn(\"Encountered exception in computing model fit quality: \" + str(e))\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 1.085235, 'x2': 4.961436, 'c1': 'C'} using model Sobol.\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Saved JSON-serialized state of optimization to `ax_client_snapshot.json`.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "{'x1': 1.0852348804473877, 'x2': 4.961435543373227, 'c1': 'C'} 1\n", | |
| "Suggested next experiment -- x1: 1.0852348804473877, x2: 4.961435543373227, x3: 3.953329576179385, c1: C\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "parameterization, trial_index = ax_client.get_next_trial() # run this cell only once (because doing sequential optimization)\n", | |
| "\n", | |
| "print(parameterization, trial_index)\n", | |
| "\n", | |
| "# extract parameters\n", | |
| "x1 = parameterization[\"x1\"]\n", | |
| "x2 = parameterization[\"x2\"]\n", | |
| "x3 = total - (x1 + x2) # composition constraint: x1 + x2 + x3 == total\n", | |
| "\n", | |
| "c1 = parameterization[\"c1\"]\n", | |
| "\n", | |
| "print(f\"Suggested next experiment -- x1: {x1}, x2: {x2}, x3: {x3}, c1: {c1}\")\n", | |
| "\n", | |
| "ax_client.save_to_json_file(snapshot_fpath)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "background_save": true | |
| }, | |
| "id": "c0y-Grs3kHV4", | |
| "outputId": "946ab777-7950-4583-ab96-5175c9244a3f" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "48.206864586996545\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "del ax_client\n", | |
| "results = branin3(x1, x2, x3, c1) # this is you \"running the experiment\"\n", | |
| "print(results)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "background_save": true | |
| }, | |
| "id": "ITNW1oyNkbUi", | |
| "outputId": "d3b8bc3a-eca8-4b1a-c2fa-9e410c67cdf8" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/usr/local/lib/python3.11/dist-packages/ax/storage/json_store/decoder.py:303: AxParameterWarning: `sort_values` is not specified for `ChoiceParameter` \"c1\". Defaulting to `False` for parameters of `ParameterType` STRING. To override this behavior (or avoid this warning), specify `sort_values` during `ChoiceParameter` construction.\n", | |
| " return _class(\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "ax_client = AxClient.load_from_json_file(snapshot_fpath)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "background_save": true | |
| }, | |
| "id": "MkkjhzOMk7AF", | |
| "outputId": "af6a4d94-c80c-4342-a1f1-0280a1c88bec" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Completed trial 1 with data: {'branin': (48.206865, None)}.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "ax_client.complete_trial(trial_index=trial_index, raw_data={\"branin\": results})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "background_save": true, | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "BpQxf7-yl0W7", | |
| "outputId": "f8db80ab-aa8c-4f36-82b3-96c6a498eaf7" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/usr/local/lib/python3.11/dist-packages/ax/modelbridge/cross_validation.py:463: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.\n", | |
| " warn(\"Encountered exception in computing model fit quality: \" + str(e))\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 5.255021, 'x2': 1.09099, 'c1': 'A'} using model Sobol.\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Completed trial 2 with data: {'branin': (47.334254, None)}.\n", | |
| "/usr/local/lib/python3.11/dist-packages/ax/modelbridge/cross_validation.py:463: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.\n", | |
| " warn(\"Encountered exception in computing model fit quality: \" + str(e))\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 4.206534, 'x2': 2.918926, 'c1': 'A'} using model Sobol.\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Completed trial 3 with data: {'branin': (33.207714, None)}.\n", | |
| "/usr/local/lib/python3.11/dist-packages/ax/modelbridge/cross_validation.py:463: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.\n", | |
| " warn(\"Encountered exception in computing model fit quality: \" + str(e))\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 0.46391, 'x2': 6.076421, 'c1': 'B'} using model Sobol.\n", | |
| "[INFO 05-27 21:37:57] ax.service.ax_client: Completed trial 4 with data: {'branin': (37.937763, None)}.\n", | |
| "[INFO 05-27 21:39:48] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 8.525951, 'x2': 0.0, 'c1': 'C'} using model SAASBO.\n", | |
| "[INFO 05-27 21:39:48] ax.service.ax_client: Completed trial 5 with data: {'branin': (9.339133, None)}.\n", | |
| "[INFO 05-27 21:41:16] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 10.0, 'x2': 0.0, 'c1': 'B'} using model SAASBO.\n", | |
| "[INFO 05-27 21:41:16] ax.service.ax_client: Completed trial 6 with data: {'branin': (10.960889, None)}.\n", | |
| "/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n", | |
| " warnings.warn(\n", | |
| "[INFO 05-27 21:42:53] ax.service.ax_client: Generated new trial 7 with parameters {'x1': 6.282715, 'x2': 0.0, 'c1': 'B'} using model SAASBO.\n", | |
| "[INFO 05-27 21:42:53] ax.service.ax_client: Completed trial 7 with data: {'branin': (20.812079, None)}.\n", | |
| "/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n", | |
| " warnings.warn(\n", | |
| "[INFO 05-27 21:44:38] ax.service.ax_client: Generated new trial 8 with parameters {'x1': 0.0, 'x2': 10.0, 'c1': 'A'} using model SAASBO.\n", | |
| "[INFO 05-27 21:44:38] ax.service.ax_client: Completed trial 8 with data: {'branin': (36.602113, None)}.\n", | |
| "/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n", | |
| " warnings.warn(\n", | |
| "[INFO 05-27 21:46:07] ax.service.ax_client: Generated new trial 9 with parameters {'x1': 10.0, 'x2': 0.0, 'c1': 'C'} using model SAASBO.\n", | |
| "[INFO 05-27 21:46:07] ax.service.ax_client: Completed trial 9 with data: {'branin': (12.960889, None)}.\n", | |
| "/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n", | |
| " warnings.warn(\n", | |
| "[INFO 05-27 21:47:48] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 10.0, 'x2': 0.0, 'c1': 'A'} using model SAASBO.\n", | |
| "[INFO 05-27 21:47:48] ax.service.ax_client: Completed trial 10 with data: {'branin': (11.960889, None)}.\n", | |
| "/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n", | |
| " warnings.warn(\n", | |
| "[INFO 05-27 21:49:39] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 3.232036, 'x2': 0.0, 'c1': 'C'} using model SAASBO.\n", | |
| "[INFO 05-27 21:49:39] ax.service.ax_client: Completed trial 11 with data: {'branin': (7.301467, None)}.\n", | |
| "/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n", | |
| " warnings.warn(\n", | |
| "[INFO 05-27 21:51:30] ax.service.ax_client: Generated new trial 12 with parameters {'x1': 7.784494, 'x2': 0.0, 'c1': 'C'} using model SAASBO.\n", | |
| "[INFO 05-27 21:51:30] ax.service.ax_client: Completed trial 12 with data: {'branin': (14.737298, None)}.\n", | |
| "/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n", | |
| " warnings.warn(\n", | |
| "[INFO 05-27 21:53:28] ax.service.ax_client: Generated new trial 13 with parameters {'x1': 4.273055, 'x2': 0.0, 'c1': 'C'} using model SAASBO.\n", | |
| "[INFO 05-27 21:53:28] ax.service.ax_client: Completed trial 13 with data: {'branin': (10.343237, None)}.\n", | |
| "/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n", | |
| " warnings.warn(\n", | |
| "[INFO 05-27 21:55:23] ax.service.ax_client: Generated new trial 14 with parameters {'x1': 0.0, 'x2': 0.0, 'c1': 'C'} using model SAASBO.\n", | |
| "[INFO 05-27 21:55:23] ax.service.ax_client: Completed trial 14 with data: {'branin': (57.602113, None)}.\n", | |
| "[INFO 05-27 21:57:07] ax.service.ax_client: Generated new trial 15 with parameters {'x1': 3.387063, 'x2': 0.0, 'c1': 'B'} using model SAASBO.\n", | |
| "[INFO 05-27 21:57:07] ax.service.ax_client: Completed trial 15 with data: {'branin': (5.059481, None)}.\n", | |
| "/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n", | |
| " warnings.warn(\n", | |
| "[INFO 05-27 21:58:36] ax.service.ax_client: Generated new trial 16 with parameters {'x1': 9.123101, 'x2': 0.0, 'c1': 'B'} using model SAASBO.\n", | |
| "[INFO 05-27 21:58:36] ax.service.ax_client: Completed trial 16 with data: {'branin': (5.814624, None)}.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 20min 12s, sys: 5.97 s, total: 20min 18s\n", | |
| "Wall time: 20min 38s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "# run a bunch of experiments in a loop without saving/loading to give it some data so the end of campaign analysis is meaningful\n", | |
| "\n", | |
| "for i in range(15): # i.e., 10 more trials\n", | |
| " parameterization, trial_index = ax_client.get_next_trial()\n", | |
| "\n", | |
| " # extract parameters\n", | |
| " x1 = parameterization[\"x1\"]\n", | |
| " x2 = parameterization[\"x2\"]\n", | |
| " x3 = total - (x1 + x2) # composition constraint: x1 + x2 + x3 == total\n", | |
| "\n", | |
| " c1 = parameterization[\"c1\"]\n", | |
| "\n", | |
| " results = branin3(x1, x2, x3, c1)\n", | |
| " ax_client.complete_trial(trial_index=trial_index, raw_data={\"branin\": results})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "QBafjdSsjafe" | |
| }, | |
| "source": [ | |
| "## End of Campaign / Analysis" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "background_save": true | |
| }, | |
| "id": "OfhEK-ARhfMy", | |
| "outputId": "fd56de59-a94b-4367-96c4-1a1ae96e4ad3" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "[WARNING 05-27 21:59:52] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Best parameters: {'x1': 3.387062556889936, 'x2': 0.0, 'c1': 'B'}\n", | |
| "Best metrics: ({'branin': np.float64(5.78071107054917)}, {'branin': {'branin': np.float64(11.477393275837601)}})\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 900x600 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "best_parameters, metrics = ax_client.get_best_parameters()\n", | |
| "\n", | |
| "print(f\"Best parameters: {best_parameters}\")\n", | |
| "print(f\"Best metrics: {metrics}\")\n", | |
| "\n", | |
| "\n", | |
| "# Plot results\n", | |
| "objectives = ax_client.objective_names\n", | |
| "df = ax_client.get_trials_data_frame()\n", | |
| "\n", | |
| "fig, ax = plt.subplots(figsize=(6, 4), dpi=150)\n", | |
| "# ax.scatter(df.index, df[objectives], ec=\"k\", fc=\"none\", label=\"Observed\")\n", | |
| "ax.plot(\n", | |
| " df.index,\n", | |
| " np.minimum.accumulate(df[objectives]),\n", | |
| " color=\"#0033FF\",\n", | |
| " lw=2,\n", | |
| " label=\"Best to Trial\",\n", | |
| ")\n", | |
| "ax.set_xlabel(\"Trial Number\")\n", | |
| "ax.set_ylabel(objectives[0])\n", | |
| "\n", | |
| "ax.legend()\n", | |
| "plt.show()" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyOTZohWpK3IF+V/aV5qVukH", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment