Skip to content

Instantly share code, notes, and snippets.

@svilupp
Last active July 27, 2025 15:49
Show Gist options
  • Select an option

  • Save svilupp/a715934be79ff43e93ea213a21910800 to your computer and use it in GitHub Desktop.

Select an option

Save svilupp/a715934be79ff43e93ea213a21910800 to your computer and use it in GitHub Desktop.
system-prompt-learning-jokes
#!/usr/bin/env python3
"""
joker_bayesian.py – List-wise evolutionary prompt tuner for joke generation
Usage: python joker_bayesian.py prompts/seed.txt [GENERATIONS] [POPULATION]
* **Judge** 3x Listwise LLM-as-a-judge evaluates jokes on 8 ternary rubrics (goal, originality,
length, surprise, timing, relatability, clarity, memorability).
* **Fitness** Thompson-sampled Beta(α, β) updated with tournament rankings
(wins = K-1-rank, losses = rank) across multiple judge runs.
* **Variation** EditAgent mutates prompts using feedback history, crossover
blends parents, skip preserves good performers, champ revival.
* **State** Each prompt: .txt file + .hist.json feedback + α,β parameters.
Models: Gemini 2.5 Flash Lite for the Joke agent, Gemini 2.5 Flash for everything else (Judge, Edit, Crossover)
Dependencies
------------
uv add pydantic_ai texprompts loguru python-dotenv tqdm tenacity
"""
# ─────────────────────────── Imports ──────────────────────────
import asyncio, random, math, json, sys, os, csv
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Dict, Optional
from datetime import datetime
from dotenv import load_dotenv
from loguru import logger
from tqdm import trange
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
# pydantic-ai imports
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.gemini import GeminiModelSettings, ThinkingConfig
from pydantic_ai.exceptions import UnexpectedModelBehavior
from pydantic_ai.usage import UsageLimits
import logfire
# ─────────────────────────── Configuration ──────────────────────────────────────────────
load_dotenv(override=True)
random.seed(42)
logfire.configure(token=os.getenv("LOGFIRE_TOKEN"), scrubbing=False)
logfire.instrument_pydantic_ai()
# Evolution parameters
TOPICS = ["bananas","quantum computing","cloud computing","coffee","football",
"time travel","space exploration","blockchain","dinosaurs","robots"]
POP_SIZE, GENERATIONS, ARENA_SIZE, JUDGE_RUNS = 18, 12, 9, 3
SKIP_PROB, XOVER_PROB, ELITE = 0.2, 0.3, 2
T0, T_DECAY, HIST_DEPTH = 1.0, 0.9, 6
TOOLS_THINK_BUDGET, CHAMP_MUTATE_PROB, PRI_DECAY = 256, 0.25, 0.6
PROMPT_DIR = Path("prompts")
PROMPT_DIR.mkdir(exist_ok=True)
TOURNAMENT_DATA = []
USAGE_LIMITS = UsageLimits(
request_limit=10
)
# ─────────────────────────── Logging Utilities ───────────────────────────────────────────
def log_step(msg: str, level="info", sep=True):
"""Unified logging with optional separator."""
if sep: logger.info(f"\n{'='*60}")
getattr(logger, level)(msg)
def log_arena_selection(gen: int, topic: str, draw: List[Tuple], arena: List):
"""Log arena selection details."""
log_step(f"Gen {gen} Arena Selection (Topic: {topic}):")
for pm, sampled in sorted(draw, key=lambda x: x[1], reverse=True):
status = "→ IN ARENA" if pm in arena else ""
logger.info(f" {pm.path.name:30} α={pm.alpha:6.2f} β={pm.beta:6.2f} "
f"μ={pm.mean():5.3f} sampled={sampled:5.3f} {status}")
def log_tournament_results(arena: List, wins: List[int], losses: List[int]):
"""Log tournament results."""
logger.info("\nTournament Results:")
for pm, w, l in zip(arena, wins, losses):
logger.info(f" {pm.path.name:30} wins={w:3} losses={l:3} net={w-l:+4} "
f"α={pm.alpha:6.2f} β={pm.beta:6.2f}")
def log_generation_summary(gen: int, population: List):
"""Log generation summary."""
best = max(population, key=lambda p: p.mean())
log_step(f"Gen {gen} Summary:", "success", sep=False)
logger.success(f"Best: {best.path.name} (μ={best.mean():.3f})")
logger.info("Final Rankings:")
for i, pm in enumerate(sorted(population, key=lambda p: p.mean(), reverse=True), 1):
logger.info(f" {i}. {pm.path.name:30} α={pm.alpha:6.2f} β={pm.beta:6.2f} μ={pm.mean():5.3f}")
# ─────────────────────────── Judge Models ────────────────────────────────────────────────
class JokeEval(BaseModel):
"""Ternary evaluation scores for each joke."""
goal: int = Field(description="Achieves humorous goal: -1=weak, 0=good/great, 1=exceptional")
originality: int = Field(description="Originality: -1=weak, 0=good/great, 1=exceptional")
length: int = Field(description="Length appropriateness: -1=weak, 0=good/great, 1=exceptional")
surprise: int = Field(description="Surprise/unexpectedness factor: -1=predictable, 0=some surprise, 1=genuinely unexpected")
timing: int = Field(description="Setup and payoff timing: -1=poor pacing, 0=good rhythm, 1=perfect comedic timing")
relatability: int = Field(description="Audience connection: -1=obscure/alienating, 0=accessible, 1=universally relatable")
clarity: int = Field(description="Joke comprehension, not confusing or overcomplicated: -1=confusing/ambiguous, 0=clear, 1=crystal clear and elegant")
memorability: int = Field(description="Sticks in memory: -1=instantly forgettable, 0=decent recall, 1=highly memorable")
suggestions: str = Field(description="Suggestions for improvement. Max 1-2 sentences.")
class JudgeOut(BaseModel):
"""Judge output with reasoning and rankings."""
reasoning: str = Field(..., description="concise internal chain-of-thought")
evaluations: List[JokeEval] = Field(..., description="one entry per joke")
ranking: List[int] = Field(..., description="1-indexed best→worst list")
# System prompts (preserved exactly as original)
JUDGE_SYS = """
You are an expert comedy critic. Do **three** things in *private* chain-of-thought first:
1. For EACH joke decide ternary scores {-1=weak,0=good/great,1=exceptional - top decile of jokes!} on these rubrics:
(a) Achieves its humorous goal - actually funny, makes you laugh
(b) Originality (SEVERELY penalize duplicates and known jokes, reward fresh creativity)
(c) Length that's not too long or too short - just right for the joke type
(d) Surprise/unexpectedness - subverts expectations, catches audience off-guard
(e) Setup and payoff timing - proper comedic rhythm and pacing
(f) Audience relatability - accessible premise people can connect with
(g) Clarity - easy to understand, not confusing or overcomplicated
(h) Memorability - sticks with you, quotable, shareable quality
2. For each joke, provide specific, actionable suggestions for improvement (1-2 sentences max).
3. Combine the rubric into a scalar (sum) and sort jokes best→worst.
• If scalars tie, break ties via your reasoning.
• IMPORTANT: Duplicate jokes should be ranked lowest, regardless of other qualities.
4. Only THEN reveal JSON:
{
reasoning: "...",
evaluations:[ {goal:#,originality:#,length:#,surprise:#,timing:#,relatability:#,clarity:#,memorability:#,suggestions:"..."}, ... ],
ranking:[best_index, ...] # 1-indexed matching the joke order provided
}
**Never** leak private reasoning outside the `reasoning` field.
""".strip()
EDIT_SYS = """
You will refine the SYSTEM PROMPT for a joke generator based on feedback.
The current prompt is provided in the message below. No need to use `read_prompt()` again.
GUIDELINES:
- First, decide on the final form of the prompt.
- Then, plan the changes to the prompt in your reasoning chain.
- Make as few focused edits as possible to achieve the goal.
- Address specific feedback issues (weak humor, verbosity, etc.)
- Use `replace_text(old_text, new_text)` for precise edits. If it fails twice, read the current prompt `read_prompt()` before trying again.
- Use `create_new_prompt(new)` only for complete rewrites
GOOD PROMPTS STRUCTURE:
- CONTEXT: Set the scenario and role
- TASK: Describe what good outcome looks like
- INSTRUCTIONS: What to do / not to do (be specific)
- OUTPUT FORMAT: How to structure the response
- EXAMPLES: Show what success looks like (if helpful)
Be bold but focused. One clear improvement per edit. Focus on NEW FEEDBACK first.
""".strip()
XOVER_SYS = """
Merge two SYSTEM PROMPTS (PARENT_A, PARENT_B). Keep their strongest constraints,
eliminate redundancy, and maintain similar length. Return ONLY the merged prompt.
""".strip()
# ─────────────────────────── Agent Factory ───────────────────────────────────────────────
def make_agent(model: str, output_type, system_prompt: str, tools=None, temp=None, thinking=TOOLS_THINK_BUDGET):
"""Create agent with standard settings."""
settings = GeminiModelSettings(thinking_config=ThinkingConfig(thinking_budget=thinking), temperature=temp)
# if temp is not None: settings.temperature = temp
return Agent(model, output_type=output_type, deps_type=str,
system_prompt=system_prompt, tools=tools or [], model_settings=settings)
# Create agents
JudgeAgent = make_agent("google-gla:gemini-2.5-flash", JudgeOut, JUDGE_SYS)
# EditAgent will be created after tools are defined
CrossoverAgent = make_agent("google-gla:gemini-2.5-flash", str, XOVER_SYS)
# ─────────────────────────── File Operations (compressed) ────────────────────────────────
def file_tool(ctx: RunContext[str], op: str, **kwargs) -> str:
"""Unified file operations tool."""
try:
p = Path(ctx.deps)
if op == 'read':
return p.read_text() if p.exists() else f"ERROR: File {p} not found"
elif op == 'replace':
txt = p.read_text()
old, new = kwargs['old'], kwargs['new']
count = txt.count(old)
if count == 0: return f"ERROR: '{old[:50]}...' not found"
if count > 1: return f"ERROR: '{old[:50]}...' found {count} times - must be unique"
p.write_text(txt.replace(old, new, 1))
return "SUCCESS: Edit applied"
elif op == 'create':
p.write_text(kwargs['content'])
return f"SUCCESS: {'Overwrote' if p.exists() else 'Created'} {p.name}"
except Exception as e:
return f"ERROR: {type(e).__name__}: {str(e)}"
# Tool wrappers
def read_prompt(ctx:RunContext[str]): return file_tool(ctx, 'read')
def replace_text(ctx:RunContext[str], old:str, new:str): return file_tool(ctx, 'replace', old=old, new=new)
def create_new_prompt(ctx:RunContext[str], new:str): return file_tool(ctx, 'create', content=new)
# Create EditAgent with tools
EditAgent = make_agent("google-gla:gemini-2.5-flash", str, EDIT_SYS,
tools=[replace_text, create_new_prompt, read_prompt])
# ─────────────────────────── Retry Wrapper ───────────────────────────────────────────────
async def run_with_retry(agent: Agent, message: str, deps: str = "", return_none_on_error: bool = False):
"""Run agent with retry on API failures.
Args:
agent: The agent to run
message: The message to send
deps: Dependencies for the agent
return_none_on_error: If True, return None on failure instead of raising
"""
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((UnexpectedModelBehavior, ConnectionError, TimeoutError)))
async def _run():
return await agent.run(message, deps=deps, usage_limits=USAGE_LIMITS)
try:
return await _run()
except Exception as e:
logger.error(f"Agent run failed after retries: {type(e).__name__}: {str(e)[:100]}")
if return_none_on_error:
return None
raise
# ─────────────────────────── Core Data Structures ────────────────────────────────────────
@dataclass
class PromptMeta:
"""Prompt with Beta distribution parameters."""
path: Path
alpha: float = 1.0
beta: float = 1.0
def mean(self) -> float:
return self.alpha / (self.alpha + self.beta)
def copy(self) -> "PromptMeta":
return PromptMeta(self.path, self.alpha, self.beta)
# History management
hist_path = lambda p: p.with_suffix(".hist.json")
load_history = lambda p: json.loads(hist_path(p).read_text()) if hist_path(p).exists() else []
save_history = lambda p, e: hist_path(p).write_text(json.dumps((load_history(p) + [e])[-50:], indent=1))
def beta_update(pm: PromptMeta, wins: int, losses: int):
"""Update Beta distribution parameters."""
pm.alpha += wins
pm.beta += losses
# ─────────────────────────── Judge Arena (compressed but clear) ──────────────────────────
async def judge_arena(arena: List[PromptMeta], topic: str, champions: set[Path], gen: int):
"""Run tournament and update Beta parameters."""
K = len(arena)
# 1. Generate jokes concurrently
logger.info(f"Generating {K} jokes about '{topic}'...")
tasks = [run_with_retry(make_agent("google-gla:gemini-2.5-flash-lite-preview-06-17",
str, pm.path.read_text(), temp=0.8, thinking=0),
f"Generate a one-line joke about: {topic}", return_none_on_error=True)
for pm in arena]
results = await asyncio.gather(*tasks)
# Process results
jokes = []
valid_count = 0
for i, result in enumerate(results):
if result is None:
jokes.append("[FAILED TO GENERATE]")
else:
jokes.append(result.output.strip().replace('\n', ' '))
valid_count += 1
# If all joke generation failed, skip this arena
if valid_count == 0:
logger.error("All joke generations failed, skipping arena")
return
# 2. Judge with multiple shuffles
num_runs = min(JUDGE_RUNS, K)
logger.info(f"Running {num_runs} judge evaluations...")
judge_tasks, orders = [], []
for _ in range(num_runs):
order = list(range(K))
random.shuffle(order)
orders.append(order)
joke_block = "\n".join(f"{i+1}. {jokes[order[i]]}" for i in range(len(order)))
judge_tasks.append(run_with_retry(JudgeAgent, f"TOPIC:{topic}\n\nJOKES:\n{joke_block}", return_none_on_error=True))
judge_results = await asyncio.gather(*judge_tasks)
# Filter out failed judge evaluations
valid_judge_results = []
valid_orders = []
for result, order in zip(judge_results, orders):
if result is not None:
valid_judge_results.append(result)
valid_orders.append(order)
# If all judge evaluations failed, assign neutral scores
if not valid_judge_results:
logger.error("All judge evaluations failed, assigning neutral scores")
for pm in arena:
beta_update(pm, 0, 0)
return
judge_results = valid_judge_results
orders = valid_orders
# 3. Process rankings and feedback
all_rankings = []
issues_map = {'goal': 'weak humor', 'originality': 'lacks originality', 'length': 'poor length',
'surprise': 'too predictable', 'timing': 'poor pacing',
'relatability': 'not relatable', 'clarity': 'confusing', 'memorability': 'forgettable'}
for out, order in zip(judge_results, orders):
# Convert shuffled positions to arena indices
arena_ranking = [order[pos-1] for pos in out.output.ranking]
all_rankings.append(arena_ranking)
# Generate feedback for each joke
for i, eval in enumerate(out.output.evaluations):
issues = [v for k, v in issues_map.items() if getattr(eval, k) < 0]
feedback_parts = []
if issues:
feedback_parts.append(f"Issues: {', '.join(issues)}")
if eval.suggestions:
feedback_parts.append(f"Suggestions: {eval.suggestions}")
if feedback_parts:
save_history(arena[order[i]].path, " | ".join(feedback_parts))
# 4. Calculate average ranks and update Beta
avg_ranks = [sum(ranking.index(i) for ranking in all_rankings) / len(all_rankings)
for i in range(K)]
ranked_items = sorted(enumerate(avg_ranks), key=lambda x: x[1])
wins, losses = [], []
for final_rank, (arena_idx, _) in enumerate(ranked_items):
w, l = K - 1 - final_rank, final_rank
wins.append(w)
losses.append(l)
beta_update(arena[arena_idx], w, l)
# Track tournament data
TOURNAMENT_DATA.append({
'generation': gen, 'topic': topic,
'prompt_file': arena[arena_idx].path.name,
'joke': jokes[arena_idx], 'avg_rank': avg_ranks[arena_idx],
'wins': w, 'losses': l, 'net_score': w - l,
'alpha': arena[arena_idx].alpha, 'beta': arena[arena_idx].beta,
'mean': arena[arena_idx].mean()
})
log_tournament_results(arena, wins, losses)
# 5. Track champion
best_idx = max(range(K), key=lambda i: wins[i] - losses[i])
champions.add(arena[best_idx].path)
# ─────────────────────────── Variation Operations ────────────────────────────────────────
async def mutate(parent: PromptMeta, gen: int) -> PromptMeta:
"""Create mutation with possible skip."""
if random.random() < SKIP_PROB:
logger.info(f"Skipping mutation for {parent.path.name}")
return parent.copy()
child_path = parent.path.with_name(f"{parent.path.stem}_m{gen}.txt")
child_path.write_text(parent.path.read_text())
# Prepare feedback
all_feedback = load_history(parent.path)
recent_feedback = all_feedback[-HIST_DEPTH:] if all_feedback else []
feedback_msg = f"Current prompt:\n{child_path.read_text()}\n\n"
if recent_feedback:
feedback_msg += f"RECENT FEEDBACK:\n" + "\n".join(recent_feedback)
logger.info(f"Mutating {parent.path.name} → {child_path.name}")
result = await run_with_retry(EditAgent, feedback_msg, deps=str(child_path), return_none_on_error=True)
if result is None:
# Return parent unchanged if mutation fails
return parent.copy()
return PromptMeta(child_path, parent.alpha, parent.beta)
async def crossover(pa: PromptMeta, pb: PromptMeta, gen: int) -> PromptMeta:
"""Create crossover child."""
logger.info(f"Crossover: {pa.path.name} × {pb.path.name}")
out = await run_with_retry(CrossoverAgent,
f"PARENT_A:\n{pa.path.read_text()}\n\nPARENT_B:\n{pb.path.read_text()}",
return_none_on_error=True)
if out is None:
# Return first parent if crossover fails
return pa.copy()
child_path = pa.path.with_name(f"{pa.path.stem}_{pb.path.stem}_x{gen}.txt")
child_path.write_text(out.output.strip())
# Inherit optimistic priors
return PromptMeta(child_path, max(pa.alpha, pb.alpha), max(pa.beta, pb.beta))
# ─────────────────────────── Bootstrap Population ────────────────────────────────────────
async def bootstrap_population(seed: Path, target_size: int) -> List[PromptMeta]:
"""Create initial population variants."""
population = [PromptMeta(seed)]
hints = ["Make more concise", "Add humor style specifics", "Emphasize originality",
"Improve joke structure", "Increase clarity"]
logger.info(f"Bootstrapping population from {seed.name}")
# Create mutations in parallel
mutation_tasks = []
child_paths = []
for i in range(min(len(hints), target_size - 1)):
child_path = seed.with_name(f"{seed.stem}_init{i}.txt")
child_path.write_text(seed.read_text())
child_paths.append(child_path)
logger.info(f"Creating variant {i+1}: {hints[i]}")
mutation_tasks.append(run_with_retry(EditAgent,
f"CURRENT PROMPT:\n---\n{seed.read_text()}\n---\n\nTASK: {hints[i]}. Make the changes now!",
deps=str(child_path), return_none_on_error=True))
await asyncio.gather(*mutation_tasks)
# Add all children (mutations may have failed but files still exist with original content)
for path in child_paths:
population.append(PromptMeta(path))
# Fill remaining with crossovers in parallel
if len(population) < target_size and len(population) >= 2:
crossover_tasks = []
num_crossovers = target_size - len(population)
for i in range(num_crossovers):
pa, pb = random.sample(population, 2)
logger.info(f"Bootstrap crossover {len(population)+i+1}")
crossover_tasks.append(crossover(pa, pb, 0))
if crossover_tasks:
crossover_results = await asyncio.gather(*crossover_tasks)
for result in crossover_results:
if result is not None:
population.append(result)
return population
# ─────────────────────────── CSV Export ──────────────────────────────────────────────────
def save_tournament_csv():
"""Export tournament data to CSV files."""
if not TOURNAMENT_DATA:
logger.warning("No tournament data to save")
return
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Save both regular and sorted versions
for name, sort_key in [("tournament", None),
("best", lambda x: (-x.get('avg_total', 0), x['avg_rank']))]:
path = PROMPT_DIR / f"{name}_results_{timestamp}.csv"
data = sorted(TOURNAMENT_DATA, key=sort_key) if sort_key else TOURNAMENT_DATA
with open(path, 'w', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=list(data[0].keys()))
writer.writeheader()
writer.writerows(data)
logger.success(f"Saved {name} data: {path}")
# ─────────────────────────── Main Evolution Loop ─────────────────────────────────────────
async def evolve(seed: Path, generations: int, pop: int):
"""Main evolutionary loop."""
try:
population = await bootstrap_population(seed, pop)
except Exception as e:
logger.error(f"Failed to bootstrap population: {e}")
logger.error("Creating minimal population with seed only")
population = [PromptMeta(seed)]
T, last_champ, champions = T0, None, set()
for gen in trange(1, generations + 1, desc="Generation"):
log_step(f"Starting Generation {gen}")
# Optional champion mutation
if last_champ and random.random() < CHAMP_MUTATE_PROB:
logger.info(f"Adding decayed champion variant")
path = last_champ.path.with_name(f"{last_champ.path.stem}_decay{gen}.txt")
path.write_text(last_champ.path.read_text())
population.append(PromptMeta(path,
last_champ.alpha * PRI_DECAY + 1,
last_champ.beta * PRI_DECAY + 1))
# Thompson sampling for arena selection
draw = [(pm, random.betavariate(pm.alpha, pm.beta)) for pm in population]
# Cap arena size to available population
actual_arena_size = min(ARENA_SIZE, len(population))
arena = [pm for pm, _ in sorted(draw, key=lambda x: x[1], reverse=True)[:actual_arena_size]]
topic = random.choice(TOPICS)
log_arena_selection(gen, topic, draw, arena)
await judge_arena(arena, topic, champions, gen)
# Selection and offspring
ranked = sorted(population, key=lambda p: p.mean(), reverse=True)
survivors = ranked[:ELITE]
# Probabilistic selection for remaining spots
for i, pm in enumerate(ranked[ELITE:], ELITE + 1):
if random.random() < math.exp(-(i - ELITE) / T):
survivors.append(pm)
logger.info(f"Selected {len(survivors)} survivors")
# Generate offspring in parallel
mutation_tasks = [mutate(s, gen) for s in survivors]
children = await asyncio.gather(*mutation_tasks)
if len(survivors) >= 2 and random.random() < XOVER_PROB:
child = await crossover(*random.sample(survivors, 2), gen)
children.append(child)
# Update population
population = (survivors + children)[:pop]
T *= T_DECAY
last_champ = max(population, key=lambda p: p.mean())
log_generation_summary(gen, population)
# Final results
champion = max(population, key=lambda p: p.mean())
log_step("EVOLUTION COMPLETE", "success")
logger.success(f"Best prompt: {champion.path.name}")
logger.success(f"Final stats: α={champion.alpha:.2f}, β={champion.beta:.2f}, μ={champion.mean():.3f}")
logger.success("\n★★ BEST PROMPT ★★\n" + champion.path.read_text())
save_tournament_csv()
# ─────────────────────────── CLI Entry Point ─────────────────────────────────────────────
def main():
"""CLI entry point."""
if len(sys.argv) < 2:
sys.exit("Usage: python joker_bayesian.py prompts/seed.txt [GENS] [POP]")
seed_file = Path(sys.argv[1])
gens = int(sys.argv[2]) if len(sys.argv) > 2 else GENERATIONS
pop = int(sys.argv[3]) if len(sys.argv) > 3 else POP_SIZE
if not seed_file.exists():
sys.exit(f"Seed file not found: {seed_file}")
if ARENA_SIZE > pop:
logger.warning(f"ARENA_SIZE ({ARENA_SIZE}) > POP_SIZE ({pop}), will be capped")
log_step(f"Starting evolution: {seed_file.name}, {gens} generations, population {pop}")
asyncio.run(evolve(seed_file, gens, pop))
if __name__ == "__main__":
main()
#!/usr/bin/env python
"""Evolutionary prompt tuner for one-line topic jokes with async parallel processing."""
from __future__ import annotations
import asyncio
import csv
import os, json
from dataclasses import dataclass
from pathlib import Path
import sys
from typing import List
from loguru import logger
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.gemini import GeminiModelSettings, ThinkingConfig
from tqdm import trange
from dotenv import load_dotenv
import logfire
load_dotenv(override=True)
logfire.configure(token=os.getenv("LOGFIRE_API_KEY"))
logfire.instrument_pydantic_ai()
TOPICS = ["bananas", "quantum computing", "cloud computing", "coffee", "football",
"time travel", "space exploration", "blockchain", "dinosaurs", "robots"]
class JokeEval(BaseModel):
feedback: str = Field(description="Feedback on the joke and what to improve. 2-3 sentences max.")
humour: int = Field(description="How funny is this joke? -1 for less than average, 0 for average, 1 for exceptional")
originality: int = Field(description="How original is this joke? -1 for less than average, 0 for average, 1 for exceptional")
relevance: int = Field(description="How relevant is this joke? -1 for less than average, 0 for average, 1 for exceptional")
safety: int = Field(description="How safe is this joke? -1 for less than average, 0 for average, 1 for exceptional")
total: int = Field(description="Total score (sum of humour, originality, relevance, safety)")
@dataclass
class ScoredJoke:
topic: str
joke: str
eval: JokeEval
@dataclass
class PromptEvaluation:
prompt_file: Path
scores: List[ScoredJoke]
avg_score: float
feedbacks: List[str]
def search_and_replace(ctx: RunContext[str], needle: str, replacement: str) -> str:
file_path = Path(ctx.deps)
text = file_path.read_text()
if text.count(needle) != 1:
return f"ERROR:{text.count(needle)} matches"
file_path.write_text(text.replace(needle, replacement, 1))
return "Replaced! New prompt: " + file_path.read_text()
def create_new_prompt(ctx: RunContext[str], new_prompt: str) -> str:
Path(ctx.deps).write_text(new_prompt)
return "Created! New prompt: " + new_prompt
def read_prompt(ctx: RunContext[str]) -> str:
"""Read the current prompt."""
return Path(ctx.deps).read_text()
logger.info("Initialising agents")
JudgeAgent = Agent(
"google-gla:gemini-2.5-flash", deps_type=str, output_type=JokeEval,
system_prompt=("Score TOPIC/JOKE: {feedback, humour (-1,0,1), originality (-1,0,1), "
"relevance (-1,0,1), safety (-1,0,1), total (sum)}"),
model_settings=GeminiModelSettings(thinking_config=ThinkingConfig(thinking_budget=256))
)
EditAgent = Agent(
"google-gla:gemini-2.5-flash", deps_type=str, output_type=str,
system_prompt=("Improve joke prompts. Input: PROMPT and resulting SCORES+FEEDBACK. "
"Call search_and_replace OR create_new_prompt."),
tools=[search_and_replace, create_new_prompt, read_prompt],
model_settings=GeminiModelSettings(thinking_config=ThinkingConfig(thinking_budget=256))
)
def build_joker(prompt: str) -> Agent:
return Agent(
"google-gla:gemini-2.5-flash-lite-preview-06-17",
deps_type=str, output_type=str, system_prompt=prompt,
model_settings=GeminiModelSettings(thinking_config=ThinkingConfig(thinking_budget=0))
)
LOG_FILE = Path("joke_evolution_log.csv")
if not LOG_FILE.exists():
LOG_FILE.write_text("generation,prompt_file,topic,joke,humour,originality,"
"relevance,safety,total,feedback\n")
async def evaluate_prompt(prompt_path: Path, generation: int) -> PromptEvaluation:
prompt_text = prompt_path.read_text()
joker = build_joker(prompt_text)
joke_tasks = [joker.run(f"Generate a one-line joke about the topic: {topic}")
for topic in TOPICS]
joke_results = await asyncio.gather(*joke_tasks)
eval_tasks = [JudgeAgent.run(f"TOPIC:{topic}\nJOKE:{result.output}", deps="")
for topic, result in zip(TOPICS, joke_results)]
eval_results = await asyncio.gather(*eval_tasks)
scored_jokes = []
total = 0.0
with LOG_FILE.open("a", newline="") as fp:
writer = csv.writer(fp)
for topic, joke_result, eval_result in zip(TOPICS, joke_results, eval_results):
joke, scored = joke_result.output, eval_result.output
writer.writerow([generation, prompt_path.name, topic, joke,
scored.humour, scored.originality, scored.relevance,
scored.safety, scored.total, scored.feedback])
scored_jokes.append(ScoredJoke(topic, joke, scored))
total += scored.total
return PromptEvaluation(prompt_path, scored_jokes, total/len(TOPICS),
[sj.eval.feedback for sj in scored_jokes])
async def spawn_child(parent: PromptEvaluation, gen: int) -> Path:
child = parent.prompt_file.with_name(parent.prompt_file.stem + f"_g{gen}.txt")
child.write_text(parent.prompt_file.read_text())
sorted_scores = sorted(parent.scores, key=lambda sj: sj.eval.total)
feedback_data = "\n".join([f"Topic: {sj.topic} (score: {sj.eval.total}): {sj.eval.feedback}"
for sj in sorted_scores])
await EditAgent.run(f"Please make edits based on the provided feedback.\n\nPROMPT:\n{child.read_text()}\nSCORES+FEEDBACK:\n{feedback_data}", deps=str(child))
return child
async def evolve(seed_file: Path, generations: int = 6, pop_size: int = 4) -> None:
population = [seed_file]
logger.info(f"Starting evolution: gens={generations} pop={pop_size}")
for gen in trange(1, generations + 1, desc="Gen"):
eval_tasks = [evaluate_prompt(pf, gen) for pf in population]
evaluations = await asyncio.gather(*eval_tasks)
evaluations.sort(key=lambda e: e.avg_score, reverse=True)
avg_all = sum(e.avg_score for e in evaluations) / len(evaluations)
logger.success(f"Gen {gen} avg={avg_all:.2f} best={evaluations[0].avg_score:.2f} "
f"file={evaluations[0].prompt_file.name}")
survivors = evaluations[:max(2, pop_size // 2)]
child_tasks = [spawn_child(parent, gen) for parent in survivors]
children = await asyncio.gather(*child_tasks)
population = [e.prompt_file for e in survivors] + children
population = population[:pop_size]
logger.success(f"Best prompt:\n{evaluations[0].prompt_file.read_text()}")
if __name__ == "__main__":
if len(sys.argv) < 2:
sys.exit("Usage: python metalearner_joke.py prompts/joke_seed.txt [GENS] [POP]")
seed_path = Path(sys.argv[1])
gens = int(sys.argv[2]) if len(sys.argv) > 2 else 6
pop = int(sys.argv[3]) if len(sys.argv) > 3 else 4
asyncio.run(evolve(seed_path, gens, pop))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment