Skip to content

Instantly share code, notes, and snippets.

@wesslen
Created November 18, 2024 02:13
Show Gist options
  • Select an option

  • Save wesslen/b2a7d0ab2a9c94d8a178b4ff5e86c644 to your computer and use it in GitHub Desktop.

Select an option

Save wesslen/b2a7d0ab2a9c94d8a178b4ff5e86c644 to your computer and use it in GitHub Desktop.
conformal-uncertainty.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNH6sn7W6a/qqd7pBIMeE7P",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/wesslen/b2a7d0ab2a9c94d8a178b4ff5e86c644/conformal-uncertainty.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"from typing import List, Tuple\n",
"import torch\n",
"from torch.nn.functional import softmax\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"class ConformalLLM:\n",
" def __init__(self, model_name: str = \"gpt2\", alpha: float = 0.1):\n",
" \"\"\"\n",
" Initialize conformal predictor for LLM\n",
"\n",
" Args:\n",
" model_name: Name of the HuggingFace model to use\n",
" alpha: Desired miscoverage level (e.g., 0.1 for 90% coverage)\n",
" \"\"\"\n",
" self.model = AutoModelForCausalLM.from_pretrained(model_name)\n",
" self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" self.alpha = alpha\n",
" self.calibration_scores = None\n",
"\n",
" def get_token_probabilities(self, text: str) -> torch.Tensor:\n",
" \"\"\"Get probability distribution over next tokens.\"\"\"\n",
" inputs = self.tokenizer(text, return_tensors=\"pt\")\n",
" with torch.no_grad():\n",
" outputs = self.model(**inputs)\n",
" probs = softmax(outputs.logits[:, -1, :], dim=-1)\n",
" return probs.squeeze()\n",
"\n",
" def calibrate(self, calibration_data: List[Tuple[str, str]]):\n",
" \"\"\"\n",
" Calibrate the predictor using validation data\n",
"\n",
" Args:\n",
" calibration_data: List of (context, next_token) pairs\n",
" \"\"\"\n",
" scores = []\n",
"\n",
" for context, next_token in calibration_data:\n",
" # Get probability distribution over tokens\n",
" probs = self.get_token_probabilities(context)\n",
"\n",
" # Get token ID for the actual next token\n",
" true_token_id = self.tokenizer.encode(next_token)[0]\n",
"\n",
" # Store negative log probability as nonconformity score\n",
" score = -torch.log(probs[true_token_id]).item()\n",
" scores.append(score)\n",
"\n",
" # Store calibration scores for prediction\n",
" self.calibration_scores = np.sort(scores)\n",
"\n",
" def predict(self, context: str, n_tokens: int = 100) -> List[str]:\n",
" \"\"\"\n",
" Generate text with conformal prediction sets\n",
"\n",
" Args:\n",
" context: Input text prompt\n",
" n_tokens: Number of tokens to generate\n",
"\n",
" Returns:\n",
" List of generated tokens within the prediction set\n",
" \"\"\"\n",
" if self.calibration_scores is None:\n",
" raise ValueError(\"Must call calibrate() before predict()\")\n",
"\n",
" # Get quantile threshold for conformal prediction\n",
" n = len(self.calibration_scores)\n",
" # Fix: Ensure quantile is in [0, 1]\n",
" q = min(max((n + 1) * (1 - self.alpha) / n, 0), 1)\n",
" threshold = np.quantile(self.calibration_scores, q)\n",
"\n",
" # Get probabilities for next token\n",
" probs = self.get_token_probabilities(context)\n",
"\n",
" # Find tokens within confidence set\n",
" log_probs = torch.log(probs)\n",
" valid_tokens = (-log_probs <= threshold).nonzero().squeeze()\n",
"\n",
" # Handle case where valid_tokens is 0-dimensional (single token)\n",
" if valid_tokens.dim() == 0:\n",
" valid_tokens = valid_tokens.unsqueeze(0)\n",
"\n",
" # Convert token IDs to text\n",
" prediction_set = [self.tokenizer.decode([idx.item()])\n",
" for idx in valid_tokens]\n",
"\n",
" return prediction_set\n",
"\n",
"# Example usage\n",
"def main():\n",
" # Initialize predictor\n",
" predictor = ConformalLLM(alpha=0.1) # 90% coverage\n",
"\n",
" # Create simple calibration dataset\n",
" calibration_data = [\n",
" (\"The capital of France is\", \" Paris\"),\n",
" (\"The capital of Italy is\", \" Rome\"),\n",
" (\"The capital of Spain is\", \" Madrid\"),\n",
" (\"The capital of Germany is\", \" Berlin\"),\n",
" (\"The capital of England is\", \" London\")\n",
" ]\n",
"\n",
" # Calibrate the predictor\n",
" predictor.calibrate(calibration_data)\n",
"\n",
" # Make prediction with conformal set\n",
" context = \"The capital of Portugal is\"\n",
" prediction_set = predictor.predict(context)\n",
"\n",
" print(f\"Input: {context}\")\n",
" print(f\"Conformal prediction set (90% coverage):\")\n",
" print(prediction_set)\n",
"\n",
"if __name__ == \"__main__\":\n",
" main()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SYhCrVxgCc1v",
"outputId": "a8540de0-0f48-4fd1-c0e2-a3cb064a45b0"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Input: The capital of Portugal is\n",
"Conformal prediction set (90% coverage):\n",
"[' a', ' the', ' now', ' Lisbon']\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment