Last active
October 10, 2025 22:04
-
-
Save jc4p/334d52de25624e6419aa21f4ccd35f10 to your computer and use it in GitHub Desktop.
Making and running a BERT classifier on Farcaster casts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Classify Farcaster casts using Google Gemini Batch API. | |
| Classifies each cast into one of 5 categories: | |
| - Personal Life | |
| - Spam | |
| - App/Bot Interactions | |
| - Crypto Content | |
| - Not Enough Context To Tell | |
| """ | |
| import json | |
| import time | |
| from google import genai | |
| from google.genai import types | |
| def create_classification_prompt(cast_text): | |
| """Create the prompt for classifying the cast.""" | |
| return f"""Classify this Farcaster cast into exactly ONE of these categories: | |
| 1. Personal Life | |
| 2. Spam | |
| 3. App/Bot Interactions | |
| 4. Crypto Content | |
| 5. Not Enough Context To Tell | |
| Cast: "{cast_text}" | |
| Respond in this exact format: | |
| Category: [category name] | |
| Confidence: [number from 0-10]""" | |
| def main(): | |
| # Initialize the client | |
| client = genai.Client() | |
| # Load the filtered casts from JSON | |
| print("Loading data/casts_filtered.json...") | |
| with open('data/casts_filtered.json', 'r') as f: | |
| casts = [] | |
| for line in f: | |
| casts.append(json.loads(line)) | |
| # Sample to 10k records | |
| import random | |
| if len(casts) > 10000: | |
| print(f"Loaded {len(casts)} casts, sampling 10,000...") | |
| random.seed(42) # For reproducibility | |
| casts = random.sample(casts, 10000) | |
| print(f"Processing {len(casts)} casts") | |
| # Create JSONL file with batch requests | |
| jsonl_filename = 'classification_batch_requests.jsonl' | |
| print(f"Creating {jsonl_filename}...") | |
| with open(jsonl_filename, 'w') as f: | |
| for i, cast in enumerate(casts): | |
| cast_hash = cast['Hash'] | |
| cast_text = cast['Text'] | |
| # Create the classification prompt | |
| prompt = create_classification_prompt(cast_text) | |
| # Create the request object | |
| request = { | |
| "key": f"cast-{i}-{cast_hash}", | |
| "request": { | |
| "contents": [ | |
| { | |
| "parts": [{"text": prompt}], | |
| "role": "user" | |
| } | |
| ], | |
| "generation_config": { | |
| "temperature": 0.1 # Very low temperature for consistent classification | |
| } | |
| } | |
| } | |
| f.write(json.dumps(request) + '\n') | |
| print(f"Created {jsonl_filename} with {len(casts)} requests") | |
| # Upload the file to the File API | |
| print("Uploading JSONL file to Google File API...") | |
| uploaded_file = client.files.upload( | |
| file=jsonl_filename, | |
| config=types.UploadFileConfig( | |
| display_name='farcaster-cast-classification', | |
| mime_type='application/jsonl' | |
| ) | |
| ) | |
| print(f"Uploaded file: {uploaded_file.name}") | |
| # Create the batch job | |
| print("Creating batch job...") | |
| batch_job = client.batches.create( | |
| model="gemini-flash-latest", | |
| src=uploaded_file.name, | |
| config={ | |
| 'display_name': "farcaster-cast-classification", | |
| }, | |
| ) | |
| job_name = batch_job.name | |
| print(f"Created batch job: {job_name}") | |
| # Poll for job completion | |
| print("\nPolling job status...") | |
| completed_states = { | |
| 'JOB_STATE_SUCCEEDED', | |
| 'JOB_STATE_FAILED', | |
| 'JOB_STATE_CANCELLED', | |
| 'JOB_STATE_EXPIRED', | |
| } | |
| while batch_job.state.name not in completed_states: | |
| print(f"Current state: {batch_job.state.name}") | |
| time.sleep(30) # Wait 30 seconds before polling again | |
| batch_job = client.batches.get(name=job_name) | |
| print(f"\nJob finished with state: {batch_job.state.name}") | |
| # Handle results | |
| if batch_job.state.name == 'JOB_STATE_SUCCEEDED': | |
| if batch_job.dest and batch_job.dest.file_name: | |
| result_file_name = batch_job.dest.file_name | |
| print(f"Results are in file: {result_file_name}") | |
| # Download the results | |
| print("Downloading results...") | |
| file_content = client.files.download(file=result_file_name) | |
| # Save raw results (bulk file) | |
| raw_results_filename = 'cast_classification_raw.jsonl' | |
| with open(raw_results_filename, 'wb') as f: | |
| f.write(file_content) | |
| print(f"Saved raw results to {raw_results_filename}") | |
| # Parse and format results | |
| print("Parsing results...") | |
| results = [] | |
| for line in file_content.decode('utf-8').strip().split('\n'): | |
| result = json.loads(line) | |
| results.append(result) | |
| # Create a structured output with the original cast data + classification | |
| structured_results = [] | |
| for i, cast in enumerate(casts): | |
| # Find matching result | |
| key = f"cast-{i}-{cast['Hash']}" | |
| matching_result = next((r for r in results if r.get('key') == key), None) | |
| output = { | |
| 'Fid': cast['Fid'], | |
| 'Hash': cast['Hash'], | |
| 'Text': cast['Text'], | |
| 'Timestamp': cast['Timestamp'], | |
| 'Classification': None, | |
| 'Confidence': None, | |
| 'error': None | |
| } | |
| if matching_result: | |
| if 'response' in matching_result: | |
| # Extract the classification text | |
| response_text = matching_result['response']['candidates'][0]['content']['parts'][0]['text'].strip() | |
| # Parse the response to extract category and confidence | |
| lines = response_text.split('\n') | |
| for line in lines: | |
| if line.startswith('Category:'): | |
| output['Classification'] = line.replace('Category:', '').strip() | |
| elif line.startswith('Confidence:'): | |
| try: | |
| output['Confidence'] = int(line.replace('Confidence:', '').strip()) | |
| except ValueError: | |
| output['Confidence'] = None | |
| elif 'status' in matching_result: | |
| output['error'] = matching_result['status'] | |
| structured_results.append(output) | |
| # Save structured results (filtered file for analysis) | |
| output_filename = 'cast_classification_results.json' | |
| with open(output_filename, 'w') as f: | |
| json.dump(structured_results, f, indent=2) | |
| print(f"Saved structured results to {output_filename}") | |
| # Print summary by category | |
| successful = sum(1 for r in structured_results if r['Classification']) | |
| failed = sum(1 for r in structured_results if r['error']) | |
| # Count by category | |
| category_counts = {} | |
| for r in structured_results: | |
| if r['Classification']: | |
| cat = r['Classification'] | |
| category_counts[cat] = category_counts.get(cat, 0) + 1 | |
| # Calculate average confidence per category | |
| category_confidence = {} | |
| for r in structured_results: | |
| if r['Classification'] and r['Confidence'] is not None: | |
| cat = r['Classification'] | |
| if cat not in category_confidence: | |
| category_confidence[cat] = [] | |
| category_confidence[cat].append(r['Confidence']) | |
| print(f"\nSummary:") | |
| print(f" Total casts: {len(structured_results)}") | |
| print(f" Successful: {successful}") | |
| print(f" Failed: {failed}") | |
| print(f"\nClassification breakdown:") | |
| for cat, count in sorted(category_counts.items(), key=lambda x: x[1], reverse=True): | |
| avg_conf = sum(category_confidence.get(cat, [0])) / len(category_confidence.get(cat, [1])) if cat in category_confidence else 0 | |
| print(f" {cat}: {count} (avg confidence: {avg_conf:.1f})") | |
| elif batch_job.state.name == 'JOB_STATE_FAILED': | |
| print(f"Error: {batch_job.error}") | |
| print("\nDone!") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment