Created
November 18, 2024 02:13
-
-
Save wesslen/b2a7d0ab2a9c94d8a178b4ff5e86c644 to your computer and use it in GitHub Desktop.
conformal-uncertainty.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyNH6sn7W6a/qqd7pBIMeE7P", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/wesslen/b2a7d0ab2a9c94d8a178b4ff5e86c644/conformal-uncertainty.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import numpy as np\n", | |
| "from typing import List, Tuple\n", | |
| "import torch\n", | |
| "from torch.nn.functional import softmax\n", | |
| "from transformers import AutoModelForCausalLM, AutoTokenizer\n", | |
| "\n", | |
| "class ConformalLLM:\n", | |
| " def __init__(self, model_name: str = \"gpt2\", alpha: float = 0.1):\n", | |
| " \"\"\"\n", | |
| " Initialize conformal predictor for LLM\n", | |
| "\n", | |
| " Args:\n", | |
| " model_name: Name of the HuggingFace model to use\n", | |
| " alpha: Desired miscoverage level (e.g., 0.1 for 90% coverage)\n", | |
| " \"\"\"\n", | |
| " self.model = AutoModelForCausalLM.from_pretrained(model_name)\n", | |
| " self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n", | |
| " self.alpha = alpha\n", | |
| " self.calibration_scores = None\n", | |
| "\n", | |
| " def get_token_probabilities(self, text: str) -> torch.Tensor:\n", | |
| " \"\"\"Get probability distribution over next tokens.\"\"\"\n", | |
| " inputs = self.tokenizer(text, return_tensors=\"pt\")\n", | |
| " with torch.no_grad():\n", | |
| " outputs = self.model(**inputs)\n", | |
| " probs = softmax(outputs.logits[:, -1, :], dim=-1)\n", | |
| " return probs.squeeze()\n", | |
| "\n", | |
| " def calibrate(self, calibration_data: List[Tuple[str, str]]):\n", | |
| " \"\"\"\n", | |
| " Calibrate the predictor using validation data\n", | |
| "\n", | |
| " Args:\n", | |
| " calibration_data: List of (context, next_token) pairs\n", | |
| " \"\"\"\n", | |
| " scores = []\n", | |
| "\n", | |
| " for context, next_token in calibration_data:\n", | |
| " # Get probability distribution over tokens\n", | |
| " probs = self.get_token_probabilities(context)\n", | |
| "\n", | |
| " # Get token ID for the actual next token\n", | |
| " true_token_id = self.tokenizer.encode(next_token)[0]\n", | |
| "\n", | |
| " # Store negative log probability as nonconformity score\n", | |
| " score = -torch.log(probs[true_token_id]).item()\n", | |
| " scores.append(score)\n", | |
| "\n", | |
| " # Store calibration scores for prediction\n", | |
| " self.calibration_scores = np.sort(scores)\n", | |
| "\n", | |
| " def predict(self, context: str, n_tokens: int = 100) -> List[str]:\n", | |
| " \"\"\"\n", | |
| " Generate text with conformal prediction sets\n", | |
| "\n", | |
| " Args:\n", | |
| " context: Input text prompt\n", | |
| " n_tokens: Number of tokens to generate\n", | |
| "\n", | |
| " Returns:\n", | |
| " List of generated tokens within the prediction set\n", | |
| " \"\"\"\n", | |
| " if self.calibration_scores is None:\n", | |
| " raise ValueError(\"Must call calibrate() before predict()\")\n", | |
| "\n", | |
| " # Get quantile threshold for conformal prediction\n", | |
| " n = len(self.calibration_scores)\n", | |
| " # Fix: Ensure quantile is in [0, 1]\n", | |
| " q = min(max((n + 1) * (1 - self.alpha) / n, 0), 1)\n", | |
| " threshold = np.quantile(self.calibration_scores, q)\n", | |
| "\n", | |
| " # Get probabilities for next token\n", | |
| " probs = self.get_token_probabilities(context)\n", | |
| "\n", | |
| " # Find tokens within confidence set\n", | |
| " log_probs = torch.log(probs)\n", | |
| " valid_tokens = (-log_probs <= threshold).nonzero().squeeze()\n", | |
| "\n", | |
| " # Handle case where valid_tokens is 0-dimensional (single token)\n", | |
| " if valid_tokens.dim() == 0:\n", | |
| " valid_tokens = valid_tokens.unsqueeze(0)\n", | |
| "\n", | |
| " # Convert token IDs to text\n", | |
| " prediction_set = [self.tokenizer.decode([idx.item()])\n", | |
| " for idx in valid_tokens]\n", | |
| "\n", | |
| " return prediction_set\n", | |
| "\n", | |
| "# Example usage\n", | |
| "def main():\n", | |
| " # Initialize predictor\n", | |
| " predictor = ConformalLLM(alpha=0.1) # 90% coverage\n", | |
| "\n", | |
| " # Create simple calibration dataset\n", | |
| " calibration_data = [\n", | |
| " (\"The capital of France is\", \" Paris\"),\n", | |
| " (\"The capital of Italy is\", \" Rome\"),\n", | |
| " (\"The capital of Spain is\", \" Madrid\"),\n", | |
| " (\"The capital of Germany is\", \" Berlin\"),\n", | |
| " (\"The capital of England is\", \" London\")\n", | |
| " ]\n", | |
| "\n", | |
| " # Calibrate the predictor\n", | |
| " predictor.calibrate(calibration_data)\n", | |
| "\n", | |
| " # Make prediction with conformal set\n", | |
| " context = \"The capital of Portugal is\"\n", | |
| " prediction_set = predictor.predict(context)\n", | |
| "\n", | |
| " print(f\"Input: {context}\")\n", | |
| " print(f\"Conformal prediction set (90% coverage):\")\n", | |
| " print(prediction_set)\n", | |
| "\n", | |
| "if __name__ == \"__main__\":\n", | |
| " main()" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "SYhCrVxgCc1v", | |
| "outputId": "a8540de0-0f48-4fd1-c0e2-a3cb064a45b0" | |
| }, | |
| "execution_count": 2, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Input: The capital of Portugal is\n", | |
| "Conformal prediction set (90% coverage):\n", | |
| "[' a', ' the', ' now', ' Lisbon']\n" | |
| ] | |
| } | |
| ] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment