Created
March 14, 2025 19:01
-
-
Save jesuinovieira/610b28c99b00c108a170c7a276943d3b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### SETUP" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import time\n", | |
| "import os\n", | |
| "import sys\n", | |
| "import findspark\n", | |
| "import pandas as pd\n", | |
| "import shap\n", | |
| "import lightgbm as lgb\n", | |
| "import requests\n", | |
| "from typing import Iterable\n", | |
| "from sklearn.model_selection import train_test_split\n", | |
| "\n", | |
| "findspark.init()\n", | |
| "os.environ[\"PYSPARK_PYTHON\"] = sys.executable\n", | |
| "\n", | |
| "import pyspark.sql\n", | |
| "import pyspark.sql.types as T\n", | |
| "\n", | |
| "conf = pyspark.SparkConf().setAppName(\"bug\")\n", | |
| "# Set maxRecordsPerBatch to 1 since we are interested in a single iteration\n", | |
| "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1\")\n", | |
| "spark = pyspark.sql.SparkSession.builder.config(conf=conf).getOrCreate()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### LIGHTGBM" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# NOTE: set size to train lgb model with different number of estimators\n", | |
| "# s: n_estimator=1000, m: n_estimators=2500, l: n_estimators=5000\n", | |
| "size = \"s\"\n", | |
| "\n", | |
| "# Download the dataset if it doesn't exist\n", | |
| "url = \"https://raw.githubusercontent.com/saul-chirinos/Electricity-demand-forecasting-in-Panama/master/Data/continuous%20dataset.csv\"\n", | |
| "filename = \"panama.csv\"\n", | |
| "\n", | |
| "if not os.path.isfile(filename):\n", | |
| " response = requests.get(url)\n", | |
| " response.raise_for_status()\n", | |
| " with open(filename, \"wb\") as file:\n", | |
| " file.write(response.content)\n", | |
| "\n", | |
| "# Load data\n", | |
| "data = pd.read_csv(filename).drop(columns=[\"datetime\", \"QV2M_san\", \"T2M_san\", \"T2M_toc\"])\n", | |
| "X, y = data.drop(columns=[\"nat_demand\"]), data[\"nat_demand\"]\n", | |
| "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", | |
| "\n", | |
| "# Train model\n", | |
| "params = {\"n_estimators\": 1000 if size == \"s\" else 2500 if size == \"m\" else 5000, \"num_leaves\": 256}\n", | |
| "train, test = lgb.Dataset(X_train, label=y_train), lgb.Dataset(X_test, label=y_test)\n", | |
| "predictor = lgb.train(params=params, train_set=train, valid_sets=[test])\n", | |
| "predictor.save_model(f\"lgb-{size}.txt\")\n", | |
| "\n", | |
| "# NOTE: use this for multiple runs to avoid retraining\n", | |
| "# \n", | |
| "# Load model\n", | |
| "# predictor = lgb.Booster(model_file=f\"lgb-{size}.txt\")\n", | |
| "\n", | |
| "print(f\"lgb-{size}: {os.path.getsize(f'lgb-{size}.txt') / (1024 * 1024):.2f} MB\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### SHAP" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def explain(df, model, background_data):\n", | |
| " def compute_shap(iterable: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:\n", | |
| " for i, batch in enumerate(iterable):\n", | |
| " if i > 0:\n", | |
| " break\n", | |
| "\n", | |
| " explainer.shap_values(batch, silent=False)\n", | |
| " yield pd.DataFrame(columns=[\"dummy\"])\n", | |
| "\n", | |
| " explainer = shap.KernelExplainer(\n", | |
| " model=model.predict,\n", | |
| " data=background_data,\n", | |
| " keep_index=True,\n", | |
| " link=\"identity\",\n", | |
| " )\n", | |
| "\n", | |
| " print(\"Computing shap values\")\n", | |
| " t1 = time.time()\n", | |
| "\n", | |
| " schema = T.StructType([T.StructField(\"dummy\", T.IntegerType())])\n", | |
| " shap_values = df.mapInPandas(compute_shap, schema=schema)\n", | |
| " shap_values.collect()\n", | |
| "\n", | |
| " t2 = time.time()\n", | |
| " print(f\"Elapsed time: {round(t2 - t1, 2)} seconds\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Select samples for background data and to be explained\n", | |
| "background_data = X_train.iloc[:10]\n", | |
| "df = spark.createDataFrame(X_test.iloc[:100]).coalesce(1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "print(f\"{pyspark.__version__=}\")\n", | |
| "explain(df, predictor, background_data)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": ".venv", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.11.11" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment