Last active
October 10, 2025 22:04
-
-
Save jc4p/334d52de25624e6419aa21f4ccd35f10 to your computer and use it in GitHub Desktop.
Making and running a BERT classifier on Farcaster casts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Classify Farcaster casts using Google Gemini Batch API. | |
| Classifies each cast into one of 5 categories: | |
| - Personal Life | |
| - Spam | |
| - App/Bot Interactions | |
| - Crypto Content | |
| - Not Enough Context To Tell | |
| """ | |
| import json | |
| import time | |
| from google import genai | |
| from google.genai import types | |
| def create_classification_prompt(cast_text): | |
| """Create the prompt for classifying the cast.""" | |
| return f"""Classify this Farcaster cast into exactly ONE of these categories: | |
| 1. Personal Life | |
| 2. Spam | |
| 3. App/Bot Interactions | |
| 4. Crypto Content | |
| 5. Not Enough Context To Tell | |
| Cast: "{cast_text}" | |
| Respond in this exact format: | |
| Category: [category name] | |
| Confidence: [number from 0-10]""" | |
| def main(): | |
| # Initialize the client | |
| client = genai.Client() | |
| # Load the filtered casts from JSON | |
| print("Loading data/casts_filtered.json...") | |
| with open('data/casts_filtered.json', 'r') as f: | |
| casts = [] | |
| for line in f: | |
| casts.append(json.loads(line)) | |
| # Sample to 10k records | |
| import random | |
| if len(casts) > 10000: | |
| print(f"Loaded {len(casts)} casts, sampling 10,000...") | |
| random.seed(42) # For reproducibility | |
| casts = random.sample(casts, 10000) | |
| print(f"Processing {len(casts)} casts") | |
| # Create JSONL file with batch requests | |
| jsonl_filename = 'classification_batch_requests.jsonl' | |
| print(f"Creating {jsonl_filename}...") | |
| with open(jsonl_filename, 'w') as f: | |
| for i, cast in enumerate(casts): | |
| cast_hash = cast['Hash'] | |
| cast_text = cast['Text'] | |
| # Create the classification prompt | |
| prompt = create_classification_prompt(cast_text) | |
| # Create the request object | |
| request = { | |
| "key": f"cast-{i}-{cast_hash}", | |
| "request": { | |
| "contents": [ | |
| { | |
| "parts": [{"text": prompt}], | |
| "role": "user" | |
| } | |
| ], | |
| "generation_config": { | |
| "temperature": 0.1 # Very low temperature for consistent classification | |
| } | |
| } | |
| } | |
| f.write(json.dumps(request) + '\n') | |
| print(f"Created {jsonl_filename} with {len(casts)} requests") | |
| # Upload the file to the File API | |
| print("Uploading JSONL file to Google File API...") | |
| uploaded_file = client.files.upload( | |
| file=jsonl_filename, | |
| config=types.UploadFileConfig( | |
| display_name='farcaster-cast-classification', | |
| mime_type='application/jsonl' | |
| ) | |
| ) | |
| print(f"Uploaded file: {uploaded_file.name}") | |
| # Create the batch job | |
| print("Creating batch job...") | |
| batch_job = client.batches.create( | |
| model="gemini-flash-latest", | |
| src=uploaded_file.name, | |
| config={ | |
| 'display_name': "farcaster-cast-classification", | |
| }, | |
| ) | |
| job_name = batch_job.name | |
| print(f"Created batch job: {job_name}") | |
| # Poll for job completion | |
| print("\nPolling job status...") | |
| completed_states = { | |
| 'JOB_STATE_SUCCEEDED', | |
| 'JOB_STATE_FAILED', | |
| 'JOB_STATE_CANCELLED', | |
| 'JOB_STATE_EXPIRED', | |
| } | |
| while batch_job.state.name not in completed_states: | |
| print(f"Current state: {batch_job.state.name}") | |
| time.sleep(30) # Wait 30 seconds before polling again | |
| batch_job = client.batches.get(name=job_name) | |
| print(f"\nJob finished with state: {batch_job.state.name}") | |
| # Handle results | |
| if batch_job.state.name == 'JOB_STATE_SUCCEEDED': | |
| if batch_job.dest and batch_job.dest.file_name: | |
| result_file_name = batch_job.dest.file_name | |
| print(f"Results are in file: {result_file_name}") | |
| # Download the results | |
| print("Downloading results...") | |
| file_content = client.files.download(file=result_file_name) | |
| # Save raw results (bulk file) | |
| raw_results_filename = 'cast_classification_raw.jsonl' | |
| with open(raw_results_filename, 'wb') as f: | |
| f.write(file_content) | |
| print(f"Saved raw results to {raw_results_filename}") | |
| # Parse and format results | |
| print("Parsing results...") | |
| results = [] | |
| for line in file_content.decode('utf-8').strip().split('\n'): | |
| result = json.loads(line) | |
| results.append(result) | |
| # Create a structured output with the original cast data + classification | |
| structured_results = [] | |
| for i, cast in enumerate(casts): | |
| # Find matching result | |
| key = f"cast-{i}-{cast['Hash']}" | |
| matching_result = next((r for r in results if r.get('key') == key), None) | |
| output = { | |
| 'Fid': cast['Fid'], | |
| 'Hash': cast['Hash'], | |
| 'Text': cast['Text'], | |
| 'Timestamp': cast['Timestamp'], | |
| 'Classification': None, | |
| 'Confidence': None, | |
| 'error': None | |
| } | |
| if matching_result: | |
| if 'response' in matching_result: | |
| # Extract the classification text | |
| response_text = matching_result['response']['candidates'][0]['content']['parts'][0]['text'].strip() | |
| # Parse the response to extract category and confidence | |
| lines = response_text.split('\n') | |
| for line in lines: | |
| if line.startswith('Category:'): | |
| output['Classification'] = line.replace('Category:', '').strip() | |
| elif line.startswith('Confidence:'): | |
| try: | |
| output['Confidence'] = int(line.replace('Confidence:', '').strip()) | |
| except ValueError: | |
| output['Confidence'] = None | |
| elif 'status' in matching_result: | |
| output['error'] = matching_result['status'] | |
| structured_results.append(output) | |
| # Save structured results (filtered file for analysis) | |
| output_filename = 'cast_classification_results.json' | |
| with open(output_filename, 'w') as f: | |
| json.dump(structured_results, f, indent=2) | |
| print(f"Saved structured results to {output_filename}") | |
| # Print summary by category | |
| successful = sum(1 for r in structured_results if r['Classification']) | |
| failed = sum(1 for r in structured_results if r['error']) | |
| # Count by category | |
| category_counts = {} | |
| for r in structured_results: | |
| if r['Classification']: | |
| cat = r['Classification'] | |
| category_counts[cat] = category_counts.get(cat, 0) + 1 | |
| # Calculate average confidence per category | |
| category_confidence = {} | |
| for r in structured_results: | |
| if r['Classification'] and r['Confidence'] is not None: | |
| cat = r['Classification'] | |
| if cat not in category_confidence: | |
| category_confidence[cat] = [] | |
| category_confidence[cat].append(r['Confidence']) | |
| print(f"\nSummary:") | |
| print(f" Total casts: {len(structured_results)}") | |
| print(f" Successful: {successful}") | |
| print(f" Failed: {failed}") | |
| print(f"\nClassification breakdown:") | |
| for cat, count in sorted(category_counts.items(), key=lambda x: x[1], reverse=True): | |
| avg_conf = sum(category_confidence.get(cat, [0])) / len(category_confidence.get(cat, [1])) if cat in category_confidence else 0 | |
| print(f" {cat}: {count} (avg confidence: {avg_conf:.1f})") | |
| elif batch_job.state.name == 'JOB_STATE_FAILED': | |
| print(f"Error: {batch_job.error}") | |
| print("\nDone!") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Farcaster Cast Classification Training\n", | |
| "\n", | |
| "Training a BERT base model to classify casts into 5 categories:\n", | |
| "- Personal Life - Salient\n", | |
| "- Personal Life - Not Salient\n", | |
| "- Crypto Content\n", | |
| "- App/Bot Interactions\n", | |
| "- Not Enough Context To Tell\n", | |
| "\n", | |
| "(Spam category removed - low confidence predictions will be treated as spam)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/home/ubuntu/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
| " from .autonotebook import tqdm as notebook_tqdm\n", | |
| "/home/ubuntu/.venv/lib/python3.10/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.\n", | |
| " warnings.warn(\n", | |
| "/home/ubuntu/.venv/lib/python3.10/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.\n", | |
| " warnings.warn(\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Imports and environment setup\n", | |
| "import os\n", | |
| "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", | |
| "\n", | |
| "import torch\n", | |
| "from torch.utils.data import Dataset, DataLoader\n", | |
| "from transformers import (\n", | |
| " AutoTokenizer, \n", | |
| " AutoModelForSequenceClassification, \n", | |
| " Trainer, \n", | |
| " TrainingArguments\n", | |
| ")\n", | |
| "from sklearn.model_selection import train_test_split\n", | |
| "from sklearn.metrics import f1_score, classification_report, precision_recall_fscore_support, accuracy_score\n", | |
| "import numpy as np\n", | |
| "from collections import Counter\n", | |
| "import wandb\n", | |
| "import json\n", | |
| "import random" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[32m\u001b[41mERROR\u001b[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", | |
| "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mjc4p\u001b[0m to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "True" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Login to Weights & Biases\n", | |
| "import wandb\n", | |
| "wandb.login()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Loaded 8,030 samples\n", | |
| "Number of classes: 5\n", | |
| "Labels: ['Personal Life - Salient', 'Personal Life - Not Salient', 'Crypto Content', 'App/Bot Interactions', 'Not Enough Context To Tell']\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Load and prepare balanced training data\n", | |
| "def load_and_prepare_data(data_file):\n", | |
| " \"\"\"\n", | |
| " Load merged and balanced training data.\n", | |
| " \"\"\"\n", | |
| " # Load the training data\n", | |
| " with open(data_file, 'r') as f:\n", | |
| " data = json.load(f)\n", | |
| " \n", | |
| " # Define label mapping for 5 categories (no Spam)\n", | |
| " label_map = {\n", | |
| " 'Personal Life - Salient': 0,\n", | |
| " 'Personal Life - Not Salient': 1,\n", | |
| " 'Crypto Content': 2,\n", | |
| " 'App/Bot Interactions': 3,\n", | |
| " 'Not Enough Context To Tell': 4\n", | |
| " }\n", | |
| " \n", | |
| " # Prepare data\n", | |
| " texts = []\n", | |
| " labels = []\n", | |
| " \n", | |
| " for item in data:\n", | |
| " classification = item.get('Classification')\n", | |
| " if classification and classification in label_map:\n", | |
| " texts.append(item['Text'])\n", | |
| " labels.append(label_map[classification])\n", | |
| " \n", | |
| " return texts, labels, label_map\n", | |
| "\n", | |
| "# Load the balanced dataset\n", | |
| "texts, labels, label_map = load_and_prepare_data('training_data_balanced.json')\n", | |
| "\n", | |
| "print(f\"Loaded {len(texts):,} samples\")\n", | |
| "print(f\"Number of classes: {len(label_map)}\")\n", | |
| "print(f\"Labels: {list(label_map.keys())}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Dataset statistics:\n", | |
| "Total samples: 8030\n", | |
| "\n", | |
| "Class distribution:\n", | |
| " Personal Life - Salient: 486 (6.1%)\n", | |
| " Personal Life - Not Salient: 2582 (32.2%)\n", | |
| " Crypto Content: 2582 (32.2%)\n", | |
| " App/Bot Interactions: 1089 (13.6%)\n", | |
| " Not Enough Context To Tell: 1291 (16.1%)\n", | |
| "\n", | |
| "Sample texts per category:\n", | |
| "\n", | |
| "Personal Life - Salient:\n", | |
| " Agreed, Syracuse's terrain makes for pleasant walks. Flat surfaces are great for leisurely strolls and easy biking too....\n", | |
| "\n", | |
| "Personal Life - Not Salient:\n", | |
| " You for talk since nau 😂😭...\n", | |
| "\n", | |
| "Crypto Content:\n", | |
| " Nike’s .SWOOSH platform hits 5M users – Sneaker NFTs minted. 👟 #NFT #Fashion...\n", | |
| "\n", | |
| "App/Bot Interactions:\n", | |
| " I currently rank #421 on The Leaderboard. Where do you rank?...\n", | |
| "\n", | |
| "Not Enough Context To Tell:\n", | |
| " I couldn't have said it better myself. Well done!...\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Display dataset statistics\n", | |
| "from collections import Counter\n", | |
| "\n", | |
| "label_counts = Counter(labels)\n", | |
| "label_names = sorted(label_map.keys(), key=lambda x: label_map[x])\n", | |
| "\n", | |
| "print(\"Dataset statistics:\")\n", | |
| "print(f\"Total samples: {len(texts)}\")\n", | |
| "print(\"\\nClass distribution:\")\n", | |
| "for name in label_names:\n", | |
| " idx = label_map[name]\n", | |
| " count = label_counts[idx]\n", | |
| " percentage = (count / len(texts)) * 100\n", | |
| " print(f\" {name}: {count} ({percentage:.1f}%)\")\n", | |
| "\n", | |
| "# Show sample texts\n", | |
| "print(\"\\nSample texts per category:\")\n", | |
| "for name in label_names:\n", | |
| " idx = label_map[name]\n", | |
| " sample_indices = [i for i, l in enumerate(labels) if l == idx]\n", | |
| " if sample_indices:\n", | |
| " sample_idx = sample_indices[0]\n", | |
| " print(f\"\\n{name}:\")\n", | |
| " print(f\" {texts[sample_idx][:150]}...\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Define dataset class for single-label classification\n", | |
| "class TextClassificationDataset(Dataset):\n", | |
| " def __init__(self, texts, labels, tokenizer, max_length=512):\n", | |
| " self.encodings = tokenizer(\n", | |
| " texts,\n", | |
| " padding=True,\n", | |
| " truncation=True,\n", | |
| " max_length=max_length,\n", | |
| " return_tensors=None\n", | |
| " )\n", | |
| " self.labels = labels\n", | |
| "\n", | |
| " def __getitem__(self, idx):\n", | |
| " item = {\n", | |
| " key: torch.tensor(val[idx]) \n", | |
| " for key, val in self.encodings.items()\n", | |
| " }\n", | |
| " item['labels'] = torch.tensor(self.labels[idx])\n", | |
| " return item\n", | |
| "\n", | |
| " def __len__(self):\n", | |
| " return len(self.labels)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Define weighted trainer for handling remaining class imbalance\n", | |
| "class WeightedTrainer(Trainer):\n", | |
| " def __init__(self, class_weights, *args, **kwargs):\n", | |
| " super().__init__(*args, **kwargs)\n", | |
| " self.class_weights = torch.FloatTensor(class_weights)\n", | |
| "\n", | |
| " def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n", | |
| " labels = inputs.pop(\"labels\")\n", | |
| " outputs = model(**inputs)\n", | |
| " logits = outputs.logits\n", | |
| " \n", | |
| " # Move class weights to same device as model\n", | |
| " weights = self.class_weights.to(logits.device)\n", | |
| " \n", | |
| " # Apply class weights to CrossEntropyLoss\n", | |
| " loss_fct = torch.nn.CrossEntropyLoss(weight=weights)\n", | |
| " loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1))\n", | |
| " \n", | |
| " return (loss, outputs) if return_outputs else loss" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Training function\n", | |
| "def train_classifier(\n", | |
| " texts,\n", | |
| " labels,\n", | |
| " label_map,\n", | |
| " model_name='bert-base-uncased',\n", | |
| " batch_size=128,\n", | |
| " num_epochs=10,\n", | |
| " learning_rate=2e-5,\n", | |
| " max_length=512,\n", | |
| " weight_decay=0.01,\n", | |
| " warmup_ratio=0.1,\n", | |
| "):\n", | |
| " \"\"\"\n", | |
| " Training function for single-label classification with class weighting\n", | |
| " \"\"\"\n", | |
| " wandb.init(project=\"farcaster-cast-classification\")\n", | |
| " \n", | |
| " # Calculate class weights (inverse frequency) for remaining imbalance\n", | |
| " label_counts = Counter(labels)\n", | |
| " num_samples = len(labels)\n", | |
| " num_classes = len(label_map)\n", | |
| " \n", | |
| " class_weights = []\n", | |
| " for i in range(num_classes):\n", | |
| " count = label_counts.get(i, 1)\n", | |
| " weight = num_samples / (num_classes * count)\n", | |
| " class_weights.append(weight)\n", | |
| " \n", | |
| " print(\"Class weights:\")\n", | |
| " label_names = sorted(label_map.keys(), key=lambda x: label_map[x])\n", | |
| " for i, name in enumerate(label_names):\n", | |
| " print(f\" {name}: {class_weights[i]:.2f}\")\n", | |
| " \n", | |
| " # Initialize tokenizer and model\n", | |
| " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", | |
| " model = AutoModelForSequenceClassification.from_pretrained(\n", | |
| " model_name,\n", | |
| " num_labels=len(label_map),\n", | |
| " problem_type=\"single_label_classification\"\n", | |
| " )\n", | |
| " \n", | |
| " model.gradient_checkpointing_enable()\n", | |
| " \n", | |
| " # Split data\n", | |
| " train_texts, val_texts, train_labels, val_labels = train_test_split(\n", | |
| " texts, labels, test_size=0.15, stratify=labels, random_state=42\n", | |
| " )\n", | |
| " \n", | |
| " # Create datasets\n", | |
| " train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)\n", | |
| " val_dataset = TextClassificationDataset(val_texts, val_labels, tokenizer, max_length)\n", | |
| " \n", | |
| " # Calculate training steps\n", | |
| " num_training_steps = (len(train_dataset) // batch_size) * num_epochs\n", | |
| " num_warmup_steps = int(num_training_steps * warmup_ratio)\n", | |
| " \n", | |
| " # Define training arguments\n", | |
| " training_args = TrainingArguments(\n", | |
| " output_dir=\"./farcaster-classifier\",\n", | |
| " num_train_epochs=num_epochs,\n", | |
| " per_device_train_batch_size=batch_size,\n", | |
| " per_device_eval_batch_size=batch_size * 2,\n", | |
| " warmup_steps=num_warmup_steps,\n", | |
| " weight_decay=weight_decay,\n", | |
| " logging_dir=\"./logs\",\n", | |
| " logging_steps=10,\n", | |
| " eval_strategy=\"steps\",\n", | |
| " eval_steps=50,\n", | |
| " save_strategy=\"steps\",\n", | |
| " save_steps=50,\n", | |
| " load_best_model_at_end=True,\n", | |
| " metric_for_best_model=\"eval_f1_macro\",\n", | |
| " learning_rate=learning_rate,\n", | |
| " bf16=True,\n", | |
| " dataloader_num_workers=4,\n", | |
| " group_by_length=True,\n", | |
| " optim=\"adamw_torch_fused\",\n", | |
| " lr_scheduler_type=\"cosine\",\n", | |
| " gradient_checkpointing=True\n", | |
| " )\n", | |
| " \n", | |
| " def compute_metrics(eval_pred):\n", | |
| " predictions, labels = eval_pred\n", | |
| " predictions = np.argmax(predictions, axis=1)\n", | |
| " \n", | |
| " # Calculate per-class metrics\n", | |
| " metrics = {}\n", | |
| " label_names = sorted(label_map.keys(), key=lambda x: label_map[x])\n", | |
| " \n", | |
| " f1_scores = []\n", | |
| " for i, name in enumerate(label_names):\n", | |
| " # Get binary labels for this class\n", | |
| " binary_labels = (labels == i).astype(int)\n", | |
| " binary_preds = (predictions == i).astype(int)\n", | |
| " \n", | |
| " precision, recall, f1, _ = precision_recall_fscore_support(\n", | |
| " binary_labels, binary_preds, average='binary', zero_division=0\n", | |
| " )\n", | |
| " metrics[f\"f1_{name.lower().replace(' ', '_').replace('-', '')}\"[:20]] = f1\n", | |
| " metrics[f\"precision_{name.lower().replace(' ', '_').replace('-', '')}\"[:20]] = precision\n", | |
| " metrics[f\"recall_{name.lower().replace(' ', '_').replace('-', '')}\"[:20]] = recall\n", | |
| " f1_scores.append(f1)\n", | |
| " \n", | |
| " # Overall metrics\n", | |
| " metrics[\"f1_macro\"] = np.mean(f1_scores)\n", | |
| " metrics[\"accuracy\"] = accuracy_score(labels, predictions)\n", | |
| " \n", | |
| " return metrics\n", | |
| " \n", | |
| " # Initialize trainer with class weights\n", | |
| " trainer = WeightedTrainer(\n", | |
| " class_weights=class_weights,\n", | |
| " model=model,\n", | |
| " args=training_args,\n", | |
| " train_dataset=train_dataset,\n", | |
| " eval_dataset=val_dataset,\n", | |
| " compute_metrics=compute_metrics,\n", | |
| " )\n", | |
| " \n", | |
| " # Train the model\n", | |
| " trainer.train()\n", | |
| " \n", | |
| " # Final evaluation\n", | |
| " final_metrics = trainer.evaluate()\n", | |
| " print(\"\\nFinal Evaluation Metrics:\")\n", | |
| " for key, value in final_metrics.items():\n", | |
| " print(f\"{key}: {value}\")\n", | |
| " \n", | |
| " # Save the model and tokenizer\n", | |
| " model.save_pretrained(\"./farcaster-classifier-final\")\n", | |
| " tokenizer.save_pretrained(\"./farcaster-classifier-final\")\n", | |
| " \n", | |
| " # Save label map\n", | |
| " with open(\"./farcaster-classifier-final/label_map.json\", 'w') as f:\n", | |
| " json.dump(label_map, f)\n", | |
| " \n", | |
| " wandb.finish()\n", | |
| " \n", | |
| " return model, tokenizer" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "Tracking run with wandb version 0.22.2" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "Run data is saved locally in <code>/home/ubuntu/wandb/run-20251009_204124-j5y5edzx</code>" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "Syncing run <strong><a href='https://wandb.ai/jc4p/farcaster-cast-classification/runs/j5y5edzx' target=\"_blank\">vital-glitter-3</a></strong> to <a href='https://wandb.ai/jc4p/farcaster-cast-classification' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| " View project at <a href='https://wandb.ai/jc4p/farcaster-cast-classification' target=\"_blank\">https://wandb.ai/jc4p/farcaster-cast-classification</a>" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| " View run at <a href='https://wandb.ai/jc4p/farcaster-cast-classification/runs/j5y5edzx' target=\"_blank\">https://wandb.ai/jc4p/farcaster-cast-classification/runs/j5y5edzx</a>" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Class weights:\n", | |
| " Personal Life - Salient: 3.30\n", | |
| " Personal Life - Not Salient: 0.62\n", | |
| " Crypto Content: 0.62\n", | |
| " App/Bot Interactions: 1.47\n", | |
| " Not Enough Context To Tell: 1.24\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", | |
| "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "\n", | |
| " <div>\n", | |
| " \n", | |
| " <progress value='1668' max='1712' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
| " [1668/1712 02:33 < 00:04, 10.88 it/s, Epoch 7.79/8]\n", | |
| " </div>\n", | |
| " <table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: left;\">\n", | |
| " <th>Step</th>\n", | |
| " <th>Training Loss</th>\n", | |
| " <th>Validation Loss</th>\n", | |
| " <th>F1 Personal Life Sa</th>\n", | |
| " <th>Precision Personal L</th>\n", | |
| " <th>Recall Personal Life</th>\n", | |
| " <th>F1 Personal Life No</th>\n", | |
| " <th>F1 Crypto Content</th>\n", | |
| " <th>Precision Crypto Con</th>\n", | |
| " <th>Recall Crypto Conten</th>\n", | |
| " <th>F1 App/bot Interacti</th>\n", | |
| " <th>Precision App/bot In</th>\n", | |
| " <th>Recall App/bot Inter</th>\n", | |
| " <th>F1 Not Enough Contex</th>\n", | |
| " <th>Precision Not Enough</th>\n", | |
| " <th>Recall Not Enough Co</th>\n", | |
| " <th>F1 Macro</th>\n", | |
| " <th>Accuracy</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <td>50</td>\n", | |
| " <td>1.595800</td>\n", | |
| " <td>1.591155</td>\n", | |
| " <td>0.030303</td>\n", | |
| " <td>0.434368</td>\n", | |
| " <td>0.469072</td>\n", | |
| " <td>0.451053</td>\n", | |
| " <td>0.577519</td>\n", | |
| " <td>0.462016</td>\n", | |
| " <td>0.770026</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.085714</td>\n", | |
| " <td>0.562500</td>\n", | |
| " <td>0.046392</td>\n", | |
| " <td>0.228918</td>\n", | |
| " <td>0.408299</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>100</td>\n", | |
| " <td>1.525800</td>\n", | |
| " <td>1.460845</td>\n", | |
| " <td>0.119403</td>\n", | |
| " <td>0.538462</td>\n", | |
| " <td>0.306701</td>\n", | |
| " <td>0.390805</td>\n", | |
| " <td>0.682927</td>\n", | |
| " <td>0.620253</td>\n", | |
| " <td>0.759690</td>\n", | |
| " <td>0.215962</td>\n", | |
| " <td>0.460000</td>\n", | |
| " <td>0.141104</td>\n", | |
| " <td>0.448802</td>\n", | |
| " <td>0.388679</td>\n", | |
| " <td>0.530928</td>\n", | |
| " <td>0.371580</td>\n", | |
| " <td>0.460581</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>150</td>\n", | |
| " <td>1.369000</td>\n", | |
| " <td>1.313050</td>\n", | |
| " <td>0.084211</td>\n", | |
| " <td>0.661905</td>\n", | |
| " <td>0.358247</td>\n", | |
| " <td>0.464883</td>\n", | |
| " <td>0.609756</td>\n", | |
| " <td>0.743494</td>\n", | |
| " <td>0.516796</td>\n", | |
| " <td>0.543326</td>\n", | |
| " <td>0.439394</td>\n", | |
| " <td>0.711656</td>\n", | |
| " <td>0.438486</td>\n", | |
| " <td>0.315909</td>\n", | |
| " <td>0.716495</td>\n", | |
| " <td>0.428132</td>\n", | |
| " <td>0.496266</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>200</td>\n", | |
| " <td>1.214200</td>\n", | |
| " <td>1.148946</td>\n", | |
| " <td>0.370044</td>\n", | |
| " <td>0.796875</td>\n", | |
| " <td>0.657216</td>\n", | |
| " <td>0.720339</td>\n", | |
| " <td>0.758270</td>\n", | |
| " <td>0.746867</td>\n", | |
| " <td>0.770026</td>\n", | |
| " <td>0.580460</td>\n", | |
| " <td>0.545946</td>\n", | |
| " <td>0.619632</td>\n", | |
| " <td>0.516129</td>\n", | |
| " <td>0.598639</td>\n", | |
| " <td>0.453608</td>\n", | |
| " <td>0.589048</td>\n", | |
| " <td>0.650622</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>250</td>\n", | |
| " <td>1.010600</td>\n", | |
| " <td>1.004799</td>\n", | |
| " <td>0.453988</td>\n", | |
| " <td>0.841270</td>\n", | |
| " <td>0.546392</td>\n", | |
| " <td>0.662500</td>\n", | |
| " <td>0.734993</td>\n", | |
| " <td>0.847973</td>\n", | |
| " <td>0.648579</td>\n", | |
| " <td>0.592593</td>\n", | |
| " <td>0.495868</td>\n", | |
| " <td>0.736196</td>\n", | |
| " <td>0.535645</td>\n", | |
| " <td>0.427692</td>\n", | |
| " <td>0.716495</td>\n", | |
| " <td>0.595944</td>\n", | |
| " <td>0.629876</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>300</td>\n", | |
| " <td>0.931000</td>\n", | |
| " <td>0.896634</td>\n", | |
| " <td>0.450980</td>\n", | |
| " <td>0.762136</td>\n", | |
| " <td>0.809278</td>\n", | |
| " <td>0.785000</td>\n", | |
| " <td>0.824632</td>\n", | |
| " <td>0.855556</td>\n", | |
| " <td>0.795866</td>\n", | |
| " <td>0.653061</td>\n", | |
| " <td>0.622222</td>\n", | |
| " <td>0.687117</td>\n", | |
| " <td>0.575949</td>\n", | |
| " <td>0.745902</td>\n", | |
| " <td>0.469072</td>\n", | |
| " <td>0.657925</td>\n", | |
| " <td>0.722822</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>350</td>\n", | |
| " <td>0.879800</td>\n", | |
| " <td>0.839245</td>\n", | |
| " <td>0.490196</td>\n", | |
| " <td>0.800000</td>\n", | |
| " <td>0.752577</td>\n", | |
| " <td>0.775564</td>\n", | |
| " <td>0.800000</td>\n", | |
| " <td>0.838527</td>\n", | |
| " <td>0.764858</td>\n", | |
| " <td>0.655949</td>\n", | |
| " <td>0.689189</td>\n", | |
| " <td>0.625767</td>\n", | |
| " <td>0.621891</td>\n", | |
| " <td>0.600962</td>\n", | |
| " <td>0.644330</td>\n", | |
| " <td>0.668720</td>\n", | |
| " <td>0.717842</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>400</td>\n", | |
| " <td>0.751500</td>\n", | |
| " <td>0.808259</td>\n", | |
| " <td>0.494624</td>\n", | |
| " <td>0.816456</td>\n", | |
| " <td>0.664948</td>\n", | |
| " <td>0.732955</td>\n", | |
| " <td>0.814921</td>\n", | |
| " <td>0.916129</td>\n", | |
| " <td>0.733850</td>\n", | |
| " <td>0.672131</td>\n", | |
| " <td>0.605911</td>\n", | |
| " <td>0.754601</td>\n", | |
| " <td>0.621444</td>\n", | |
| " <td>0.539924</td>\n", | |
| " <td>0.731959</td>\n", | |
| " <td>0.667215</td>\n", | |
| " <td>0.707884</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>450</td>\n", | |
| " <td>0.580500</td>\n", | |
| " <td>0.793636</td>\n", | |
| " <td>0.523810</td>\n", | |
| " <td>0.781170</td>\n", | |
| " <td>0.791237</td>\n", | |
| " <td>0.786172</td>\n", | |
| " <td>0.848322</td>\n", | |
| " <td>0.882682</td>\n", | |
| " <td>0.816537</td>\n", | |
| " <td>0.674221</td>\n", | |
| " <td>0.626316</td>\n", | |
| " <td>0.730061</td>\n", | |
| " <td>0.639118</td>\n", | |
| " <td>0.686391</td>\n", | |
| " <td>0.597938</td>\n", | |
| " <td>0.694329</td>\n", | |
| " <td>0.748548</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>500</td>\n", | |
| " <td>0.619300</td>\n", | |
| " <td>0.752609</td>\n", | |
| " <td>0.507177</td>\n", | |
| " <td>0.852843</td>\n", | |
| " <td>0.657216</td>\n", | |
| " <td>0.742358</td>\n", | |
| " <td>0.838174</td>\n", | |
| " <td>0.901786</td>\n", | |
| " <td>0.782946</td>\n", | |
| " <td>0.702857</td>\n", | |
| " <td>0.657754</td>\n", | |
| " <td>0.754601</td>\n", | |
| " <td>0.657596</td>\n", | |
| " <td>0.587045</td>\n", | |
| " <td>0.747423</td>\n", | |
| " <td>0.689633</td>\n", | |
| " <td>0.729461</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>550</td>\n", | |
| " <td>0.604800</td>\n", | |
| " <td>0.790841</td>\n", | |
| " <td>0.532374</td>\n", | |
| " <td>0.787879</td>\n", | |
| " <td>0.804124</td>\n", | |
| " <td>0.795918</td>\n", | |
| " <td>0.866920</td>\n", | |
| " <td>0.850746</td>\n", | |
| " <td>0.883721</td>\n", | |
| " <td>0.699422</td>\n", | |
| " <td>0.661202</td>\n", | |
| " <td>0.742331</td>\n", | |
| " <td>0.659091</td>\n", | |
| " <td>0.734177</td>\n", | |
| " <td>0.597938</td>\n", | |
| " <td>0.710745</td>\n", | |
| " <td>0.770124</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>600</td>\n", | |
| " <td>0.560500</td>\n", | |
| " <td>0.717147</td>\n", | |
| " <td>0.544503</td>\n", | |
| " <td>0.858491</td>\n", | |
| " <td>0.703608</td>\n", | |
| " <td>0.773371</td>\n", | |
| " <td>0.865789</td>\n", | |
| " <td>0.882038</td>\n", | |
| " <td>0.850129</td>\n", | |
| " <td>0.727811</td>\n", | |
| " <td>0.702857</td>\n", | |
| " <td>0.754601</td>\n", | |
| " <td>0.674699</td>\n", | |
| " <td>0.633484</td>\n", | |
| " <td>0.721649</td>\n", | |
| " <td>0.717235</td>\n", | |
| " <td>0.760996</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>650</td>\n", | |
| " <td>0.492000</td>\n", | |
| " <td>0.872252</td>\n", | |
| " <td>0.480000</td>\n", | |
| " <td>0.821023</td>\n", | |
| " <td>0.744845</td>\n", | |
| " <td>0.781081</td>\n", | |
| " <td>0.875989</td>\n", | |
| " <td>0.894879</td>\n", | |
| " <td>0.857881</td>\n", | |
| " <td>0.715084</td>\n", | |
| " <td>0.656410</td>\n", | |
| " <td>0.785276</td>\n", | |
| " <td>0.675991</td>\n", | |
| " <td>0.617021</td>\n", | |
| " <td>0.747423</td>\n", | |
| " <td>0.705629</td>\n", | |
| " <td>0.766805</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>700</td>\n", | |
| " <td>0.396000</td>\n", | |
| " <td>0.870288</td>\n", | |
| " <td>0.503817</td>\n", | |
| " <td>0.843844</td>\n", | |
| " <td>0.724227</td>\n", | |
| " <td>0.779473</td>\n", | |
| " <td>0.858942</td>\n", | |
| " <td>0.837838</td>\n", | |
| " <td>0.881137</td>\n", | |
| " <td>0.709877</td>\n", | |
| " <td>0.714286</td>\n", | |
| " <td>0.705521</td>\n", | |
| " <td>0.663636</td>\n", | |
| " <td>0.593496</td>\n", | |
| " <td>0.752577</td>\n", | |
| " <td>0.703149</td>\n", | |
| " <td>0.760166</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>750</td>\n", | |
| " <td>0.343400</td>\n", | |
| " <td>0.845435</td>\n", | |
| " <td>0.519481</td>\n", | |
| " <td>0.818697</td>\n", | |
| " <td>0.744845</td>\n", | |
| " <td>0.780027</td>\n", | |
| " <td>0.840499</td>\n", | |
| " <td>0.907186</td>\n", | |
| " <td>0.782946</td>\n", | |
| " <td>0.711111</td>\n", | |
| " <td>0.649746</td>\n", | |
| " <td>0.785276</td>\n", | |
| " <td>0.691244</td>\n", | |
| " <td>0.625000</td>\n", | |
| " <td>0.773196</td>\n", | |
| " <td>0.708472</td>\n", | |
| " <td>0.755187</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>800</td>\n", | |
| " <td>0.401300</td>\n", | |
| " <td>0.907402</td>\n", | |
| " <td>0.518519</td>\n", | |
| " <td>0.804178</td>\n", | |
| " <td>0.793814</td>\n", | |
| " <td>0.798962</td>\n", | |
| " <td>0.865409</td>\n", | |
| " <td>0.843137</td>\n", | |
| " <td>0.888889</td>\n", | |
| " <td>0.713846</td>\n", | |
| " <td>0.716049</td>\n", | |
| " <td>0.711656</td>\n", | |
| " <td>0.692708</td>\n", | |
| " <td>0.700000</td>\n", | |
| " <td>0.685567</td>\n", | |
| " <td>0.717889</td>\n", | |
| " <td>0.776763</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>850</td>\n", | |
| " <td>0.360000</td>\n", | |
| " <td>0.906668</td>\n", | |
| " <td>0.500000</td>\n", | |
| " <td>0.795918</td>\n", | |
| " <td>0.804124</td>\n", | |
| " <td>0.800000</td>\n", | |
| " <td>0.869677</td>\n", | |
| " <td>0.868557</td>\n", | |
| " <td>0.870801</td>\n", | |
| " <td>0.703226</td>\n", | |
| " <td>0.741497</td>\n", | |
| " <td>0.668712</td>\n", | |
| " <td>0.685851</td>\n", | |
| " <td>0.641256</td>\n", | |
| " <td>0.737113</td>\n", | |
| " <td>0.711751</td>\n", | |
| " <td>0.774274</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>900</td>\n", | |
| " <td>0.209300</td>\n", | |
| " <td>0.899608</td>\n", | |
| " <td>0.571429</td>\n", | |
| " <td>0.835165</td>\n", | |
| " <td>0.783505</td>\n", | |
| " <td>0.808511</td>\n", | |
| " <td>0.864230</td>\n", | |
| " <td>0.873351</td>\n", | |
| " <td>0.855297</td>\n", | |
| " <td>0.709877</td>\n", | |
| " <td>0.714286</td>\n", | |
| " <td>0.705521</td>\n", | |
| " <td>0.705314</td>\n", | |
| " <td>0.663636</td>\n", | |
| " <td>0.752577</td>\n", | |
| " <td>0.731872</td>\n", | |
| " <td>0.780083</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>950</td>\n", | |
| " <td>0.226400</td>\n", | |
| " <td>0.973592</td>\n", | |
| " <td>0.523810</td>\n", | |
| " <td>0.799479</td>\n", | |
| " <td>0.791237</td>\n", | |
| " <td>0.795337</td>\n", | |
| " <td>0.855643</td>\n", | |
| " <td>0.869333</td>\n", | |
| " <td>0.842377</td>\n", | |
| " <td>0.707965</td>\n", | |
| " <td>0.681818</td>\n", | |
| " <td>0.736196</td>\n", | |
| " <td>0.676399</td>\n", | |
| " <td>0.640553</td>\n", | |
| " <td>0.716495</td>\n", | |
| " <td>0.711831</td>\n", | |
| " <td>0.767635</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1000</td>\n", | |
| " <td>0.240300</td>\n", | |
| " <td>0.921706</td>\n", | |
| " <td>0.551724</td>\n", | |
| " <td>0.814324</td>\n", | |
| " <td>0.791237</td>\n", | |
| " <td>0.802614</td>\n", | |
| " <td>0.851406</td>\n", | |
| " <td>0.883333</td>\n", | |
| " <td>0.821705</td>\n", | |
| " <td>0.685714</td>\n", | |
| " <td>0.641711</td>\n", | |
| " <td>0.736196</td>\n", | |
| " <td>0.684864</td>\n", | |
| " <td>0.660287</td>\n", | |
| " <td>0.711340</td>\n", | |
| " <td>0.715264</td>\n", | |
| " <td>0.765975</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1050</td>\n", | |
| " <td>0.217900</td>\n", | |
| " <td>1.024940</td>\n", | |
| " <td>0.481203</td>\n", | |
| " <td>0.773956</td>\n", | |
| " <td>0.811856</td>\n", | |
| " <td>0.792453</td>\n", | |
| " <td>0.869340</td>\n", | |
| " <td>0.870466</td>\n", | |
| " <td>0.868217</td>\n", | |
| " <td>0.711656</td>\n", | |
| " <td>0.711656</td>\n", | |
| " <td>0.711656</td>\n", | |
| " <td>0.678851</td>\n", | |
| " <td>0.687831</td>\n", | |
| " <td>0.670103</td>\n", | |
| " <td>0.706701</td>\n", | |
| " <td>0.770954</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1100</td>\n", | |
| " <td>0.196000</td>\n", | |
| " <td>1.072047</td>\n", | |
| " <td>0.487805</td>\n", | |
| " <td>0.778846</td>\n", | |
| " <td>0.835052</td>\n", | |
| " <td>0.805970</td>\n", | |
| " <td>0.869001</td>\n", | |
| " <td>0.872396</td>\n", | |
| " <td>0.865633</td>\n", | |
| " <td>0.704819</td>\n", | |
| " <td>0.692308</td>\n", | |
| " <td>0.717791</td>\n", | |
| " <td>0.689474</td>\n", | |
| " <td>0.704301</td>\n", | |
| " <td>0.675258</td>\n", | |
| " <td>0.711414</td>\n", | |
| " <td>0.777593</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1150</td>\n", | |
| " <td>0.113100</td>\n", | |
| " <td>1.009274</td>\n", | |
| " <td>0.565789</td>\n", | |
| " <td>0.794344</td>\n", | |
| " <td>0.796392</td>\n", | |
| " <td>0.795367</td>\n", | |
| " <td>0.861809</td>\n", | |
| " <td>0.838631</td>\n", | |
| " <td>0.886305</td>\n", | |
| " <td>0.689189</td>\n", | |
| " <td>0.766917</td>\n", | |
| " <td>0.625767</td>\n", | |
| " <td>0.683805</td>\n", | |
| " <td>0.682051</td>\n", | |
| " <td>0.685567</td>\n", | |
| " <td>0.719192</td>\n", | |
| " <td>0.771784</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1200</td>\n", | |
| " <td>0.160200</td>\n", | |
| " <td>1.015507</td>\n", | |
| " <td>0.567568</td>\n", | |
| " <td>0.782082</td>\n", | |
| " <td>0.832474</td>\n", | |
| " <td>0.806492</td>\n", | |
| " <td>0.867280</td>\n", | |
| " <td>0.882353</td>\n", | |
| " <td>0.852713</td>\n", | |
| " <td>0.727273</td>\n", | |
| " <td>0.696629</td>\n", | |
| " <td>0.760736</td>\n", | |
| " <td>0.674095</td>\n", | |
| " <td>0.733333</td>\n", | |
| " <td>0.623711</td>\n", | |
| " <td>0.728541</td>\n", | |
| " <td>0.780083</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1250</td>\n", | |
| " <td>0.117900</td>\n", | |
| " <td>1.075018</td>\n", | |
| " <td>0.514706</td>\n", | |
| " <td>0.802667</td>\n", | |
| " <td>0.775773</td>\n", | |
| " <td>0.788991</td>\n", | |
| " <td>0.859671</td>\n", | |
| " <td>0.841584</td>\n", | |
| " <td>0.878553</td>\n", | |
| " <td>0.701538</td>\n", | |
| " <td>0.703704</td>\n", | |
| " <td>0.699387</td>\n", | |
| " <td>0.678481</td>\n", | |
| " <td>0.666667</td>\n", | |
| " <td>0.690722</td>\n", | |
| " <td>0.708677</td>\n", | |
| " <td>0.766805</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1300</td>\n", | |
| " <td>0.103900</td>\n", | |
| " <td>1.106329</td>\n", | |
| " <td>0.516129</td>\n", | |
| " <td>0.791774</td>\n", | |
| " <td>0.793814</td>\n", | |
| " <td>0.792793</td>\n", | |
| " <td>0.871595</td>\n", | |
| " <td>0.875000</td>\n", | |
| " <td>0.868217</td>\n", | |
| " <td>0.703226</td>\n", | |
| " <td>0.741497</td>\n", | |
| " <td>0.668712</td>\n", | |
| " <td>0.668224</td>\n", | |
| " <td>0.611111</td>\n", | |
| " <td>0.737113</td>\n", | |
| " <td>0.710393</td>\n", | |
| " <td>0.770124</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1350</td>\n", | |
| " <td>0.130300</td>\n", | |
| " <td>1.165310</td>\n", | |
| " <td>0.504348</td>\n", | |
| " <td>0.788413</td>\n", | |
| " <td>0.806701</td>\n", | |
| " <td>0.797452</td>\n", | |
| " <td>0.868895</td>\n", | |
| " <td>0.864450</td>\n", | |
| " <td>0.873385</td>\n", | |
| " <td>0.696486</td>\n", | |
| " <td>0.726667</td>\n", | |
| " <td>0.668712</td>\n", | |
| " <td>0.668258</td>\n", | |
| " <td>0.622222</td>\n", | |
| " <td>0.721649</td>\n", | |
| " <td>0.707088</td>\n", | |
| " <td>0.770954</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1400</td>\n", | |
| " <td>0.113800</td>\n", | |
| " <td>1.137059</td>\n", | |
| " <td>0.515625</td>\n", | |
| " <td>0.776442</td>\n", | |
| " <td>0.832474</td>\n", | |
| " <td>0.803483</td>\n", | |
| " <td>0.863402</td>\n", | |
| " <td>0.861183</td>\n", | |
| " <td>0.865633</td>\n", | |
| " <td>0.723404</td>\n", | |
| " <td>0.716867</td>\n", | |
| " <td>0.730061</td>\n", | |
| " <td>0.664879</td>\n", | |
| " <td>0.692737</td>\n", | |
| " <td>0.639175</td>\n", | |
| " <td>0.714159</td>\n", | |
| " <td>0.775104</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1450</td>\n", | |
| " <td>0.096600</td>\n", | |
| " <td>1.125042</td>\n", | |
| " <td>0.511628</td>\n", | |
| " <td>0.785000</td>\n", | |
| " <td>0.809278</td>\n", | |
| " <td>0.796954</td>\n", | |
| " <td>0.868895</td>\n", | |
| " <td>0.864450</td>\n", | |
| " <td>0.873385</td>\n", | |
| " <td>0.719512</td>\n", | |
| " <td>0.715152</td>\n", | |
| " <td>0.723926</td>\n", | |
| " <td>0.671835</td>\n", | |
| " <td>0.673575</td>\n", | |
| " <td>0.670103</td>\n", | |
| " <td>0.713765</td>\n", | |
| " <td>0.774274</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1500</td>\n", | |
| " <td>0.157300</td>\n", | |
| " <td>1.149831</td>\n", | |
| " <td>0.516129</td>\n", | |
| " <td>0.787500</td>\n", | |
| " <td>0.811856</td>\n", | |
| " <td>0.799492</td>\n", | |
| " <td>0.870013</td>\n", | |
| " <td>0.866667</td>\n", | |
| " <td>0.873385</td>\n", | |
| " <td>0.721212</td>\n", | |
| " <td>0.712575</td>\n", | |
| " <td>0.730061</td>\n", | |
| " <td>0.670077</td>\n", | |
| " <td>0.664975</td>\n", | |
| " <td>0.675258</td>\n", | |
| " <td>0.715385</td>\n", | |
| " <td>0.775934</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1550</td>\n", | |
| " <td>0.112300</td>\n", | |
| " <td>1.122491</td>\n", | |
| " <td>0.515625</td>\n", | |
| " <td>0.780247</td>\n", | |
| " <td>0.814433</td>\n", | |
| " <td>0.796974</td>\n", | |
| " <td>0.868557</td>\n", | |
| " <td>0.866324</td>\n", | |
| " <td>0.870801</td>\n", | |
| " <td>0.717325</td>\n", | |
| " <td>0.710843</td>\n", | |
| " <td>0.723926</td>\n", | |
| " <td>0.661458</td>\n", | |
| " <td>0.668421</td>\n", | |
| " <td>0.654639</td>\n", | |
| " <td>0.711988</td>\n", | |
| " <td>0.772614</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1600</td>\n", | |
| " <td>0.108300</td>\n", | |
| " <td>1.167189</td>\n", | |
| " <td>0.508197</td>\n", | |
| " <td>0.779412</td>\n", | |
| " <td>0.819588</td>\n", | |
| " <td>0.798995</td>\n", | |
| " <td>0.867779</td>\n", | |
| " <td>0.862245</td>\n", | |
| " <td>0.873385</td>\n", | |
| " <td>0.716049</td>\n", | |
| " <td>0.720497</td>\n", | |
| " <td>0.711656</td>\n", | |
| " <td>0.668380</td>\n", | |
| " <td>0.666667</td>\n", | |
| " <td>0.670103</td>\n", | |
| " <td>0.711880</td>\n", | |
| " <td>0.774274</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1650</td>\n", | |
| " <td>0.069700</td>\n", | |
| " <td>1.164802</td>\n", | |
| " <td>0.528000</td>\n", | |
| " <td>0.782716</td>\n", | |
| " <td>0.817010</td>\n", | |
| " <td>0.799496</td>\n", | |
| " <td>0.867779</td>\n", | |
| " <td>0.862245</td>\n", | |
| " <td>0.873385</td>\n", | |
| " <td>0.716049</td>\n", | |
| " <td>0.720497</td>\n", | |
| " <td>0.711656</td>\n", | |
| " <td>0.668380</td>\n", | |
| " <td>0.666667</td>\n", | |
| " <td>0.670103</td>\n", | |
| " <td>0.715941</td>\n", | |
| " <td>0.775104</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table><p>" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# Train the model with BERT base\n", | |
| "model, tokenizer = train_classifier(\n", | |
| " texts,\n", | |
| " labels,\n", | |
| " label_map,\n", | |
| " model_name='bert-base-uncased',\n", | |
| " batch_size=32,\n", | |
| " num_epochs=8,\n", | |
| " learning_rate=2e-5,\n", | |
| " weight_decay=0.01,\n", | |
| " warmup_ratio=0.2\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "\n", | |
| "Text: GM everyone! Hope you have a great day\n", | |
| "Prediction: Personal Life - Not Salient (confidence: 0.974)\n", | |
| "\n", | |
| "Text: Just bought more ETH, bullish on the merge\n", | |
| "Prediction: Crypto Content (confidence: 0.988)\n", | |
| "\n", | |
| "Text: This bot is spamming my feed with nonsense\n", | |
| "Prediction: App/Bot Interactions (confidence: 0.971)\n", | |
| "\n", | |
| "Text: I've been reflecting on my journey as a developer and realized that the best projects come from solving real problems I face daily\n", | |
| "Prediction: Personal Life - Salient (confidence: 0.860)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Batch prediction function\n", | |
| "def predict_batch(texts, model, tokenizer, device='cuda' if torch.cuda.is_available() else 'cpu'):\n", | |
| " \"\"\"\n", | |
| " Efficient batch prediction for single-label classification\n", | |
| " \"\"\"\n", | |
| " model.to(device)\n", | |
| " model.eval()\n", | |
| " \n", | |
| " inputs = tokenizer(\n", | |
| " texts,\n", | |
| " padding=True,\n", | |
| " truncation=True,\n", | |
| " max_length=512,\n", | |
| " return_tensors=\"pt\"\n", | |
| " ).to(device)\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " outputs = model(**inputs)\n", | |
| " # Apply softmax for single-label\n", | |
| " predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)\n", | |
| " predicted_labels = torch.argmax(predictions, dim=-1)\n", | |
| " \n", | |
| " return predicted_labels.cpu().numpy(), predictions.cpu().numpy()\n", | |
| "\n", | |
| "# Test on some examples\n", | |
| "test_texts = [\n", | |
| " \"GM everyone! Hope you have a great day\",\n", | |
| " \"Just bought more ETH, bullish on the merge\",\n", | |
| " \"This bot is spamming my feed with nonsense\",\n", | |
| " \"I've been reflecting on my journey as a developer and realized that the best projects come from solving real problems I face daily\"\n", | |
| "]\n", | |
| "\n", | |
| "pred_labels, pred_probs = predict_batch(test_texts, model, tokenizer)\n", | |
| "\n", | |
| "# Get reverse label mapping\n", | |
| "idx_to_label = {v: k for k, v in label_map.items()}\n", | |
| "\n", | |
| "for i, text in enumerate(test_texts):\n", | |
| " print(f\"\\nText: {text}\")\n", | |
| " predicted_class = idx_to_label[pred_labels[i]]\n", | |
| " confidence = pred_probs[i][pred_labels[i]]\n", | |
| " print(f\"Prediction: {predicted_class} (confidence: {confidence:.3f})\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Saving model to ./farcaster-classifier-final...\n", | |
| "✓ Model saved successfully!\n", | |
| "✓ Tokenizer saved successfully!\n", | |
| "✓ Label map saved successfully!\n", | |
| "\n", | |
| "Model location: ./farcaster-classifier-final/\n", | |
| "Files saved:\n", | |
| " - config.json\n", | |
| " - model.safetensors (or pytorch_model.bin)\n", | |
| " - tokenizer files\n", | |
| " - label_map.json\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Save the final model and configuration\n", | |
| "print(\"Saving model to ./farcaster-classifier-final...\")\n", | |
| "\n", | |
| "# Save model and tokenizer\n", | |
| "model.save_pretrained(\"./farcaster-classifier-final\")\n", | |
| "tokenizer.save_pretrained(\"./farcaster-classifier-final\")\n", | |
| "\n", | |
| "# Save label map\n", | |
| "with open(\"./farcaster-classifier-final/label_map.json\", 'w') as f:\n", | |
| " json.dump(label_map, f, indent=2)\n", | |
| "\n", | |
| "print(\"✓ Model saved successfully!\")\n", | |
| "print(\"✓ Tokenizer saved successfully!\")\n", | |
| "print(\"✓ Label map saved successfully!\")\n", | |
| "print(f\"\\nModel location: ./farcaster-classifier-final/\")\n", | |
| "print(f\"Files saved:\")\n", | |
| "print(f\" - config.json\")\n", | |
| "print(f\" - model.safetensors (or pytorch_model.bin)\")\n", | |
| "print(f\" - tokenizer files\")\n", | |
| "print(f\" - label_map.json\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Random Samples from Each Category with Model Predictions\n", | |
| "================================================================================\n", | |
| "\n", | |
| "\n", | |
| "================================================================================\n", | |
| "CATEGORY: Personal Life - Salient\n", | |
| "================================================================================\n", | |
| "\n", | |
| "1. ✓ Text:\n", | |
| " many such cases\n", | |
| "\n", | |
| "born and raised in LA\n", | |
| "\n", | |
| "26 years later have still never developed the muscle to check the weather daily\n", | |
| "\n", | |
| " Predicted: Personal Life - Salient (confidence: 0.991)\n", | |
| "\n", | |
| "\n", | |
| "2. ✓ Text:\n", | |
| " facebook just reminded me that 4 years ago i came 3rd at the diamond league\n", | |
| "\n", | |
| "1st and 2nd place in the photo are now olympic medalists!\n", | |
| "\n", | |
| "and they are both 6+ years older than me \n", | |
| "\n", | |
| "by the time LA 2028 c...\n", | |
| "\n", | |
| " Predicted: Personal Life - Salient (confidence: 0.991)\n", | |
| "\n", | |
| "\n", | |
| "3. ✓ Text:\n", | |
| " The Butterfly Nebula from Hubble. Stars can make beautiful patterns as they age -- sometimes similar to flowers or insects. NGC 6302, the Butterfly Nebula, is a notable example.\n", | |
| "\n", | |
| " Predicted: Personal Life - Salient (confidence: 0.983)\n", | |
| "\n", | |
| "\n", | |
| "4. ✓ Text:\n", | |
| " Và khi điều quan trọng nhất còn chưa được thực hiện, các mắt xích dậm chân tại chỗ, tổ chức chưa thể tiến về phía trước, sự bận rộn với những điều nhỏ nhặt đó cũng chẳng khác gì 1 sự lười biếng được v...\n", | |
| "\n", | |
| " Predicted: Personal Life - Salient (confidence: 0.988)\n", | |
| "\n", | |
| "\n", | |
| "5. ✓ Text:\n", | |
| " Republican love to lied. I sold insurance for years, and my wife used to works for Medicaid—you need proof of legal status to obtain insurance. The only type of insurance someone can get without that ...\n", | |
| "\n", | |
| " Predicted: Personal Life - Salient (confidence: 0.936)\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "================================================================================\n", | |
| "CATEGORY: Personal Life - Not Salient\n", | |
| "================================================================================\n", | |
| "\n", | |
| "1. ✓ Text:\n", | |
| " “naomi can u make a video about ai”\n", | |
| "\n", | |
| "me: yes i can i am quite the expert \n", | |
| "\n", | |
| "also me:\n", | |
| "\n", | |
| " Predicted: Personal Life - Not Salient (confidence: 0.574)\n", | |
| "\n", | |
| "\n", | |
| "2. ✗ Text:\n", | |
| " すごい計画👏子ども達も体力あるのね👀❤️素晴らしい👏\n", | |
| "\n", | |
| " Predicted: Personal Life - Salient (confidence: 0.539)\n", | |
| " Top 3 predictions:\n", | |
| " - Personal Life - Salient: 0.539\n", | |
| " - Personal Life - Not Salient: 0.397\n", | |
| " - App/Bot Interactions: 0.023\n", | |
| "\n", | |
| "\n", | |
| "3. ✗ Text:\n", | |
| " ITAP\n", | |
| "I took this photo \n", | |
| "What do you think?\n", | |
| "\n", | |
| " Predicted: Not Enough Context To Tell (confidence: 0.879)\n", | |
| " Top 3 predictions:\n", | |
| " - Not Enough Context To Tell: 0.879\n", | |
| " - Personal Life - Not Salient: 0.104\n", | |
| " - App/Bot Interactions: 0.009\n", | |
| "\n", | |
| "\n", | |
| "4. ✓ Text:\n", | |
| " Happy tree Tuesday friend\n", | |
| "\n", | |
| " Predicted: Personal Life - Not Salient (confidence: 0.982)\n", | |
| "\n", | |
| "\n", | |
| "5. ✓ Text:\n", | |
| " Good morning, dear friend.\n", | |
| "\n", | |
| " Predicted: Personal Life - Not Salient (confidence: 0.984)\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "================================================================================\n", | |
| "CATEGORY: Crypto Content\n", | |
| "================================================================================\n", | |
| "\n", | |
| "1. ✓ Text:\n", | |
| " so glad you like it mate!! 👍 \n", | |
| "https://www.empirebuilder.world/empire/0x0c22b5e951683f6fadebdcb931cdf801d2aaa70e\n", | |
| "\n", | |
| " Predicted: Crypto Content (confidence: 0.907)\n", | |
| "\n", | |
| "\n", | |
| "2. ✗ Text:\n", | |
| " Help me get Fuel by liking this cast!\n", | |
| "5 Likes = 1 Fuel🔋\n", | |
| "Support my mech battles in Wreck League Versus 🤖 by \n", | |
| "\n", | |
| " Predicted: App/Bot Interactions (confidence: 0.982)\n", | |
| " Top 3 predictions:\n", | |
| " - App/Bot Interactions: 0.982\n", | |
| " - Crypto Content: 0.012\n", | |
| " - Personal Life - Salient: 0.002\n", | |
| "\n", | |
| "\n", | |
| "3. ✓ Text:\n", | |
| " does token trade history in the app help?\n", | |
| "\n", | |
| "https://warpcast.com/rainbow/0x3c1185b9\n", | |
| "\n", | |
| " Predicted: Crypto Content (confidence: 0.930)\n", | |
| "\n", | |
| "\n", | |
| "4. ✓ Text:\n", | |
| " Cryptocurrency derivatives amplify trading opportunities, catering to experienced traders exclusively.\n", | |
| "\n", | |
| " Predicted: Crypto Content (confidence: 0.989)\n", | |
| "\n", | |
| "\n", | |
| "5. ✓ Text:\n", | |
| " Exit scams vanish with funds, betraying community trust completely.\n", | |
| "\n", | |
| " Predicted: Crypto Content (confidence: 0.989)\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "================================================================================\n", | |
| "CATEGORY: App/Bot Interactions\n", | |
| "================================================================================\n", | |
| "\n", | |
| "1. ✓ Text:\n", | |
| " \\#Privasea\n", | |
| "\n", | |
| "Privasea Update : All tasks will be removed tomorrow March 31st🔶\n", | |
| "\n", | |
| "⭕️ Complete all tasks in ImHuman app ➡️ privasea.ai/download-app\n", | |
| "\n", | |
| "⏺ ref code ➡️ ZwkQxkV 👈\n", | |
| "\n", | |
| "📌You have also time to complete...\n", | |
| "\n", | |
| " Predicted: App/Bot Interactions (confidence: 0.960)\n", | |
| "\n", | |
| "\n", | |
| "2. ✓ Text:\n", | |
| " I currently rank #1000+ on The Leaderboard. Where do you rank?\n", | |
| "\n", | |
| " Predicted: App/Bot Interactions (confidence: 0.988)\n", | |
| "\n", | |
| "\n", | |
| "3. ✓ Text:\n", | |
| " 🌅 Just sent my daily GM on RISE!\n", | |
| "\n", | |
| "Join me in this amazing journey and discover a new way to connect with the community.\n", | |
| "\n", | |
| "Start your journey here: https://onchaingm.com?ref=0xB68B3e0BfA96fd90c1D889e6F8...\n", | |
| "\n", | |
| " Predicted: App/Bot Interactions (confidence: 0.977)\n", | |
| "\n", | |
| "\n", | |
| "4. ✓ Text:\n", | |
| " I currently rank #1000+ on The Leaderboard. Where do you rank?\n", | |
| "\n", | |
| " Predicted: App/Bot Interactions (confidence: 0.988)\n", | |
| "\n", | |
| "\n", | |
| "5. ✓ Text:\n", | |
| " has blocked \n", | |
| " has blocked \n", | |
| " has sneakily blocked \n", | |
| "\n", | |
| "\n", | |
| " Predicted: App/Bot Interactions (confidence: 0.813)\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "================================================================================\n", | |
| "CATEGORY: Not Enough Context To Tell\n", | |
| "================================================================================\n", | |
| "\n", | |
| "1. ✓ Text:\n", | |
| " 11 👏 Tally-ho! Spectacular!\n", | |
| "\n", | |
| " Predicted: Not Enough Context To Tell (confidence: 0.700)\n", | |
| "\n", | |
| "\n", | |
| "2. ✓ Text:\n", | |
| " I couldn't have said it better myself. Well done!\n", | |
| "\n", | |
| " Predicted: Not Enough Context To Tell (confidence: 0.972)\n", | |
| "\n", | |
| "\n", | |
| "3. ✓ Text:\n", | |
| " This is exactly how I feel. Great perspective!\n", | |
| "\n", | |
| " Predicted: Not Enough Context To Tell (confidence: 0.977)\n", | |
| "\n", | |
| "\n", | |
| "4. ✓ Text:\n", | |
| " Yes! This is such an important topic, and you nailed it.\n", | |
| "\n", | |
| " Predicted: Not Enough Context To Tell (confidence: 0.966)\n", | |
| "\n", | |
| "\n", | |
| "5. ✓ Text:\n", | |
| " That's holds a deep meaning\n", | |
| "\n", | |
| " Predicted: Not Enough Context To Tell (confidence: 0.906)\n", | |
| "\n", | |
| "\n", | |
| "================================================================================\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Show 5 random samples from each category with predictions\n", | |
| "import random\n", | |
| "\n", | |
| "print(\"Random Samples from Each Category with Model Predictions\")\n", | |
| "print(\"=\"*80)\n", | |
| "\n", | |
| "# Get reverse label mapping\n", | |
| "idx_to_label = {v: k for k, v in label_map.items()}\n", | |
| "\n", | |
| "# For each category\n", | |
| "for category_name in sorted(label_map.keys(), key=lambda x: label_map[x]):\n", | |
| " category_idx = label_map[category_name]\n", | |
| " \n", | |
| " print(f\"\\n\\n{'='*80}\")\n", | |
| " print(f\"CATEGORY: {category_name}\")\n", | |
| " print('='*80)\n", | |
| " \n", | |
| " # Get all indices for this category\n", | |
| " category_indices = [i for i, l in enumerate(labels) if l == category_idx]\n", | |
| " \n", | |
| " # Sample 5 random examples\n", | |
| " sample_indices = random.sample(category_indices, min(5, len(category_indices)))\n", | |
| " sample_texts = [texts[i] for i in sample_indices]\n", | |
| " \n", | |
| " # Get predictions\n", | |
| " pred_labels, pred_probs = predict_batch(sample_texts, model, tokenizer)\n", | |
| " \n", | |
| " # Display each sample\n", | |
| " for i, (text, pred_label, probs) in enumerate(zip(sample_texts, pred_labels, pred_probs), 1):\n", | |
| " predicted_category = idx_to_label[pred_label]\n", | |
| " confidence = probs[pred_label]\n", | |
| " is_correct = \"✓\" if pred_label == category_idx else \"✗\"\n", | |
| " \n", | |
| " print(f\"\\n{i}. {is_correct} Text:\")\n", | |
| " print(f\" {text[:200]}{'...' if len(text) > 200 else ''}\")\n", | |
| " print(f\"\\n Predicted: {predicted_category} (confidence: {confidence:.3f})\")\n", | |
| " \n", | |
| " # Show top 3 predictions if wrong\n", | |
| " if pred_label != category_idx:\n", | |
| " top3_indices = np.argsort(probs)[-3:][::-1]\n", | |
| " print(f\" Top 3 predictions:\")\n", | |
| " for idx in top3_indices:\n", | |
| " print(f\" - {idx_to_label[idx]}: {probs[idx]:.3f}\")\n", | |
| " print()\n", | |
| "\n", | |
| "print(\"\\n\" + \"=\"*80)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "", | |
| "text/plain": [ | |
| "<Figure size 1000x800 with 2 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "\n", | |
| "Normalized Confusion Matrix (row percentages):\n", | |
| "============================================================\n", | |
| "\n", | |
| "PL-Salient:\n", | |
| " -> PL-Salient: 60.3%\n", | |
| " -> PL-NotSalient: 24.7%\n", | |
| " -> Crypto: 2.7%\n", | |
| " -> App/Bot: 2.7%\n", | |
| " -> No Context: 9.6%\n", | |
| "\n", | |
| "PL-NotSalient:\n", | |
| " -> PL-Salient: 5.9%\n", | |
| " -> PL-NotSalient: 78.4%\n", | |
| " -> Crypto: 1.3%\n", | |
| " -> App/Bot: 3.9%\n", | |
| " -> No Context: 10.6%\n", | |
| "\n", | |
| "Crypto:\n", | |
| " -> PL-Salient: 1.8%\n", | |
| " -> PL-NotSalient: 2.6%\n", | |
| " -> Crypto: 85.5%\n", | |
| " -> App/Bot: 5.7%\n", | |
| " -> No Context: 4.4%\n", | |
| "\n", | |
| "App/Bot:\n", | |
| " -> PL-Salient: 2.5%\n", | |
| " -> PL-NotSalient: 3.1%\n", | |
| " -> Crypto: 18.4%\n", | |
| " -> App/Bot: 70.6%\n", | |
| " -> No Context: 5.5%\n", | |
| "\n", | |
| "No Context:\n", | |
| " -> PL-Salient: 1.5%\n", | |
| " -> PL-NotSalient: 13.9%\n", | |
| " -> Crypto: 5.7%\n", | |
| " -> App/Bot: 3.6%\n", | |
| " -> No Context: 75.3%\n", | |
| "\n", | |
| "============================================================\n", | |
| "Per-Class Accuracy:\n", | |
| "============================================================\n", | |
| "PL-Salient: 60.3%\n", | |
| "PL-NotSalient: 78.4%\n", | |
| "Crypto: 85.5%\n", | |
| "App/Bot: 70.6%\n", | |
| "No Context: 75.3%\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Generate confusion matrix on validation set\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import seaborn as sns\n", | |
| "from sklearn.metrics import confusion_matrix\n", | |
| "\n", | |
| "# Get validation dataset predictions\n", | |
| "train_texts, val_texts, train_labels, val_labels = train_test_split(\n", | |
| " texts, labels, test_size=0.15, stratify=labels, random_state=42\n", | |
| ")\n", | |
| "\n", | |
| "# Predict on validation set\n", | |
| "val_pred_labels, val_pred_probs = predict_batch(val_texts, model, tokenizer)\n", | |
| "\n", | |
| "# Create confusion matrix\n", | |
| "cm = confusion_matrix(val_labels, val_pred_labels)\n", | |
| "\n", | |
| "# Get label names in correct order\n", | |
| "label_names = sorted(label_map.keys(), key=lambda x: label_map[x])\n", | |
| "label_names_short = [\n", | |
| " \"PL-Salient\",\n", | |
| " \"PL-NotSalient\", \n", | |
| " \"Crypto\",\n", | |
| " \"App/Bot\",\n", | |
| " \"No Context\"\n", | |
| "]\n", | |
| "\n", | |
| "# Plot confusion matrix\n", | |
| "plt.figure(figsize=(10, 8))\n", | |
| "sns.heatmap(\n", | |
| " cm, \n", | |
| " annot=True, \n", | |
| " fmt='d', \n", | |
| " cmap='Blues',\n", | |
| " xticklabels=label_names_short,\n", | |
| " yticklabels=label_names_short,\n", | |
| " cbar_kws={'label': 'Count'}\n", | |
| ")\n", | |
| "plt.title('Confusion Matrix - Validation Set', fontsize=16, pad=20)\n", | |
| "plt.ylabel('True Label', fontsize=12)\n", | |
| "plt.xlabel('Predicted Label', fontsize=12)\n", | |
| "plt.xticks(rotation=45, ha='right')\n", | |
| "plt.yticks(rotation=0)\n", | |
| "plt.tight_layout()\n", | |
| "plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')\n", | |
| "plt.show()\n", | |
| "\n", | |
| "# Print normalized confusion matrix (percentages)\n", | |
| "cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", | |
| "print(\"\\nNormalized Confusion Matrix (row percentages):\")\n", | |
| "print(\"=\"*60)\n", | |
| "for i, true_label in enumerate(label_names_short):\n", | |
| " print(f\"\\n{true_label}:\")\n", | |
| " for j, pred_label in enumerate(label_names_short):\n", | |
| " percentage = cm_normalized[i, j] * 100\n", | |
| " if percentage > 0:\n", | |
| " print(f\" -> {pred_label}: {percentage:.1f}%\")\n", | |
| "\n", | |
| "# Print per-class accuracy\n", | |
| "print(\"\\n\" + \"=\"*60)\n", | |
| "print(\"Per-Class Accuracy:\")\n", | |
| "print(\"=\"*60)\n", | |
| "for i, label_name in enumerate(label_names_short):\n", | |
| " class_accuracy = cm[i, i] / cm[i, :].sum() * 100\n", | |
| " print(f\"{label_name}: {class_accuracy:.1f}%\")" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "My Project (uv)", | |
| "language": "python", | |
| "name": "myproject" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.12" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment