Last active
July 27, 2025 15:49
-
-
Save svilupp/a715934be79ff43e93ea213a21910800 to your computer and use it in GitHub Desktop.
system-prompt-learning-jokes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| joker_bayesian.py – List-wise evolutionary prompt tuner for joke generation | |
| Usage: python joker_bayesian.py prompts/seed.txt [GENERATIONS] [POPULATION] | |
| * **Judge** 3x Listwise LLM-as-a-judge evaluates jokes on 8 ternary rubrics (goal, originality, | |
| length, surprise, timing, relatability, clarity, memorability). | |
| * **Fitness** Thompson-sampled Beta(α, β) updated with tournament rankings | |
| (wins = K-1-rank, losses = rank) across multiple judge runs. | |
| * **Variation** EditAgent mutates prompts using feedback history, crossover | |
| blends parents, skip preserves good performers, champ revival. | |
| * **State** Each prompt: .txt file + .hist.json feedback + α,β parameters. | |
| Models: Gemini 2.5 Flash Lite for the Joke agent, Gemini 2.5 Flash for everything else (Judge, Edit, Crossover) | |
| Dependencies | |
| ------------ | |
| uv add pydantic_ai texprompts loguru python-dotenv tqdm tenacity | |
| """ | |
| # ─────────────────────────── Imports ────────────────────────── | |
| import asyncio, random, math, json, sys, os, csv | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import List, Tuple, Dict, Optional | |
| from datetime import datetime | |
| from dotenv import load_dotenv | |
| from loguru import logger | |
| from tqdm import trange | |
| from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type | |
| # pydantic-ai imports | |
| from pydantic import BaseModel, Field | |
| from pydantic_ai import Agent, RunContext | |
| from pydantic_ai.models.gemini import GeminiModelSettings, ThinkingConfig | |
| from pydantic_ai.exceptions import UnexpectedModelBehavior | |
| from pydantic_ai.usage import UsageLimits | |
| import logfire | |
| # ─────────────────────────── Configuration ────────────────────────────────────────────── | |
| load_dotenv(override=True) | |
| random.seed(42) | |
| logfire.configure(token=os.getenv("LOGFIRE_TOKEN"), scrubbing=False) | |
| logfire.instrument_pydantic_ai() | |
| # Evolution parameters | |
| TOPICS = ["bananas","quantum computing","cloud computing","coffee","football", | |
| "time travel","space exploration","blockchain","dinosaurs","robots"] | |
| POP_SIZE, GENERATIONS, ARENA_SIZE, JUDGE_RUNS = 18, 12, 9, 3 | |
| SKIP_PROB, XOVER_PROB, ELITE = 0.2, 0.3, 2 | |
| T0, T_DECAY, HIST_DEPTH = 1.0, 0.9, 6 | |
| TOOLS_THINK_BUDGET, CHAMP_MUTATE_PROB, PRI_DECAY = 256, 0.25, 0.6 | |
| PROMPT_DIR = Path("prompts") | |
| PROMPT_DIR.mkdir(exist_ok=True) | |
| TOURNAMENT_DATA = [] | |
| USAGE_LIMITS = UsageLimits( | |
| request_limit=10 | |
| ) | |
| # ─────────────────────────── Logging Utilities ─────────────────────────────────────────── | |
| def log_step(msg: str, level="info", sep=True): | |
| """Unified logging with optional separator.""" | |
| if sep: logger.info(f"\n{'='*60}") | |
| getattr(logger, level)(msg) | |
| def log_arena_selection(gen: int, topic: str, draw: List[Tuple], arena: List): | |
| """Log arena selection details.""" | |
| log_step(f"Gen {gen} Arena Selection (Topic: {topic}):") | |
| for pm, sampled in sorted(draw, key=lambda x: x[1], reverse=True): | |
| status = "→ IN ARENA" if pm in arena else "" | |
| logger.info(f" {pm.path.name:30} α={pm.alpha:6.2f} β={pm.beta:6.2f} " | |
| f"μ={pm.mean():5.3f} sampled={sampled:5.3f} {status}") | |
| def log_tournament_results(arena: List, wins: List[int], losses: List[int]): | |
| """Log tournament results.""" | |
| logger.info("\nTournament Results:") | |
| for pm, w, l in zip(arena, wins, losses): | |
| logger.info(f" {pm.path.name:30} wins={w:3} losses={l:3} net={w-l:+4} " | |
| f"α={pm.alpha:6.2f} β={pm.beta:6.2f}") | |
| def log_generation_summary(gen: int, population: List): | |
| """Log generation summary.""" | |
| best = max(population, key=lambda p: p.mean()) | |
| log_step(f"Gen {gen} Summary:", "success", sep=False) | |
| logger.success(f"Best: {best.path.name} (μ={best.mean():.3f})") | |
| logger.info("Final Rankings:") | |
| for i, pm in enumerate(sorted(population, key=lambda p: p.mean(), reverse=True), 1): | |
| logger.info(f" {i}. {pm.path.name:30} α={pm.alpha:6.2f} β={pm.beta:6.2f} μ={pm.mean():5.3f}") | |
| # ─────────────────────────── Judge Models ──────────────────────────────────────────────── | |
| class JokeEval(BaseModel): | |
| """Ternary evaluation scores for each joke.""" | |
| goal: int = Field(description="Achieves humorous goal: -1=weak, 0=good/great, 1=exceptional") | |
| originality: int = Field(description="Originality: -1=weak, 0=good/great, 1=exceptional") | |
| length: int = Field(description="Length appropriateness: -1=weak, 0=good/great, 1=exceptional") | |
| surprise: int = Field(description="Surprise/unexpectedness factor: -1=predictable, 0=some surprise, 1=genuinely unexpected") | |
| timing: int = Field(description="Setup and payoff timing: -1=poor pacing, 0=good rhythm, 1=perfect comedic timing") | |
| relatability: int = Field(description="Audience connection: -1=obscure/alienating, 0=accessible, 1=universally relatable") | |
| clarity: int = Field(description="Joke comprehension, not confusing or overcomplicated: -1=confusing/ambiguous, 0=clear, 1=crystal clear and elegant") | |
| memorability: int = Field(description="Sticks in memory: -1=instantly forgettable, 0=decent recall, 1=highly memorable") | |
| suggestions: str = Field(description="Suggestions for improvement. Max 1-2 sentences.") | |
| class JudgeOut(BaseModel): | |
| """Judge output with reasoning and rankings.""" | |
| reasoning: str = Field(..., description="concise internal chain-of-thought") | |
| evaluations: List[JokeEval] = Field(..., description="one entry per joke") | |
| ranking: List[int] = Field(..., description="1-indexed best→worst list") | |
| # System prompts (preserved exactly as original) | |
| JUDGE_SYS = """ | |
| You are an expert comedy critic. Do **three** things in *private* chain-of-thought first: | |
| 1. For EACH joke decide ternary scores {-1=weak,0=good/great,1=exceptional - top decile of jokes!} on these rubrics: | |
| (a) Achieves its humorous goal - actually funny, makes you laugh | |
| (b) Originality (SEVERELY penalize duplicates and known jokes, reward fresh creativity) | |
| (c) Length that's not too long or too short - just right for the joke type | |
| (d) Surprise/unexpectedness - subverts expectations, catches audience off-guard | |
| (e) Setup and payoff timing - proper comedic rhythm and pacing | |
| (f) Audience relatability - accessible premise people can connect with | |
| (g) Clarity - easy to understand, not confusing or overcomplicated | |
| (h) Memorability - sticks with you, quotable, shareable quality | |
| 2. For each joke, provide specific, actionable suggestions for improvement (1-2 sentences max). | |
| 3. Combine the rubric into a scalar (sum) and sort jokes best→worst. | |
| • If scalars tie, break ties via your reasoning. | |
| • IMPORTANT: Duplicate jokes should be ranked lowest, regardless of other qualities. | |
| 4. Only THEN reveal JSON: | |
| { | |
| reasoning: "...", | |
| evaluations:[ {goal:#,originality:#,length:#,surprise:#,timing:#,relatability:#,clarity:#,memorability:#,suggestions:"..."}, ... ], | |
| ranking:[best_index, ...] # 1-indexed matching the joke order provided | |
| } | |
| **Never** leak private reasoning outside the `reasoning` field. | |
| """.strip() | |
| EDIT_SYS = """ | |
| You will refine the SYSTEM PROMPT for a joke generator based on feedback. | |
| The current prompt is provided in the message below. No need to use `read_prompt()` again. | |
| GUIDELINES: | |
| - First, decide on the final form of the prompt. | |
| - Then, plan the changes to the prompt in your reasoning chain. | |
| - Make as few focused edits as possible to achieve the goal. | |
| - Address specific feedback issues (weak humor, verbosity, etc.) | |
| - Use `replace_text(old_text, new_text)` for precise edits. If it fails twice, read the current prompt `read_prompt()` before trying again. | |
| - Use `create_new_prompt(new)` only for complete rewrites | |
| GOOD PROMPTS STRUCTURE: | |
| - CONTEXT: Set the scenario and role | |
| - TASK: Describe what good outcome looks like | |
| - INSTRUCTIONS: What to do / not to do (be specific) | |
| - OUTPUT FORMAT: How to structure the response | |
| - EXAMPLES: Show what success looks like (if helpful) | |
| Be bold but focused. One clear improvement per edit. Focus on NEW FEEDBACK first. | |
| """.strip() | |
| XOVER_SYS = """ | |
| Merge two SYSTEM PROMPTS (PARENT_A, PARENT_B). Keep their strongest constraints, | |
| eliminate redundancy, and maintain similar length. Return ONLY the merged prompt. | |
| """.strip() | |
| # ─────────────────────────── Agent Factory ─────────────────────────────────────────────── | |
| def make_agent(model: str, output_type, system_prompt: str, tools=None, temp=None, thinking=TOOLS_THINK_BUDGET): | |
| """Create agent with standard settings.""" | |
| settings = GeminiModelSettings(thinking_config=ThinkingConfig(thinking_budget=thinking), temperature=temp) | |
| # if temp is not None: settings.temperature = temp | |
| return Agent(model, output_type=output_type, deps_type=str, | |
| system_prompt=system_prompt, tools=tools or [], model_settings=settings) | |
| # Create agents | |
| JudgeAgent = make_agent("google-gla:gemini-2.5-flash", JudgeOut, JUDGE_SYS) | |
| # EditAgent will be created after tools are defined | |
| CrossoverAgent = make_agent("google-gla:gemini-2.5-flash", str, XOVER_SYS) | |
| # ─────────────────────────── File Operations (compressed) ──────────────────────────────── | |
| def file_tool(ctx: RunContext[str], op: str, **kwargs) -> str: | |
| """Unified file operations tool.""" | |
| try: | |
| p = Path(ctx.deps) | |
| if op == 'read': | |
| return p.read_text() if p.exists() else f"ERROR: File {p} not found" | |
| elif op == 'replace': | |
| txt = p.read_text() | |
| old, new = kwargs['old'], kwargs['new'] | |
| count = txt.count(old) | |
| if count == 0: return f"ERROR: '{old[:50]}...' not found" | |
| if count > 1: return f"ERROR: '{old[:50]}...' found {count} times - must be unique" | |
| p.write_text(txt.replace(old, new, 1)) | |
| return "SUCCESS: Edit applied" | |
| elif op == 'create': | |
| p.write_text(kwargs['content']) | |
| return f"SUCCESS: {'Overwrote' if p.exists() else 'Created'} {p.name}" | |
| except Exception as e: | |
| return f"ERROR: {type(e).__name__}: {str(e)}" | |
| # Tool wrappers | |
| def read_prompt(ctx:RunContext[str]): return file_tool(ctx, 'read') | |
| def replace_text(ctx:RunContext[str], old:str, new:str): return file_tool(ctx, 'replace', old=old, new=new) | |
| def create_new_prompt(ctx:RunContext[str], new:str): return file_tool(ctx, 'create', content=new) | |
| # Create EditAgent with tools | |
| EditAgent = make_agent("google-gla:gemini-2.5-flash", str, EDIT_SYS, | |
| tools=[replace_text, create_new_prompt, read_prompt]) | |
| # ─────────────────────────── Retry Wrapper ─────────────────────────────────────────────── | |
| async def run_with_retry(agent: Agent, message: str, deps: str = "", return_none_on_error: bool = False): | |
| """Run agent with retry on API failures. | |
| Args: | |
| agent: The agent to run | |
| message: The message to send | |
| deps: Dependencies for the agent | |
| return_none_on_error: If True, return None on failure instead of raising | |
| """ | |
| @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=1, max=10), | |
| retry=retry_if_exception_type((UnexpectedModelBehavior, ConnectionError, TimeoutError))) | |
| async def _run(): | |
| return await agent.run(message, deps=deps, usage_limits=USAGE_LIMITS) | |
| try: | |
| return await _run() | |
| except Exception as e: | |
| logger.error(f"Agent run failed after retries: {type(e).__name__}: {str(e)[:100]}") | |
| if return_none_on_error: | |
| return None | |
| raise | |
| # ─────────────────────────── Core Data Structures ──────────────────────────────────────── | |
| @dataclass | |
| class PromptMeta: | |
| """Prompt with Beta distribution parameters.""" | |
| path: Path | |
| alpha: float = 1.0 | |
| beta: float = 1.0 | |
| def mean(self) -> float: | |
| return self.alpha / (self.alpha + self.beta) | |
| def copy(self) -> "PromptMeta": | |
| return PromptMeta(self.path, self.alpha, self.beta) | |
| # History management | |
| hist_path = lambda p: p.with_suffix(".hist.json") | |
| load_history = lambda p: json.loads(hist_path(p).read_text()) if hist_path(p).exists() else [] | |
| save_history = lambda p, e: hist_path(p).write_text(json.dumps((load_history(p) + [e])[-50:], indent=1)) | |
| def beta_update(pm: PromptMeta, wins: int, losses: int): | |
| """Update Beta distribution parameters.""" | |
| pm.alpha += wins | |
| pm.beta += losses | |
| # ─────────────────────────── Judge Arena (compressed but clear) ────────────────────────── | |
| async def judge_arena(arena: List[PromptMeta], topic: str, champions: set[Path], gen: int): | |
| """Run tournament and update Beta parameters.""" | |
| K = len(arena) | |
| # 1. Generate jokes concurrently | |
| logger.info(f"Generating {K} jokes about '{topic}'...") | |
| tasks = [run_with_retry(make_agent("google-gla:gemini-2.5-flash-lite-preview-06-17", | |
| str, pm.path.read_text(), temp=0.8, thinking=0), | |
| f"Generate a one-line joke about: {topic}", return_none_on_error=True) | |
| for pm in arena] | |
| results = await asyncio.gather(*tasks) | |
| # Process results | |
| jokes = [] | |
| valid_count = 0 | |
| for i, result in enumerate(results): | |
| if result is None: | |
| jokes.append("[FAILED TO GENERATE]") | |
| else: | |
| jokes.append(result.output.strip().replace('\n', ' ')) | |
| valid_count += 1 | |
| # If all joke generation failed, skip this arena | |
| if valid_count == 0: | |
| logger.error("All joke generations failed, skipping arena") | |
| return | |
| # 2. Judge with multiple shuffles | |
| num_runs = min(JUDGE_RUNS, K) | |
| logger.info(f"Running {num_runs} judge evaluations...") | |
| judge_tasks, orders = [], [] | |
| for _ in range(num_runs): | |
| order = list(range(K)) | |
| random.shuffle(order) | |
| orders.append(order) | |
| joke_block = "\n".join(f"{i+1}. {jokes[order[i]]}" for i in range(len(order))) | |
| judge_tasks.append(run_with_retry(JudgeAgent, f"TOPIC:{topic}\n\nJOKES:\n{joke_block}", return_none_on_error=True)) | |
| judge_results = await asyncio.gather(*judge_tasks) | |
| # Filter out failed judge evaluations | |
| valid_judge_results = [] | |
| valid_orders = [] | |
| for result, order in zip(judge_results, orders): | |
| if result is not None: | |
| valid_judge_results.append(result) | |
| valid_orders.append(order) | |
| # If all judge evaluations failed, assign neutral scores | |
| if not valid_judge_results: | |
| logger.error("All judge evaluations failed, assigning neutral scores") | |
| for pm in arena: | |
| beta_update(pm, 0, 0) | |
| return | |
| judge_results = valid_judge_results | |
| orders = valid_orders | |
| # 3. Process rankings and feedback | |
| all_rankings = [] | |
| issues_map = {'goal': 'weak humor', 'originality': 'lacks originality', 'length': 'poor length', | |
| 'surprise': 'too predictable', 'timing': 'poor pacing', | |
| 'relatability': 'not relatable', 'clarity': 'confusing', 'memorability': 'forgettable'} | |
| for out, order in zip(judge_results, orders): | |
| # Convert shuffled positions to arena indices | |
| arena_ranking = [order[pos-1] for pos in out.output.ranking] | |
| all_rankings.append(arena_ranking) | |
| # Generate feedback for each joke | |
| for i, eval in enumerate(out.output.evaluations): | |
| issues = [v for k, v in issues_map.items() if getattr(eval, k) < 0] | |
| feedback_parts = [] | |
| if issues: | |
| feedback_parts.append(f"Issues: {', '.join(issues)}") | |
| if eval.suggestions: | |
| feedback_parts.append(f"Suggestions: {eval.suggestions}") | |
| if feedback_parts: | |
| save_history(arena[order[i]].path, " | ".join(feedback_parts)) | |
| # 4. Calculate average ranks and update Beta | |
| avg_ranks = [sum(ranking.index(i) for ranking in all_rankings) / len(all_rankings) | |
| for i in range(K)] | |
| ranked_items = sorted(enumerate(avg_ranks), key=lambda x: x[1]) | |
| wins, losses = [], [] | |
| for final_rank, (arena_idx, _) in enumerate(ranked_items): | |
| w, l = K - 1 - final_rank, final_rank | |
| wins.append(w) | |
| losses.append(l) | |
| beta_update(arena[arena_idx], w, l) | |
| # Track tournament data | |
| TOURNAMENT_DATA.append({ | |
| 'generation': gen, 'topic': topic, | |
| 'prompt_file': arena[arena_idx].path.name, | |
| 'joke': jokes[arena_idx], 'avg_rank': avg_ranks[arena_idx], | |
| 'wins': w, 'losses': l, 'net_score': w - l, | |
| 'alpha': arena[arena_idx].alpha, 'beta': arena[arena_idx].beta, | |
| 'mean': arena[arena_idx].mean() | |
| }) | |
| log_tournament_results(arena, wins, losses) | |
| # 5. Track champion | |
| best_idx = max(range(K), key=lambda i: wins[i] - losses[i]) | |
| champions.add(arena[best_idx].path) | |
| # ─────────────────────────── Variation Operations ──────────────────────────────────────── | |
| async def mutate(parent: PromptMeta, gen: int) -> PromptMeta: | |
| """Create mutation with possible skip.""" | |
| if random.random() < SKIP_PROB: | |
| logger.info(f"Skipping mutation for {parent.path.name}") | |
| return parent.copy() | |
| child_path = parent.path.with_name(f"{parent.path.stem}_m{gen}.txt") | |
| child_path.write_text(parent.path.read_text()) | |
| # Prepare feedback | |
| all_feedback = load_history(parent.path) | |
| recent_feedback = all_feedback[-HIST_DEPTH:] if all_feedback else [] | |
| feedback_msg = f"Current prompt:\n{child_path.read_text()}\n\n" | |
| if recent_feedback: | |
| feedback_msg += f"RECENT FEEDBACK:\n" + "\n".join(recent_feedback) | |
| logger.info(f"Mutating {parent.path.name} → {child_path.name}") | |
| result = await run_with_retry(EditAgent, feedback_msg, deps=str(child_path), return_none_on_error=True) | |
| if result is None: | |
| # Return parent unchanged if mutation fails | |
| return parent.copy() | |
| return PromptMeta(child_path, parent.alpha, parent.beta) | |
| async def crossover(pa: PromptMeta, pb: PromptMeta, gen: int) -> PromptMeta: | |
| """Create crossover child.""" | |
| logger.info(f"Crossover: {pa.path.name} × {pb.path.name}") | |
| out = await run_with_retry(CrossoverAgent, | |
| f"PARENT_A:\n{pa.path.read_text()}\n\nPARENT_B:\n{pb.path.read_text()}", | |
| return_none_on_error=True) | |
| if out is None: | |
| # Return first parent if crossover fails | |
| return pa.copy() | |
| child_path = pa.path.with_name(f"{pa.path.stem}_{pb.path.stem}_x{gen}.txt") | |
| child_path.write_text(out.output.strip()) | |
| # Inherit optimistic priors | |
| return PromptMeta(child_path, max(pa.alpha, pb.alpha), max(pa.beta, pb.beta)) | |
| # ─────────────────────────── Bootstrap Population ──────────────────────────────────────── | |
| async def bootstrap_population(seed: Path, target_size: int) -> List[PromptMeta]: | |
| """Create initial population variants.""" | |
| population = [PromptMeta(seed)] | |
| hints = ["Make more concise", "Add humor style specifics", "Emphasize originality", | |
| "Improve joke structure", "Increase clarity"] | |
| logger.info(f"Bootstrapping population from {seed.name}") | |
| # Create mutations in parallel | |
| mutation_tasks = [] | |
| child_paths = [] | |
| for i in range(min(len(hints), target_size - 1)): | |
| child_path = seed.with_name(f"{seed.stem}_init{i}.txt") | |
| child_path.write_text(seed.read_text()) | |
| child_paths.append(child_path) | |
| logger.info(f"Creating variant {i+1}: {hints[i]}") | |
| mutation_tasks.append(run_with_retry(EditAgent, | |
| f"CURRENT PROMPT:\n---\n{seed.read_text()}\n---\n\nTASK: {hints[i]}. Make the changes now!", | |
| deps=str(child_path), return_none_on_error=True)) | |
| await asyncio.gather(*mutation_tasks) | |
| # Add all children (mutations may have failed but files still exist with original content) | |
| for path in child_paths: | |
| population.append(PromptMeta(path)) | |
| # Fill remaining with crossovers in parallel | |
| if len(population) < target_size and len(population) >= 2: | |
| crossover_tasks = [] | |
| num_crossovers = target_size - len(population) | |
| for i in range(num_crossovers): | |
| pa, pb = random.sample(population, 2) | |
| logger.info(f"Bootstrap crossover {len(population)+i+1}") | |
| crossover_tasks.append(crossover(pa, pb, 0)) | |
| if crossover_tasks: | |
| crossover_results = await asyncio.gather(*crossover_tasks) | |
| for result in crossover_results: | |
| if result is not None: | |
| population.append(result) | |
| return population | |
| # ─────────────────────────── CSV Export ────────────────────────────────────────────────── | |
| def save_tournament_csv(): | |
| """Export tournament data to CSV files.""" | |
| if not TOURNAMENT_DATA: | |
| logger.warning("No tournament data to save") | |
| return | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # Save both regular and sorted versions | |
| for name, sort_key in [("tournament", None), | |
| ("best", lambda x: (-x.get('avg_total', 0), x['avg_rank']))]: | |
| path = PROMPT_DIR / f"{name}_results_{timestamp}.csv" | |
| data = sorted(TOURNAMENT_DATA, key=sort_key) if sort_key else TOURNAMENT_DATA | |
| with open(path, 'w', newline='', encoding='utf-8') as f: | |
| writer = csv.DictWriter(f, fieldnames=list(data[0].keys())) | |
| writer.writeheader() | |
| writer.writerows(data) | |
| logger.success(f"Saved {name} data: {path}") | |
| # ─────────────────────────── Main Evolution Loop ───────────────────────────────────────── | |
| async def evolve(seed: Path, generations: int, pop: int): | |
| """Main evolutionary loop.""" | |
| try: | |
| population = await bootstrap_population(seed, pop) | |
| except Exception as e: | |
| logger.error(f"Failed to bootstrap population: {e}") | |
| logger.error("Creating minimal population with seed only") | |
| population = [PromptMeta(seed)] | |
| T, last_champ, champions = T0, None, set() | |
| for gen in trange(1, generations + 1, desc="Generation"): | |
| log_step(f"Starting Generation {gen}") | |
| # Optional champion mutation | |
| if last_champ and random.random() < CHAMP_MUTATE_PROB: | |
| logger.info(f"Adding decayed champion variant") | |
| path = last_champ.path.with_name(f"{last_champ.path.stem}_decay{gen}.txt") | |
| path.write_text(last_champ.path.read_text()) | |
| population.append(PromptMeta(path, | |
| last_champ.alpha * PRI_DECAY + 1, | |
| last_champ.beta * PRI_DECAY + 1)) | |
| # Thompson sampling for arena selection | |
| draw = [(pm, random.betavariate(pm.alpha, pm.beta)) for pm in population] | |
| # Cap arena size to available population | |
| actual_arena_size = min(ARENA_SIZE, len(population)) | |
| arena = [pm for pm, _ in sorted(draw, key=lambda x: x[1], reverse=True)[:actual_arena_size]] | |
| topic = random.choice(TOPICS) | |
| log_arena_selection(gen, topic, draw, arena) | |
| await judge_arena(arena, topic, champions, gen) | |
| # Selection and offspring | |
| ranked = sorted(population, key=lambda p: p.mean(), reverse=True) | |
| survivors = ranked[:ELITE] | |
| # Probabilistic selection for remaining spots | |
| for i, pm in enumerate(ranked[ELITE:], ELITE + 1): | |
| if random.random() < math.exp(-(i - ELITE) / T): | |
| survivors.append(pm) | |
| logger.info(f"Selected {len(survivors)} survivors") | |
| # Generate offspring in parallel | |
| mutation_tasks = [mutate(s, gen) for s in survivors] | |
| children = await asyncio.gather(*mutation_tasks) | |
| if len(survivors) >= 2 and random.random() < XOVER_PROB: | |
| child = await crossover(*random.sample(survivors, 2), gen) | |
| children.append(child) | |
| # Update population | |
| population = (survivors + children)[:pop] | |
| T *= T_DECAY | |
| last_champ = max(population, key=lambda p: p.mean()) | |
| log_generation_summary(gen, population) | |
| # Final results | |
| champion = max(population, key=lambda p: p.mean()) | |
| log_step("EVOLUTION COMPLETE", "success") | |
| logger.success(f"Best prompt: {champion.path.name}") | |
| logger.success(f"Final stats: α={champion.alpha:.2f}, β={champion.beta:.2f}, μ={champion.mean():.3f}") | |
| logger.success("\n★★ BEST PROMPT ★★\n" + champion.path.read_text()) | |
| save_tournament_csv() | |
| # ─────────────────────────── CLI Entry Point ───────────────────────────────────────────── | |
| def main(): | |
| """CLI entry point.""" | |
| if len(sys.argv) < 2: | |
| sys.exit("Usage: python joker_bayesian.py prompts/seed.txt [GENS] [POP]") | |
| seed_file = Path(sys.argv[1]) | |
| gens = int(sys.argv[2]) if len(sys.argv) > 2 else GENERATIONS | |
| pop = int(sys.argv[3]) if len(sys.argv) > 3 else POP_SIZE | |
| if not seed_file.exists(): | |
| sys.exit(f"Seed file not found: {seed_file}") | |
| if ARENA_SIZE > pop: | |
| logger.warning(f"ARENA_SIZE ({ARENA_SIZE}) > POP_SIZE ({pop}), will be capped") | |
| log_step(f"Starting evolution: {seed_file.name}, {gens} generations, population {pop}") | |
| asyncio.run(evolve(seed_file, gens, pop)) | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| """Evolutionary prompt tuner for one-line topic jokes with async parallel processing.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import csv | |
| import os, json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import sys | |
| from typing import List | |
| from loguru import logger | |
| from pydantic import BaseModel, Field | |
| from pydantic_ai import Agent, RunContext | |
| from pydantic_ai.models.gemini import GeminiModelSettings, ThinkingConfig | |
| from tqdm import trange | |
| from dotenv import load_dotenv | |
| import logfire | |
| load_dotenv(override=True) | |
| logfire.configure(token=os.getenv("LOGFIRE_API_KEY")) | |
| logfire.instrument_pydantic_ai() | |
| TOPICS = ["bananas", "quantum computing", "cloud computing", "coffee", "football", | |
| "time travel", "space exploration", "blockchain", "dinosaurs", "robots"] | |
| class JokeEval(BaseModel): | |
| feedback: str = Field(description="Feedback on the joke and what to improve. 2-3 sentences max.") | |
| humour: int = Field(description="How funny is this joke? -1 for less than average, 0 for average, 1 for exceptional") | |
| originality: int = Field(description="How original is this joke? -1 for less than average, 0 for average, 1 for exceptional") | |
| relevance: int = Field(description="How relevant is this joke? -1 for less than average, 0 for average, 1 for exceptional") | |
| safety: int = Field(description="How safe is this joke? -1 for less than average, 0 for average, 1 for exceptional") | |
| total: int = Field(description="Total score (sum of humour, originality, relevance, safety)") | |
| @dataclass | |
| class ScoredJoke: | |
| topic: str | |
| joke: str | |
| eval: JokeEval | |
| @dataclass | |
| class PromptEvaluation: | |
| prompt_file: Path | |
| scores: List[ScoredJoke] | |
| avg_score: float | |
| feedbacks: List[str] | |
| def search_and_replace(ctx: RunContext[str], needle: str, replacement: str) -> str: | |
| file_path = Path(ctx.deps) | |
| text = file_path.read_text() | |
| if text.count(needle) != 1: | |
| return f"ERROR:{text.count(needle)} matches" | |
| file_path.write_text(text.replace(needle, replacement, 1)) | |
| return "Replaced! New prompt: " + file_path.read_text() | |
| def create_new_prompt(ctx: RunContext[str], new_prompt: str) -> str: | |
| Path(ctx.deps).write_text(new_prompt) | |
| return "Created! New prompt: " + new_prompt | |
| def read_prompt(ctx: RunContext[str]) -> str: | |
| """Read the current prompt.""" | |
| return Path(ctx.deps).read_text() | |
| logger.info("Initialising agents") | |
| JudgeAgent = Agent( | |
| "google-gla:gemini-2.5-flash", deps_type=str, output_type=JokeEval, | |
| system_prompt=("Score TOPIC/JOKE: {feedback, humour (-1,0,1), originality (-1,0,1), " | |
| "relevance (-1,0,1), safety (-1,0,1), total (sum)}"), | |
| model_settings=GeminiModelSettings(thinking_config=ThinkingConfig(thinking_budget=256)) | |
| ) | |
| EditAgent = Agent( | |
| "google-gla:gemini-2.5-flash", deps_type=str, output_type=str, | |
| system_prompt=("Improve joke prompts. Input: PROMPT and resulting SCORES+FEEDBACK. " | |
| "Call search_and_replace OR create_new_prompt."), | |
| tools=[search_and_replace, create_new_prompt, read_prompt], | |
| model_settings=GeminiModelSettings(thinking_config=ThinkingConfig(thinking_budget=256)) | |
| ) | |
| def build_joker(prompt: str) -> Agent: | |
| return Agent( | |
| "google-gla:gemini-2.5-flash-lite-preview-06-17", | |
| deps_type=str, output_type=str, system_prompt=prompt, | |
| model_settings=GeminiModelSettings(thinking_config=ThinkingConfig(thinking_budget=0)) | |
| ) | |
| LOG_FILE = Path("joke_evolution_log.csv") | |
| if not LOG_FILE.exists(): | |
| LOG_FILE.write_text("generation,prompt_file,topic,joke,humour,originality," | |
| "relevance,safety,total,feedback\n") | |
| async def evaluate_prompt(prompt_path: Path, generation: int) -> PromptEvaluation: | |
| prompt_text = prompt_path.read_text() | |
| joker = build_joker(prompt_text) | |
| joke_tasks = [joker.run(f"Generate a one-line joke about the topic: {topic}") | |
| for topic in TOPICS] | |
| joke_results = await asyncio.gather(*joke_tasks) | |
| eval_tasks = [JudgeAgent.run(f"TOPIC:{topic}\nJOKE:{result.output}", deps="") | |
| for topic, result in zip(TOPICS, joke_results)] | |
| eval_results = await asyncio.gather(*eval_tasks) | |
| scored_jokes = [] | |
| total = 0.0 | |
| with LOG_FILE.open("a", newline="") as fp: | |
| writer = csv.writer(fp) | |
| for topic, joke_result, eval_result in zip(TOPICS, joke_results, eval_results): | |
| joke, scored = joke_result.output, eval_result.output | |
| writer.writerow([generation, prompt_path.name, topic, joke, | |
| scored.humour, scored.originality, scored.relevance, | |
| scored.safety, scored.total, scored.feedback]) | |
| scored_jokes.append(ScoredJoke(topic, joke, scored)) | |
| total += scored.total | |
| return PromptEvaluation(prompt_path, scored_jokes, total/len(TOPICS), | |
| [sj.eval.feedback for sj in scored_jokes]) | |
| async def spawn_child(parent: PromptEvaluation, gen: int) -> Path: | |
| child = parent.prompt_file.with_name(parent.prompt_file.stem + f"_g{gen}.txt") | |
| child.write_text(parent.prompt_file.read_text()) | |
| sorted_scores = sorted(parent.scores, key=lambda sj: sj.eval.total) | |
| feedback_data = "\n".join([f"Topic: {sj.topic} (score: {sj.eval.total}): {sj.eval.feedback}" | |
| for sj in sorted_scores]) | |
| await EditAgent.run(f"Please make edits based on the provided feedback.\n\nPROMPT:\n{child.read_text()}\nSCORES+FEEDBACK:\n{feedback_data}", deps=str(child)) | |
| return child | |
| async def evolve(seed_file: Path, generations: int = 6, pop_size: int = 4) -> None: | |
| population = [seed_file] | |
| logger.info(f"Starting evolution: gens={generations} pop={pop_size}") | |
| for gen in trange(1, generations + 1, desc="Gen"): | |
| eval_tasks = [evaluate_prompt(pf, gen) for pf in population] | |
| evaluations = await asyncio.gather(*eval_tasks) | |
| evaluations.sort(key=lambda e: e.avg_score, reverse=True) | |
| avg_all = sum(e.avg_score for e in evaluations) / len(evaluations) | |
| logger.success(f"Gen {gen} avg={avg_all:.2f} best={evaluations[0].avg_score:.2f} " | |
| f"file={evaluations[0].prompt_file.name}") | |
| survivors = evaluations[:max(2, pop_size // 2)] | |
| child_tasks = [spawn_child(parent, gen) for parent in survivors] | |
| children = await asyncio.gather(*child_tasks) | |
| population = [e.prompt_file for e in survivors] + children | |
| population = population[:pop_size] | |
| logger.success(f"Best prompt:\n{evaluations[0].prompt_file.read_text()}") | |
| if __name__ == "__main__": | |
| if len(sys.argv) < 2: | |
| sys.exit("Usage: python metalearner_joke.py prompts/joke_seed.txt [GENS] [POP]") | |
| seed_path = Path(sys.argv[1]) | |
| gens = int(sys.argv[2]) if len(sys.argv) > 2 else 6 | |
| pop = int(sys.argv[3]) if len(sys.argv) > 3 else 4 | |
| asyncio.run(evolve(seed_path, gens, pop)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment