Last active
October 10, 2025 22:04
-
-
Save jc4p/334d52de25624e6419aa21f4ccd35f10 to your computer and use it in GitHub Desktop.
Making and running a BERT classifier on Farcaster casts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Classify Farcaster casts using Google Gemini Batch API. | |
| Classifies each cast into one of 5 categories: | |
| - Personal Life | |
| - Spam | |
| - App/Bot Interactions | |
| - Crypto Content | |
| - Not Enough Context To Tell | |
| """ | |
| import json | |
| import time | |
| from google import genai | |
| from google.genai import types | |
| def create_classification_prompt(cast_text): | |
| """Create the prompt for classifying the cast.""" | |
| return f"""Classify this Farcaster cast into exactly ONE of these categories: | |
| 1. Personal Life | |
| 2. Spam | |
| 3. App/Bot Interactions | |
| 4. Crypto Content | |
| 5. Not Enough Context To Tell | |
| Cast: "{cast_text}" | |
| Respond in this exact format: | |
| Category: [category name] | |
| Confidence: [number from 0-10]""" | |
| def main(): | |
| # Initialize the client | |
| client = genai.Client() | |
| # Load the filtered casts from JSON | |
| print("Loading data/casts_filtered.json...") | |
| with open('data/casts_filtered.json', 'r') as f: | |
| casts = [] | |
| for line in f: | |
| casts.append(json.loads(line)) | |
| # Sample to 10k records | |
| import random | |
| if len(casts) > 10000: | |
| print(f"Loaded {len(casts)} casts, sampling 10,000...") | |
| random.seed(42) # For reproducibility | |
| casts = random.sample(casts, 10000) | |
| print(f"Processing {len(casts)} casts") | |
| # Create JSONL file with batch requests | |
| jsonl_filename = 'classification_batch_requests.jsonl' | |
| print(f"Creating {jsonl_filename}...") | |
| with open(jsonl_filename, 'w') as f: | |
| for i, cast in enumerate(casts): | |
| cast_hash = cast['Hash'] | |
| cast_text = cast['Text'] | |
| # Create the classification prompt | |
| prompt = create_classification_prompt(cast_text) | |
| # Create the request object | |
| request = { | |
| "key": f"cast-{i}-{cast_hash}", | |
| "request": { | |
| "contents": [ | |
| { | |
| "parts": [{"text": prompt}], | |
| "role": "user" | |
| } | |
| ], | |
| "generation_config": { | |
| "temperature": 0.1 # Very low temperature for consistent classification | |
| } | |
| } | |
| } | |
| f.write(json.dumps(request) + '\n') | |
| print(f"Created {jsonl_filename} with {len(casts)} requests") | |
| # Upload the file to the File API | |
| print("Uploading JSONL file to Google File API...") | |
| uploaded_file = client.files.upload( | |
| file=jsonl_filename, | |
| config=types.UploadFileConfig( | |
| display_name='farcaster-cast-classification', | |
| mime_type='application/jsonl' | |
| ) | |
| ) | |
| print(f"Uploaded file: {uploaded_file.name}") | |
| # Create the batch job | |
| print("Creating batch job...") | |
| batch_job = client.batches.create( | |
| model="gemini-flash-latest", | |
| src=uploaded_file.name, | |
| config={ | |
| 'display_name': "farcaster-cast-classification", | |
| }, | |
| ) | |
| job_name = batch_job.name | |
| print(f"Created batch job: {job_name}") | |
| # Poll for job completion | |
| print("\nPolling job status...") | |
| completed_states = { | |
| 'JOB_STATE_SUCCEEDED', | |
| 'JOB_STATE_FAILED', | |
| 'JOB_STATE_CANCELLED', | |
| 'JOB_STATE_EXPIRED', | |
| } | |
| while batch_job.state.name not in completed_states: | |
| print(f"Current state: {batch_job.state.name}") | |
| time.sleep(30) # Wait 30 seconds before polling again | |
| batch_job = client.batches.get(name=job_name) | |
| print(f"\nJob finished with state: {batch_job.state.name}") | |
| # Handle results | |
| if batch_job.state.name == 'JOB_STATE_SUCCEEDED': | |
| if batch_job.dest and batch_job.dest.file_name: | |
| result_file_name = batch_job.dest.file_name | |
| print(f"Results are in file: {result_file_name}") | |
| # Download the results | |
| print("Downloading results...") | |
| file_content = client.files.download(file=result_file_name) | |
| # Save raw results (bulk file) | |
| raw_results_filename = 'cast_classification_raw.jsonl' | |
| with open(raw_results_filename, 'wb') as f: | |
| f.write(file_content) | |
| print(f"Saved raw results to {raw_results_filename}") | |
| # Parse and format results | |
| print("Parsing results...") | |
| results = [] | |
| for line in file_content.decode('utf-8').strip().split('\n'): | |
| result = json.loads(line) | |
| results.append(result) | |
| # Create a structured output with the original cast data + classification | |
| structured_results = [] | |
| for i, cast in enumerate(casts): | |
| # Find matching result | |
| key = f"cast-{i}-{cast['Hash']}" | |
| matching_result = next((r for r in results if r.get('key') == key), None) | |
| output = { | |
| 'Fid': cast['Fid'], | |
| 'Hash': cast['Hash'], | |
| 'Text': cast['Text'], | |
| 'Timestamp': cast['Timestamp'], | |
| 'Classification': None, | |
| 'Confidence': None, | |
| 'error': None | |
| } | |
| if matching_result: | |
| if 'response' in matching_result: | |
| # Extract the classification text | |
| response_text = matching_result['response']['candidates'][0]['content']['parts'][0]['text'].strip() | |
| # Parse the response to extract category and confidence | |
| lines = response_text.split('\n') | |
| for line in lines: | |
| if line.startswith('Category:'): | |
| output['Classification'] = line.replace('Category:', '').strip() | |
| elif line.startswith('Confidence:'): | |
| try: | |
| output['Confidence'] = int(line.replace('Confidence:', '').strip()) | |
| except ValueError: | |
| output['Confidence'] = None | |
| elif 'status' in matching_result: | |
| output['error'] = matching_result['status'] | |
| structured_results.append(output) | |
| # Save structured results (filtered file for analysis) | |
| output_filename = 'cast_classification_results.json' | |
| with open(output_filename, 'w') as f: | |
| json.dump(structured_results, f, indent=2) | |
| print(f"Saved structured results to {output_filename}") | |
| # Print summary by category | |
| successful = sum(1 for r in structured_results if r['Classification']) | |
| failed = sum(1 for r in structured_results if r['error']) | |
| # Count by category | |
| category_counts = {} | |
| for r in structured_results: | |
| if r['Classification']: | |
| cat = r['Classification'] | |
| category_counts[cat] = category_counts.get(cat, 0) + 1 | |
| # Calculate average confidence per category | |
| category_confidence = {} | |
| for r in structured_results: | |
| if r['Classification'] and r['Confidence'] is not None: | |
| cat = r['Classification'] | |
| if cat not in category_confidence: | |
| category_confidence[cat] = [] | |
| category_confidence[cat].append(r['Confidence']) | |
| print(f"\nSummary:") | |
| print(f" Total casts: {len(structured_results)}") | |
| print(f" Successful: {successful}") | |
| print(f" Failed: {failed}") | |
| print(f"\nClassification breakdown:") | |
| for cat, count in sorted(category_counts.items(), key=lambda x: x[1], reverse=True): | |
| avg_conf = sum(category_confidence.get(cat, [0])) / len(category_confidence.get(cat, [1])) if cat in category_confidence else 0 | |
| print(f" {cat}: {count} (avg confidence: {avg_conf:.1f})") | |
| elif batch_job.state.name == 'JOB_STATE_FAILED': | |
| print(f"Error: {batch_job.error}") | |
| print("\nDone!") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Farcaster Cast Classification Inference\n", | |
| "\n", | |
| "Load the trained classifier and run inference on new casts." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Imports\n", | |
| "import torch\n", | |
| "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", | |
| "import json\n", | |
| "import numpy as np" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Load the trained model, tokenizer, and label map\n", | |
| "MODEL_PATH = \"./farcaster-classifier-final\"\n", | |
| "\n", | |
| "print(\"Loading model and tokenizer...\")\n", | |
| "tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)\n", | |
| "model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)\n", | |
| "\n", | |
| "# Load label map\n", | |
| "with open(f\"{MODEL_PATH}/label_map.json\", 'r') as f:\n", | |
| " label_map = json.load(f)\n", | |
| "\n", | |
| "# Create reverse mapping (index -> label name)\n", | |
| "idx_to_label = {v: k for k, v in label_map.items()}\n", | |
| "\n", | |
| "# Move model to GPU if available\n", | |
| "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", | |
| "model.to(device)\n", | |
| "model.eval()\n", | |
| "\n", | |
| "print(f\"✓ Model loaded successfully!\")\n", | |
| "print(f\"✓ Device: {device}\")\n", | |
| "print(f\"✓ Labels: {list(label_map.keys())}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Define prediction function\n", | |
| "def classify_cast(text, show_all_probs=False):\n", | |
| " \"\"\"\n", | |
| " Classify a single cast text.\n", | |
| " \n", | |
| " Args:\n", | |
| " text: The cast text to classify\n", | |
| " show_all_probs: If True, show probabilities for all categories\n", | |
| " \n", | |
| " Returns:\n", | |
| " Dictionary with prediction results\n", | |
| " \"\"\"\n", | |
| " # Tokenize\n", | |
| " inputs = tokenizer(\n", | |
| " text,\n", | |
| " padding=True,\n", | |
| " truncation=True,\n", | |
| " max_length=512,\n", | |
| " return_tensors=\"pt\"\n", | |
| " ).to(device)\n", | |
| " \n", | |
| " # Predict\n", | |
| " with torch.no_grad():\n", | |
| " outputs = model(**inputs)\n", | |
| " probs = torch.nn.functional.softmax(outputs.logits, dim=-1)\n", | |
| " predicted_idx = torch.argmax(probs, dim=-1).item()\n", | |
| " confidence = probs[0][predicted_idx].item()\n", | |
| " \n", | |
| " # Get predicted category\n", | |
| " predicted_category = idx_to_label[predicted_idx]\n", | |
| " \n", | |
| " result = {\n", | |
| " 'text': text,\n", | |
| " 'predicted_category': predicted_category,\n", | |
| " 'confidence': confidence\n", | |
| " }\n", | |
| " \n", | |
| " if show_all_probs:\n", | |
| " all_probs = {idx_to_label[i]: probs[0][i].item() for i in range(len(idx_to_label))}\n", | |
| " result['all_probabilities'] = all_probs\n", | |
| " \n", | |
| " return result\n", | |
| "\n", | |
| "\n", | |
| "def classify_batch(texts, show_all_probs=False):\n", | |
| " \"\"\"\n", | |
| " Classify multiple casts at once (more efficient).\n", | |
| " \n", | |
| " Args:\n", | |
| " texts: List of cast texts to classify\n", | |
| " show_all_probs: If True, show probabilities for all categories\n", | |
| " \n", | |
| " Returns:\n", | |
| " List of prediction dictionaries\n", | |
| " \"\"\"\n", | |
| " # Tokenize all texts\n", | |
| " inputs = tokenizer(\n", | |
| " texts,\n", | |
| " padding=True,\n", | |
| " truncation=True,\n", | |
| " max_length=512,\n", | |
| " return_tensors=\"pt\"\n", | |
| " ).to(device)\n", | |
| " \n", | |
| " # Predict\n", | |
| " with torch.no_grad():\n", | |
| " outputs = model(**inputs)\n", | |
| " probs = torch.nn.functional.softmax(outputs.logits, dim=-1)\n", | |
| " predicted_indices = torch.argmax(probs, dim=-1).cpu().numpy()\n", | |
| " probs_np = probs.cpu().numpy()\n", | |
| " \n", | |
| " # Format results\n", | |
| " results = []\n", | |
| " for i, (text, pred_idx) in enumerate(zip(texts, predicted_indices)):\n", | |
| " predicted_category = idx_to_label[pred_idx]\n", | |
| " confidence = probs_np[i][pred_idx]\n", | |
| " \n", | |
| " result = {\n", | |
| " 'text': text,\n", | |
| " 'predicted_category': predicted_category,\n", | |
| " 'confidence': confidence\n", | |
| " }\n", | |
| " \n", | |
| " if show_all_probs:\n", | |
| " all_probs = {idx_to_label[j]: probs_np[i][j] for j in range(len(idx_to_label))}\n", | |
| " result['all_probabilities'] = all_probs\n", | |
| " \n", | |
| " results.append(result)\n", | |
| " \n", | |
| " return results\n", | |
| "\n", | |
| "\n", | |
| "def print_prediction(result):\n", | |
| " \"\"\"\n", | |
| " Pretty print a prediction result.\n", | |
| " \"\"\"\n", | |
| " print(f\"\\nText: {result['text'][:200]}{'...' if len(result['text']) > 200 else ''}\")\n", | |
| " print(f\"Predicted Category: {result['predicted_category']}\")\n", | |
| " print(f\"Confidence: {result['confidence']:.1%}\")\n", | |
| " \n", | |
| " if 'all_probabilities' in result:\n", | |
| " print(\"\\nAll probabilities:\")\n", | |
| " sorted_probs = sorted(result['all_probabilities'].items(), key=lambda x: x[1], reverse=True)\n", | |
| " for category, prob in sorted_probs:\n", | |
| " print(f\" {category}: {prob:.1%}\")\n", | |
| "\n", | |
| "print(\"✓ Prediction functions defined!\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Example: Classify a single cast\n", | |
| "example_text = \"GM everyone! Hope you have a great day\"\n", | |
| "\n", | |
| "result = classify_cast(example_text, show_all_probs=True)\n", | |
| "print_prediction(result)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Example: Classify multiple casts\n", | |
| "example_casts = [\n", | |
| " \"GM everyone! Hope you have a great day\",\n", | |
| " \"Just bought more ETH, bullish on the merge\",\n", | |
| " \"This bot is spamming my feed with nonsense\",\n", | |
| " \"I've been reflecting on my journey as a developer and realized that the best projects come from solving real problems I face daily\",\n", | |
| " \"Check out this new DeFi protocol, yields are crazy!\"\n", | |
| "]\n", | |
| "\n", | |
| "results = classify_batch(example_casts, show_all_probs=True)\n", | |
| "\n", | |
| "for i, result in enumerate(results, 1):\n", | |
| " print(f\"\\n{'='*80}\")\n", | |
| " print(f\"Example {i}:\")\n", | |
| " print('='*80)\n", | |
| " print_prediction(result)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Interactive: Classify your own text\n", | |
| "# Run this cell and enter a cast to classify\n", | |
| "\n", | |
| "your_text = input(\"Enter a cast to classify: \")\n", | |
| "result = classify_cast(your_text, show_all_probs=True)\n", | |
| "print_prediction(result)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Load and classify from a JSON file\n", | |
| "# Example: Load casts from data/casts_filtered.json and classify them\n", | |
| "\n", | |
| "def classify_from_json(json_file, sample_size=10):\n", | |
| " \"\"\"\n", | |
| " Load casts from a JSON file and classify a sample.\n", | |
| " \n", | |
| " Args:\n", | |
| " json_file: Path to JSON file (JSONL format)\n", | |
| " sample_size: Number of casts to classify\n", | |
| " \"\"\"\n", | |
| " import random\n", | |
| " \n", | |
| " print(f\"Loading casts from {json_file}...\")\n", | |
| " \n", | |
| " # Load casts\n", | |
| " casts = []\n", | |
| " with open(json_file, 'r') as f:\n", | |
| " for line in f:\n", | |
| " cast = json.loads(line)\n", | |
| " if 'Text' in cast:\n", | |
| " casts.append(cast)\n", | |
| " \n", | |
| " print(f\"Loaded {len(casts)} casts\")\n", | |
| " \n", | |
| " # Sample randomly\n", | |
| " sample = random.sample(casts, min(sample_size, len(casts)))\n", | |
| " texts = [c['Text'] for c in sample]\n", | |
| " \n", | |
| " print(f\"Classifying {len(texts)} casts...\")\n", | |
| " \n", | |
| " # Classify\n", | |
| " results = classify_batch(texts, show_all_probs=False)\n", | |
| " \n", | |
| " # Print results\n", | |
| " for i, (cast, result) in enumerate(zip(sample, results), 1):\n", | |
| " print(f\"\\n{'='*80}\")\n", | |
| " print(f\"Cast {i}:\")\n", | |
| " print('='*80)\n", | |
| " print_prediction(result)\n", | |
| " \n", | |
| " # Summary\n", | |
| " category_counts = {}\n", | |
| " for result in results:\n", | |
| " cat = result['predicted_category']\n", | |
| " category_counts[cat] = category_counts.get(cat, 0) + 1\n", | |
| " \n", | |
| " print(f\"\\n{'='*80}\")\n", | |
| " print(\"Category Distribution:\")\n", | |
| " print('='*80)\n", | |
| " for cat, count in sorted(category_counts.items(), key=lambda x: x[1], reverse=True):\n", | |
| " print(f\" {cat}: {count} ({count/len(results)*100:.1f}%)\")\n", | |
| "\n", | |
| "# Example usage (uncomment to run):\n", | |
| "# classify_from_json('data/casts_filtered.json', sample_size=20)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.12" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment