|
#!/usr/bin/env python3 |
|
from __future__ import annotations |
|
|
|
import argparse |
|
import hashlib |
|
import json |
|
import os |
|
import re |
|
import shutil |
|
import sqlite3 |
|
import sys |
|
import textwrap |
|
from dataclasses import dataclass |
|
from datetime import datetime, timezone |
|
from pathlib import Path |
|
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple |
|
|
|
import numpy as np |
|
|
|
try: |
|
import orjson # type: ignore |
|
except Exception: |
|
orjson = None |
|
|
|
try: |
|
import faiss # type: ignore |
|
except Exception: |
|
faiss = None |
|
|
|
try: |
|
from sentence_transformers import CrossEncoder, SentenceTransformer # type: ignore |
|
except Exception: |
|
CrossEncoder = None |
|
SentenceTransformer = None |
|
|
|
try: |
|
import mlx.core as mx # type: ignore |
|
from mlx_embeddings.utils import load as mlx_load # type: ignore |
|
except Exception: |
|
mx = None |
|
mlx_load = None |
|
|
|
try: |
|
import psycopg # type: ignore |
|
except Exception: |
|
psycopg = None |
|
|
|
|
|
DEFAULT_EMBED_MODEL = "mlx-community/Qwen3-Embedding-0.6B-mxfp8" |
|
DEFAULT_RERANK_MODEL = "mlx-community/Qwen3-Reranker-0.6B-mxfp8" |
|
DEFAULT_MAX_CHUNK_CHARS = 2600 |
|
DEFAULT_OVERLAP_CHARS = 260 |
|
DEFAULT_TOOL_OUTPUT_PREVIEW = 500 |
|
DEFAULT_TOOL_OUTPUT_CHUNK_CHARS = 3200 |
|
DEFAULT_BATCH_SIZE = 64 |
|
DEFAULT_OUT_DIR = "./codex_rag_store" |
|
|
|
RELEVANT_SQLITE_TABLE_TOKENS = ( |
|
"thread", |
|
"session", |
|
"turn", |
|
"message", |
|
"history", |
|
"item", |
|
"tag", |
|
"meta", |
|
"rollout", |
|
) |
|
|
|
SKIP_RESPONSE_ROLES = {"developer", "system"} |
|
SKIP_MESSAGE_PREFIXES = ( |
|
"<environment_context>", |
|
"<collaboration_mode>", |
|
"# Context from my IDE setup:", |
|
) |
|
|
|
|
|
@dataclass |
|
class Chunk: |
|
chunk_id: str |
|
session_id: Optional[str] |
|
source_kind: str |
|
source_path: str |
|
chunk_type: str |
|
cwd: Optional[str] |
|
created_at: Optional[str] |
|
turn_index: Optional[int] |
|
chunk_index: int |
|
metadata: Dict[str, Any] |
|
text: str |
|
|
|
|
|
class Store: |
|
def __init__(self, out_dir: Path): |
|
self.out_dir = out_dir |
|
self.db_path = out_dir / "chunks.sqlite" |
|
self.index_path = out_dir / "dense.index" |
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
self.conn = sqlite3.connect(self.db_path) |
|
self.conn.row_factory = sqlite3.Row |
|
self.conn.execute("PRAGMA journal_mode=WAL;") |
|
self.conn.execute("PRAGMA synchronous=NORMAL;") |
|
self._init_schema() |
|
|
|
def _init_schema(self) -> None: |
|
self.conn.executescript( |
|
""" |
|
CREATE TABLE IF NOT EXISTS chunks ( |
|
vector_id INTEGER PRIMARY KEY, |
|
chunk_id TEXT NOT NULL UNIQUE, |
|
session_id TEXT, |
|
source_kind TEXT NOT NULL, |
|
source_path TEXT NOT NULL, |
|
chunk_type TEXT NOT NULL, |
|
cwd TEXT, |
|
created_at TEXT, |
|
turn_index INTEGER, |
|
chunk_index INTEGER NOT NULL, |
|
metadata_json TEXT NOT NULL, |
|
text TEXT NOT NULL, |
|
content_hash TEXT NOT NULL UNIQUE |
|
); |
|
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_source_path ON chunks(source_path); |
|
CREATE INDEX IF NOT EXISTS idx_chunks_session_id ON chunks(session_id); |
|
CREATE INDEX IF NOT EXISTS idx_chunks_chunk_type ON chunks(chunk_type); |
|
|
|
CREATE TABLE IF NOT EXISTS sources ( |
|
source_path TEXT PRIMARY KEY, |
|
source_kind TEXT NOT NULL, |
|
file_size INTEGER NOT NULL, |
|
mtime_ns INTEGER NOT NULL, |
|
ingested_at TEXT NOT NULL |
|
); |
|
|
|
CREATE TABLE IF NOT EXISTS settings ( |
|
key TEXT PRIMARY KEY, |
|
value TEXT NOT NULL |
|
); |
|
|
|
CREATE VIRTUAL TABLE IF NOT EXISTS chunk_fts USING fts5( |
|
chunk_id UNINDEXED, |
|
vector_id UNINDEXED, |
|
text, |
|
tokenize = 'unicode61 remove_diacritics 2 tokenchars ''_-./:''' |
|
); |
|
""" |
|
) |
|
self.conn.commit() |
|
|
|
def close(self) -> None: |
|
self.conn.commit() |
|
self.conn.close() |
|
|
|
def reset(self) -> None: |
|
self.close() |
|
if self.out_dir.exists(): |
|
shutil.rmtree(self.out_dir) |
|
self.out_dir.mkdir(parents=True, exist_ok=True) |
|
self.__init__(self.out_dir) |
|
|
|
def get_setting(self, key: str) -> Optional[str]: |
|
row = self.conn.execute("SELECT value FROM settings WHERE key = ?", (key,)).fetchone() |
|
return None if row is None else row[0] |
|
|
|
def set_setting(self, key: str, value: str) -> None: |
|
self.conn.execute( |
|
"INSERT INTO settings(key, value) VALUES(?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", |
|
(key, value), |
|
) |
|
self.conn.commit() |
|
|
|
def source_is_unchanged(self, path: Path, kind: str) -> bool: |
|
stat = path.stat() |
|
row = self.conn.execute( |
|
"SELECT file_size, mtime_ns, source_kind FROM sources WHERE source_path = ?", |
|
(str(path),), |
|
).fetchone() |
|
return bool( |
|
row |
|
and int(row["file_size"]) == int(stat.st_size) |
|
and int(row["mtime_ns"]) == int(stat.st_mtime_ns) |
|
and row["source_kind"] == kind |
|
) |
|
|
|
def mark_source(self, path: Path, kind: str) -> None: |
|
stat = path.stat() |
|
self.conn.execute( |
|
""" |
|
INSERT INTO sources(source_path, source_kind, file_size, mtime_ns, ingested_at) |
|
VALUES(?, ?, ?, ?, ?) |
|
ON CONFLICT(source_path) DO UPDATE SET |
|
source_kind = excluded.source_kind, |
|
file_size = excluded.file_size, |
|
mtime_ns = excluded.mtime_ns, |
|
ingested_at = excluded.ingested_at |
|
""", |
|
( |
|
str(path), |
|
kind, |
|
int(stat.st_size), |
|
int(stat.st_mtime_ns), |
|
utc_now_iso(), |
|
), |
|
) |
|
self.conn.commit() |
|
|
|
def remove_source_chunks(self, source_path: str, index: Any | None) -> int: |
|
rows = self.conn.execute( |
|
"SELECT vector_id FROM chunks WHERE source_path = ? ORDER BY vector_id", |
|
(source_path,), |
|
).fetchall() |
|
if not rows: |
|
return 0 |
|
ids = np.array([int(r["vector_id"]) for r in rows], dtype=np.int64) |
|
if index is not None and faiss is not None and index.ntotal > 0: |
|
index.remove_ids(ids) |
|
placeholders = ",".join(["?"] * len(ids)) |
|
self.conn.execute(f"DELETE FROM chunk_fts WHERE vector_id IN ({placeholders})", tuple(int(i) for i in ids.tolist())) |
|
self.conn.execute("DELETE FROM chunks WHERE source_path = ?", (source_path,)) |
|
self.conn.commit() |
|
return len(ids) |
|
|
|
def insert_chunks(self, chunks: Sequence[Chunk]) -> List[int]: |
|
vector_ids: List[int] = [] |
|
for chunk in chunks: |
|
content_hash = sha256_hex(f"{chunk.source_path}\n{chunk.chunk_id}\n{chunk.text}") |
|
cur = self.conn.execute( |
|
""" |
|
INSERT INTO chunks( |
|
chunk_id, |
|
session_id, |
|
source_kind, |
|
source_path, |
|
chunk_type, |
|
cwd, |
|
created_at, |
|
turn_index, |
|
chunk_index, |
|
metadata_json, |
|
text, |
|
content_hash |
|
) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) |
|
""", |
|
( |
|
chunk.chunk_id, |
|
chunk.session_id, |
|
chunk.source_kind, |
|
chunk.source_path, |
|
chunk.chunk_type, |
|
chunk.cwd, |
|
chunk.created_at, |
|
chunk.turn_index, |
|
chunk.chunk_index, |
|
dumps_json(chunk.metadata), |
|
chunk.text, |
|
content_hash, |
|
), |
|
) |
|
vector_id = int(cur.lastrowid) |
|
self.conn.execute( |
|
"INSERT INTO chunk_fts(chunk_id, vector_id, text) VALUES(?, ?, ?)", |
|
(chunk.chunk_id, vector_id, chunk.text), |
|
) |
|
vector_ids.append(vector_id) |
|
self.conn.commit() |
|
return vector_ids |
|
|
|
def fetch_rows_by_ids(self, vector_ids: Sequence[int]) -> List[sqlite3.Row]: |
|
if not vector_ids: |
|
return [] |
|
placeholders = ",".join(["?"] * len(vector_ids)) |
|
rows = self.conn.execute( |
|
f"SELECT * FROM chunks WHERE vector_id IN ({placeholders})", |
|
tuple(int(v) for v in vector_ids), |
|
).fetchall() |
|
row_map = {int(r["vector_id"]): r for r in rows} |
|
return [row_map[int(v)] for v in vector_ids if int(v) in row_map] |
|
|
|
def dense_index_exists(self) -> bool: |
|
return self.index_path.exists() |
|
|
|
|
|
class PgVectorMirror: |
|
def __init__(self, dsn: str, dim: int): |
|
if psycopg is None: |
|
raise RuntimeError("psycopg is not installed. Install psycopg[binary] to use pgvector.") |
|
self.dsn = dsn |
|
self.dim = dim |
|
self.conn = psycopg.connect(dsn) |
|
self.conn.autocommit = True |
|
self._init_schema() |
|
|
|
def close(self) -> None: |
|
self.conn.close() |
|
|
|
def _init_schema(self) -> None: |
|
with self.conn.cursor() as cur: |
|
cur.execute("CREATE EXTENSION IF NOT EXISTS vector") |
|
cur.execute( |
|
f""" |
|
CREATE TABLE IF NOT EXISTS codex_chunks ( |
|
chunk_id TEXT PRIMARY KEY, |
|
vector_id BIGINT UNIQUE NOT NULL, |
|
session_id TEXT, |
|
source_kind TEXT NOT NULL, |
|
source_path TEXT NOT NULL, |
|
chunk_type TEXT NOT NULL, |
|
cwd TEXT, |
|
created_at TEXT, |
|
turn_index INTEGER, |
|
chunk_index INTEGER NOT NULL, |
|
metadata_json JSONB NOT NULL, |
|
text TEXT NOT NULL, |
|
embedding vector({self.dim}) NOT NULL |
|
) |
|
""" |
|
) |
|
cur.execute( |
|
"CREATE INDEX IF NOT EXISTS codex_chunks_source_path_idx ON codex_chunks(source_path)" |
|
) |
|
|
|
def remove_source(self, source_path: str) -> None: |
|
with self.conn.cursor() as cur: |
|
cur.execute("DELETE FROM codex_chunks WHERE source_path = %s", (source_path,)) |
|
|
|
def upsert_batch(self, vector_ids: Sequence[int], chunks: Sequence[Chunk], embeddings: np.ndarray) -> None: |
|
with self.conn.cursor() as cur: |
|
for vector_id, chunk, embedding in zip(vector_ids, chunks, embeddings): |
|
cur.execute( |
|
""" |
|
INSERT INTO codex_chunks( |
|
chunk_id, vector_id, session_id, source_kind, source_path, chunk_type, |
|
cwd, created_at, turn_index, chunk_index, metadata_json, text, embedding |
|
) |
|
VALUES(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s::jsonb, %s, %s::vector) |
|
ON CONFLICT(chunk_id) DO UPDATE SET |
|
vector_id = excluded.vector_id, |
|
session_id = excluded.session_id, |
|
source_kind = excluded.source_kind, |
|
source_path = excluded.source_path, |
|
chunk_type = excluded.chunk_type, |
|
cwd = excluded.cwd, |
|
created_at = excluded.created_at, |
|
turn_index = excluded.turn_index, |
|
chunk_index = excluded.chunk_index, |
|
metadata_json = excluded.metadata_json, |
|
text = excluded.text, |
|
embedding = excluded.embedding |
|
""", |
|
( |
|
chunk.chunk_id, |
|
int(vector_id), |
|
chunk.session_id, |
|
chunk.source_kind, |
|
chunk.source_path, |
|
chunk.chunk_type, |
|
chunk.cwd, |
|
chunk.created_at, |
|
chunk.turn_index, |
|
int(chunk.chunk_index), |
|
dumps_json(chunk.metadata), |
|
chunk.text, |
|
vector_to_pg_literal(embedding), |
|
), |
|
) |
|
|
|
def search(self, query_vector: np.ndarray, limit: int) -> List[Tuple[int, float]]: |
|
with self.conn.cursor() as cur: |
|
cur.execute( |
|
""" |
|
SELECT vector_id, 1 - (embedding <=> %s::vector) AS score |
|
FROM codex_chunks |
|
ORDER BY embedding <=> %s::vector |
|
LIMIT %s |
|
""", |
|
(vector_to_pg_literal(query_vector[0]), vector_to_pg_literal(query_vector[0]), int(limit)), |
|
) |
|
return [(int(row[0]), float(row[1])) for row in cur.fetchall()] |
|
|
|
|
|
def utc_now_iso() -> str: |
|
return datetime.now(timezone.utc).replace(microsecond=0).isoformat() |
|
|
|
|
|
def dumps_json(obj: Any) -> str: |
|
if orjson is not None: |
|
return orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode("utf-8") |
|
return json.dumps(obj, sort_keys=True, ensure_ascii=False) |
|
|
|
|
|
def loads_json(text: str) -> Any: |
|
if orjson is not None: |
|
return orjson.loads(text) |
|
return json.loads(text) |
|
|
|
|
|
def sha256_hex(text: str) -> str: |
|
return hashlib.sha256(text.encode("utf-8", errors="ignore")).hexdigest() |
|
|
|
|
|
def eprint(message: str) -> None: |
|
print(message, file=sys.stderr) |
|
|
|
|
|
def normalize_whitespace(text: str) -> str: |
|
text = text.replace("\r\n", "\n").replace("\r", "\n") |
|
text = "\n".join(line.rstrip() for line in text.split("\n")) |
|
text = re.sub(r"\n{3,}", "\n\n", text) |
|
return text.strip() |
|
|
|
|
|
def compact_text(text: str, limit: Optional[int] = None) -> str: |
|
text = normalize_whitespace(text) |
|
if limit is not None and len(text) > limit: |
|
return text[:limit].rstrip() + "\n\n[truncated]" |
|
return text |
|
|
|
|
|
def flatten_content(content: Any) -> str: |
|
parts: List[str] = [] |
|
if isinstance(content, list): |
|
for item in content: |
|
if isinstance(item, dict): |
|
item_type = item.get("type") |
|
if item_type in {"input_text", "output_text", "text"}: |
|
value = item.get("text") |
|
if isinstance(value, str): |
|
parts.append(value) |
|
elif item_type == "input_image": |
|
parts.append("[image]") |
|
elif item_type == "local_image": |
|
parts.append(f"[local_image: {item.get('path', '')}]") |
|
else: |
|
parts.append(dumps_json(item)) |
|
elif isinstance(item, str): |
|
parts.append(item) |
|
else: |
|
parts.append(str(item)) |
|
elif isinstance(content, dict): |
|
parts.append(dumps_json(content)) |
|
elif content is None: |
|
return "" |
|
else: |
|
parts.append(str(content)) |
|
return compact_text("\n".join(p for p in parts if p)) |
|
|
|
|
|
def maybe_parse_json_string(text: str) -> str: |
|
text = text.strip() |
|
if not text: |
|
return "" |
|
if text[0] in "[{": |
|
try: |
|
obj = loads_json(text) |
|
return json.dumps(obj, indent=2, ensure_ascii=False, sort_keys=True) |
|
except Exception: |
|
return text |
|
return text |
|
|
|
|
|
def split_text_smart(text: str, max_chars: int, overlap_chars: int) -> List[str]: |
|
text = normalize_whitespace(text) |
|
if len(text) <= max_chars: |
|
return [text] |
|
|
|
chunks: List[str] = [] |
|
start = 0 |
|
text_len = len(text) |
|
while start < text_len: |
|
end = min(start + max_chars, text_len) |
|
if end < text_len: |
|
window = text[start:end] |
|
split_at = max(window.rfind("\n\n"), window.rfind("\n"), window.rfind(". ")) |
|
if split_at > int(max_chars * 0.55): |
|
end = start + split_at + (2 if window[split_at:split_at + 2] == ". " else 0) |
|
piece = text[start:end].strip() |
|
if piece: |
|
chunks.append(piece) |
|
if end >= text_len: |
|
break |
|
start = max(0, end - overlap_chars) |
|
return chunks |
|
|
|
|
|
def sanitize_for_chunk_id(text: str) -> str: |
|
return re.sub(r"[^A-Za-z0-9_.-]+", "-", text).strip("-")[:80] or "chunk" |
|
|
|
|
|
def vector_to_pg_literal(vector: np.ndarray) -> str: |
|
if vector.ndim != 1: |
|
vector = np.asarray(vector).reshape(-1) |
|
return "[" + ",".join(f"{float(v):.8f}" for v in vector.tolist()) + "]" |
|
|
|
|
|
def make_chunk_id(source_path: str, base: str, index: int) -> str: |
|
seed = f"{source_path}\n{base}\n{index}" |
|
digest = sha256_hex(seed)[:16] |
|
return f"{sanitize_for_chunk_id(base)}-{index:04d}-{digest}" |
|
|
|
|
|
def discover_codex_home(explicit: Optional[str]) -> Path: |
|
candidates: List[Path] = [] |
|
if explicit: |
|
candidates.append(Path(explicit).expanduser()) |
|
env_home = os.getenv("CODEX_HOME") |
|
if env_home: |
|
candidates.append(Path(env_home).expanduser()) |
|
candidates.append(Path.home() / ".codex") |
|
userprofile = os.getenv("USERPROFILE") |
|
if userprofile: |
|
candidates.append(Path(userprofile) / ".codex") |
|
candidates.append(Path("/app/.codex")) |
|
|
|
seen: set[str] = set() |
|
for candidate in candidates: |
|
resolved = str(candidate) |
|
if resolved in seen: |
|
continue |
|
seen.add(resolved) |
|
if candidate.exists(): |
|
return candidate |
|
return candidates[0] |
|
|
|
|
|
def discover_sources( |
|
codex_home: Path, |
|
include_archived: bool, |
|
include_sqlite: bool, |
|
include_aux_jsonl: bool, |
|
) -> List[Tuple[str, Path]]: |
|
sources: List[Tuple[str, Path]] = [] |
|
|
|
def add(kind: str, path: Path) -> None: |
|
if path.exists() and path.is_file(): |
|
sources.append((kind, path)) |
|
|
|
sessions_dir = codex_home / "sessions" |
|
if sessions_dir.exists(): |
|
for path in sorted(sessions_dir.rglob("rollout-*.jsonl")): |
|
add("rollout_jsonl", path) |
|
|
|
if include_archived: |
|
archived = codex_home / "archived_sessions" |
|
if archived.exists(): |
|
for path in sorted(archived.rglob("rollout-*.jsonl")): |
|
add("archived_rollout_jsonl", path) |
|
|
|
if include_aux_jsonl: |
|
add("history_jsonl", codex_home / "history.jsonl") |
|
add("session_tags_jsonl", codex_home / "session-tags.jsonl") |
|
|
|
if include_sqlite: |
|
for path in sorted(codex_home.glob("state*.sqlite")): |
|
add("state_sqlite", path) |
|
|
|
return sources |
|
|
|
|
|
def extract_user_message(payload: Dict[str, Any]) -> str: |
|
message = payload.get("message") |
|
if isinstance(message, str): |
|
return compact_text(message) |
|
return "" |
|
|
|
|
|
def should_skip_response_text(role: str, text: str) -> bool: |
|
if role in SKIP_RESPONSE_ROLES: |
|
return True |
|
trimmed = text.lstrip() |
|
return any(trimmed.startswith(prefix) for prefix in SKIP_MESSAGE_PREFIXES) |
|
|
|
|
|
def summarize_turn_context(payload: Dict[str, Any]) -> Dict[str, Any]: |
|
sandbox = payload.get("sandbox_policy") or {} |
|
if isinstance(sandbox, dict): |
|
sandbox_mode = sandbox.get("type") or sandbox.get("mode") |
|
network_access = sandbox.get("network_access") |
|
else: |
|
sandbox_mode = None |
|
network_access = None |
|
return { |
|
"cwd": payload.get("cwd"), |
|
"approval_policy": payload.get("approval_policy"), |
|
"sandbox_mode": sandbox_mode, |
|
"network_access": network_access, |
|
"model": payload.get("model"), |
|
"summary": payload.get("summary"), |
|
"effort": payload.get("effort"), |
|
} |
|
|
|
|
|
def render_overview_chunk( |
|
source_path: str, |
|
source_kind: str, |
|
session_meta: Dict[str, Any], |
|
seen_session_ids: Sequence[str], |
|
turn_count: int, |
|
tool_count: int, |
|
first_user: str, |
|
) -> Chunk: |
|
session_id = session_meta.get("id") |
|
text = "\n".join( |
|
[ |
|
"Session overview", |
|
f"session_id: {session_id}", |
|
f"source_kind: {source_kind}", |
|
f"source_path: {source_path}", |
|
f"created_at: {session_meta.get('timestamp')}", |
|
f"cwd: {session_meta.get('cwd')}", |
|
f"originator: {session_meta.get('originator')}", |
|
f"cli_version: {session_meta.get('cli_version')}", |
|
f"source: {session_meta.get('source')}", |
|
f"model_provider: {session_meta.get('model_provider')}", |
|
f"seen_session_ids: {', '.join(seen_session_ids) if seen_session_ids else 'none'}", |
|
f"turn_count: {turn_count}", |
|
f"tool_call_count: {tool_count}", |
|
"first_user_message:", |
|
first_user or "[none]", |
|
] |
|
).strip() |
|
metadata = { |
|
"kind": "session_overview", |
|
"session_meta": session_meta, |
|
"seen_session_ids": list(seen_session_ids), |
|
"turn_count": turn_count, |
|
"tool_call_count": tool_count, |
|
} |
|
return Chunk( |
|
chunk_id=make_chunk_id(source_path, f"{session_id or 'session'}-overview", 0), |
|
session_id=session_id, |
|
source_kind=source_kind, |
|
source_path=source_path, |
|
chunk_type="session_overview", |
|
cwd=session_meta.get("cwd"), |
|
created_at=session_meta.get("timestamp"), |
|
turn_index=None, |
|
chunk_index=0, |
|
metadata=metadata, |
|
text=compact_text(text), |
|
) |
|
|
|
|
|
def render_turn_preview(turn: Dict[str, Any], tool_output_preview_chars: int) -> str: |
|
lines: List[str] = [] |
|
if turn.get("user_messages"): |
|
lines.append("user:") |
|
lines.append(compact_text("\n\n".join(turn["user_messages"]))) |
|
if turn.get("assistant_messages"): |
|
lines.append("") |
|
lines.append("assistant:") |
|
lines.append(compact_text("\n\n".join(turn["assistant_messages"]))) |
|
if turn.get("tool_calls"): |
|
lines.append("") |
|
lines.append("tool_calls:") |
|
for call in turn["tool_calls"]: |
|
lines.append(f"- {call.get('name') or 'tool'}") |
|
args_text = maybe_parse_json_string(str(call.get("arguments") or "")) |
|
if args_text: |
|
lines.append(textwrap.indent(compact_text(args_text, 1200), " ")) |
|
if turn.get("tool_outputs"): |
|
lines.append("") |
|
lines.append("tool_outputs_preview:") |
|
for output in turn["tool_outputs"]: |
|
call_name = output.get("tool_name") or "tool" |
|
call_id = output.get("call_id") or "" |
|
label = f"- {call_name} ({call_id})" if call_id else f"- {call_name}" |
|
lines.append(label) |
|
output_text = compact_text(str(output.get("output") or ""), tool_output_preview_chars) |
|
if output_text: |
|
lines.append(textwrap.indent(output_text, " ")) |
|
if turn.get("events"): |
|
lines.append("") |
|
lines.append("events:") |
|
for event in turn["events"][:8]: |
|
lines.append(f"- {event}") |
|
return "\n".join(lines).strip() |
|
|
|
|
|
def render_turn_context(turn: Dict[str, Any]) -> str: |
|
ctx = turn.get("context") or {} |
|
if not ctx: |
|
return "" |
|
lines = ["turn_context:"] |
|
for key in ("cwd", "approval_policy", "sandbox_mode", "network_access", "model", "summary", "effort"): |
|
value = ctx.get(key) |
|
if value not in (None, "", []): |
|
lines.append(f"- {key}: {value}") |
|
return "\n".join(lines) |
|
|
|
|
|
def parse_rollout_file( |
|
path: Path, |
|
source_kind: str, |
|
max_chunk_chars: int, |
|
overlap_chars: int, |
|
tool_output_preview_chars: int, |
|
tool_output_chunk_chars: int, |
|
) -> Iterator[Chunk]: |
|
session_meta: Dict[str, Any] = {} |
|
seen_session_ids: List[str] = [] |
|
turns: List[Dict[str, Any]] = [] |
|
current_turn: Optional[Dict[str, Any]] = None |
|
turn_counter = 0 |
|
|
|
def start_turn() -> Dict[str, Any]: |
|
return { |
|
"turn_index": turn_counter, |
|
"user_messages": [], |
|
"assistant_messages": [], |
|
"tool_calls": [], |
|
"tool_outputs": [], |
|
"events": [], |
|
"context": {}, |
|
"start_ts": None, |
|
} |
|
|
|
def finalize_turn() -> None: |
|
nonlocal current_turn, turn_counter |
|
if current_turn is None: |
|
return |
|
has_content = any( |
|
current_turn.get(key) |
|
for key in ("user_messages", "assistant_messages", "tool_calls", "tool_outputs", "events") |
|
) |
|
if has_content: |
|
turns.append(current_turn) |
|
turn_counter += 1 |
|
current_turn = None |
|
|
|
def ensure_turn() -> Dict[str, Any]: |
|
nonlocal current_turn |
|
if current_turn is None: |
|
current_turn = start_turn() |
|
return current_turn |
|
|
|
with path.open("r", encoding="utf-8", errors="replace") as handle: |
|
for raw_line in handle: |
|
line = raw_line.strip() |
|
if not line: |
|
continue |
|
try: |
|
obj = loads_json(line) |
|
except Exception: |
|
continue |
|
if not isinstance(obj, dict): |
|
continue |
|
|
|
item_type = obj.get("type") |
|
payload = obj.get("payload") or {} |
|
timestamp = obj.get("timestamp") |
|
|
|
if item_type == "session_meta" and isinstance(payload, dict): |
|
if not session_meta: |
|
session_meta = dict(payload) |
|
session_id = payload.get("id") |
|
if isinstance(session_id, str) and session_id not in seen_session_ids: |
|
seen_session_ids.append(session_id) |
|
continue |
|
|
|
if item_type == "turn_context" and isinstance(payload, dict): |
|
turn = ensure_turn() |
|
turn["context"] = summarize_turn_context(payload) |
|
if turn["start_ts"] is None: |
|
turn["start_ts"] = timestamp |
|
continue |
|
|
|
if item_type == "event_msg" and isinstance(payload, dict): |
|
event_type = payload.get("type") |
|
if event_type == "user_message": |
|
finalize_turn() |
|
current_turn = start_turn() |
|
current_turn["start_ts"] = timestamp |
|
msg = extract_user_message(payload) |
|
if msg: |
|
current_turn["user_messages"].append(msg) |
|
continue |
|
if event_type == "agent_message": |
|
turn = ensure_turn() |
|
if turn["start_ts"] is None: |
|
turn["start_ts"] = timestamp |
|
message = payload.get("message") |
|
if isinstance(message, str): |
|
turn["assistant_messages"].append(compact_text(message)) |
|
continue |
|
turn = ensure_turn() |
|
if turn["start_ts"] is None: |
|
turn["start_ts"] = timestamp |
|
turn["events"].append(compact_text(dumps_json(payload), 1200)) |
|
continue |
|
|
|
if item_type == "response_item" and isinstance(payload, dict): |
|
response_type = payload.get("type") |
|
if response_type == "message": |
|
role = str(payload.get("role") or "") |
|
text = flatten_content(payload.get("content")) |
|
if not text: |
|
continue |
|
if should_skip_response_text(role, text): |
|
continue |
|
turn = ensure_turn() |
|
if turn["start_ts"] is None: |
|
turn["start_ts"] = timestamp |
|
if role == "user": |
|
if not turn["user_messages"] or turn["user_messages"][-1] != text: |
|
turn["user_messages"].append(text) |
|
elif role == "assistant": |
|
turn["assistant_messages"].append(text) |
|
else: |
|
turn["events"].append(f"response_role={role}: {compact_text(text, 1200)}") |
|
continue |
|
if response_type == "function_call": |
|
turn = ensure_turn() |
|
if turn["start_ts"] is None: |
|
turn["start_ts"] = timestamp |
|
turn["tool_calls"].append( |
|
{ |
|
"name": payload.get("name"), |
|
"call_id": payload.get("call_id"), |
|
"arguments": payload.get("arguments"), |
|
} |
|
) |
|
continue |
|
if response_type == "function_call_output": |
|
turn = ensure_turn() |
|
if turn["start_ts"] is None: |
|
turn["start_ts"] = timestamp |
|
call_id = payload.get("call_id") |
|
tool_name = None |
|
for call in reversed(turn["tool_calls"]): |
|
if call.get("call_id") == call_id: |
|
tool_name = call.get("name") |
|
break |
|
turn["tool_outputs"].append( |
|
{ |
|
"call_id": call_id, |
|
"tool_name": tool_name, |
|
"output": payload.get("output"), |
|
} |
|
) |
|
continue |
|
turn = ensure_turn() |
|
turn["events"].append(compact_text(dumps_json(payload), 1200)) |
|
continue |
|
|
|
turn = ensure_turn() |
|
turn["events"].append(compact_text(dumps_json(obj), 1200)) |
|
|
|
finalize_turn() |
|
|
|
session_id = session_meta.get("id") if session_meta else None |
|
total_tool_calls = sum(len(turn.get("tool_calls", [])) for turn in turns) |
|
first_user = next((msg for turn in turns for msg in turn.get("user_messages", []) if msg), "") |
|
if session_meta or turns: |
|
yield render_overview_chunk( |
|
source_path=str(path), |
|
source_kind=source_kind, |
|
session_meta=session_meta, |
|
seen_session_ids=seen_session_ids, |
|
turn_count=len(turns), |
|
tool_count=total_tool_calls, |
|
first_user=first_user, |
|
) |
|
|
|
for turn in turns: |
|
turn_index = int(turn.get("turn_index") or 0) |
|
context_block = render_turn_context(turn) |
|
preview_text = render_turn_preview(turn, tool_output_preview_chars) |
|
created_at = turn.get("start_ts") or session_meta.get("timestamp") |
|
cwd = (turn.get("context") or {}).get("cwd") or session_meta.get("cwd") |
|
blocks = [ |
|
f"session_id: {session_id}", |
|
f"source_path: {path}", |
|
f"turn_index: {turn_index}", |
|
f"created_at: {created_at}", |
|
f"cwd: {cwd}", |
|
] |
|
if context_block: |
|
blocks.append(context_block) |
|
if preview_text: |
|
blocks.append(preview_text) |
|
base_text = "\n\n".join(blocks).strip() |
|
subchunks = split_text_smart(base_text, max_chunk_chars, overlap_chars) |
|
for idx, subchunk in enumerate(subchunks): |
|
metadata = { |
|
"kind": "turn", |
|
"source_kind": source_kind, |
|
"session_id": session_id, |
|
"turn_index": turn_index, |
|
"context": turn.get("context") or {}, |
|
"tool_calls": turn.get("tool_calls") or [], |
|
"tool_output_count": len(turn.get("tool_outputs") or []), |
|
} |
|
yield Chunk( |
|
chunk_id=make_chunk_id(str(path), f"{session_id or 'session'}-turn-{turn_index}", idx), |
|
session_id=session_id, |
|
source_kind=source_kind, |
|
source_path=str(path), |
|
chunk_type="turn", |
|
cwd=cwd, |
|
created_at=created_at, |
|
turn_index=turn_index, |
|
chunk_index=idx, |
|
metadata=metadata, |
|
text=subchunk, |
|
) |
|
|
|
for tool_output_idx, tool_output in enumerate(turn.get("tool_outputs") or []): |
|
output_text = compact_text(str(tool_output.get("output") or "")) |
|
if not output_text: |
|
continue |
|
tool_label = tool_output.get("tool_name") or "tool" |
|
call_id = tool_output.get("call_id") or "" |
|
user_prompt = compact_text("\n\n".join(turn.get("user_messages") or []), 1400) |
|
tool_chunk_base = "\n\n".join( |
|
[ |
|
f"session_id: {session_id}", |
|
f"source_path: {path}", |
|
f"turn_index: {turn_index}", |
|
f"created_at: {created_at}", |
|
f"cwd: {cwd}", |
|
f"tool_name: {tool_label}", |
|
f"call_id: {call_id}", |
|
"related_user_message:\n" + user_prompt, |
|
"tool_output:\n" + output_text, |
|
] |
|
).strip() |
|
tool_chunks = split_text_smart(tool_chunk_base, tool_output_chunk_chars, overlap_chars) |
|
for idx, subchunk in enumerate(tool_chunks): |
|
metadata = { |
|
"kind": "tool_output", |
|
"source_kind": source_kind, |
|
"session_id": session_id, |
|
"turn_index": turn_index, |
|
"tool_name": tool_label, |
|
"call_id": call_id, |
|
"tool_output_index": tool_output_idx, |
|
} |
|
yield Chunk( |
|
chunk_id=make_chunk_id( |
|
str(path), |
|
f"{session_id or 'session'}-turn-{turn_index}-tool-{tool_output_idx}", |
|
idx, |
|
), |
|
session_id=session_id, |
|
source_kind=source_kind, |
|
source_path=str(path), |
|
chunk_type="tool_output", |
|
cwd=cwd, |
|
created_at=created_at, |
|
turn_index=turn_index, |
|
chunk_index=idx, |
|
metadata=metadata, |
|
text=subchunk, |
|
) |
|
|
|
|
|
def is_relevant_sqlite_table(table_name: str) -> bool: |
|
lowered = table_name.lower() |
|
return any(token in lowered for token in RELEVANT_SQLITE_TABLE_TOKENS) |
|
|
|
|
|
def stringify_sqlite_value(value: Any) -> Optional[str]: |
|
if value is None: |
|
return None |
|
if isinstance(value, (bytes, bytearray)): |
|
return None |
|
if isinstance(value, (int, float)): |
|
return str(value) |
|
if isinstance(value, str): |
|
return compact_text(value) |
|
return compact_text(str(value)) |
|
|
|
|
|
def parse_sqlite_file(path: Path, source_kind: str, max_chunk_chars: int, overlap_chars: int) -> Iterator[Chunk]: |
|
conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) |
|
conn.row_factory = sqlite3.Row |
|
try: |
|
table_rows = conn.execute( |
|
"SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name" |
|
).fetchall() |
|
chunk_counter = 0 |
|
for table_row in table_rows: |
|
table_name = str(table_row["name"]) |
|
if not is_relevant_sqlite_table(table_name): |
|
continue |
|
column_info = conn.execute(f'PRAGMA table_info("{table_name}")').fetchall() |
|
columns = [str(row[1]) for row in column_info] |
|
select_cols = ", ".join([f'"{c}"' for c in columns]) if columns else "*" |
|
query = f'SELECT rowid, {select_cols} FROM "{table_name}"' |
|
for db_row in conn.execute(query): |
|
row_dict: Dict[str, Any] = {"rowid": db_row[0]} |
|
for idx, col in enumerate(columns, start=1): |
|
row_dict[col] = db_row[idx] |
|
rendered_pairs: List[str] = [] |
|
session_id = None |
|
cwd = None |
|
created_at = None |
|
turn_index = None |
|
for key, value in row_dict.items(): |
|
if key == "rowid": |
|
continue |
|
rendered = stringify_sqlite_value(value) |
|
if rendered in (None, ""): |
|
continue |
|
rendered_pairs.append(f"{key}: {rendered}") |
|
lowered_key = key.lower() |
|
if session_id is None and lowered_key in {"id", "thread_id", "session_id"}: |
|
session_id = rendered |
|
if cwd is None and lowered_key == "cwd": |
|
cwd = rendered |
|
if created_at is None and lowered_key in {"created_at", "timestamp", "updated_at"}: |
|
created_at = rendered |
|
if turn_index is None and lowered_key == "turn_index": |
|
try: |
|
turn_index = int(rendered) |
|
except Exception: |
|
turn_index = None |
|
if not rendered_pairs: |
|
continue |
|
base_text = textwrap.dedent( |
|
f""" |
|
sqlite_source: {path} |
|
table: {table_name} |
|
rowid: {row_dict.get('rowid')} |
|
""" |
|
).strip() + "\n" + "\n".join(rendered_pairs) |
|
for idx, subchunk in enumerate(split_text_smart(base_text, max_chunk_chars, overlap_chars)): |
|
metadata = { |
|
"kind": "sqlite_row", |
|
"table": table_name, |
|
"rowid": row_dict.get("rowid"), |
|
} |
|
yield Chunk( |
|
chunk_id=make_chunk_id(str(path), f"sqlite-{table_name}-{row_dict.get('rowid')}", idx), |
|
session_id=session_id, |
|
source_kind=source_kind, |
|
source_path=str(path), |
|
chunk_type="sqlite_row", |
|
cwd=cwd, |
|
created_at=created_at, |
|
turn_index=turn_index, |
|
chunk_index=idx, |
|
metadata=metadata, |
|
text=subchunk, |
|
) |
|
chunk_counter += 1 |
|
finally: |
|
conn.close() |
|
|
|
|
|
def parse_generic_jsonl(path: Path, source_kind: str, max_chunk_chars: int, overlap_chars: int) -> Iterator[Chunk]: |
|
filename = path.name |
|
with path.open("r", encoding="utf-8", errors="replace") as handle: |
|
batch: List[str] = [] |
|
group_index = 0 |
|
for line_index, raw_line in enumerate(handle): |
|
line = raw_line.strip() |
|
if not line: |
|
continue |
|
try: |
|
obj = loads_json(line) |
|
rendered = json.dumps(obj, indent=2, ensure_ascii=False, sort_keys=True) |
|
except Exception: |
|
rendered = line |
|
batch.append(f"line_index: {line_index}\n{rendered}") |
|
joined = "\n\n".join(batch) |
|
if len(joined) >= max_chunk_chars: |
|
for idx, subchunk in enumerate(split_text_smart(joined, max_chunk_chars, overlap_chars)): |
|
yield Chunk( |
|
chunk_id=make_chunk_id(str(path), f"{filename}-group-{group_index}", idx), |
|
session_id=None, |
|
source_kind=source_kind, |
|
source_path=str(path), |
|
chunk_type="jsonl_group", |
|
cwd=None, |
|
created_at=None, |
|
turn_index=None, |
|
chunk_index=idx, |
|
metadata={"kind": "jsonl_group", "filename": filename, "group_index": group_index}, |
|
text=subchunk, |
|
) |
|
batch = [] |
|
group_index += 1 |
|
if batch: |
|
joined = "\n\n".join(batch) |
|
for idx, subchunk in enumerate(split_text_smart(joined, max_chunk_chars, overlap_chars)): |
|
yield Chunk( |
|
chunk_id=make_chunk_id(str(path), f"{filename}-group-{group_index}", idx), |
|
session_id=None, |
|
source_kind=source_kind, |
|
source_path=str(path), |
|
chunk_type="jsonl_group", |
|
cwd=None, |
|
created_at=None, |
|
turn_index=None, |
|
chunk_index=idx, |
|
metadata={"kind": "jsonl_group", "filename": filename, "group_index": group_index}, |
|
text=subchunk, |
|
) |
|
|
|
|
|
def iter_chunks_from_source( |
|
kind: str, |
|
path: Path, |
|
max_chunk_chars: int, |
|
overlap_chars: int, |
|
tool_output_preview_chars: int, |
|
tool_output_chunk_chars: int, |
|
) -> Iterator[Chunk]: |
|
if kind in {"rollout_jsonl", "archived_rollout_jsonl"}: |
|
yield from parse_rollout_file( |
|
path, |
|
kind, |
|
max_chunk_chars, |
|
overlap_chars, |
|
tool_output_preview_chars, |
|
tool_output_chunk_chars, |
|
) |
|
elif kind == "state_sqlite": |
|
yield from parse_sqlite_file(path, kind, max_chunk_chars, overlap_chars) |
|
else: |
|
yield from parse_generic_jsonl(path, kind, max_chunk_chars, overlap_chars) |
|
|
|
|
|
def is_mlx_model_name(model_name: str) -> bool: |
|
lowered = (model_name or "").lower() |
|
return lowered.startswith("mlx-community/") or lowered.endswith("-mlx") or "mlx" in lowered |
|
|
|
|
|
class MLXEmbeddingModel: |
|
def __init__(self, model_name: str, max_length: int = 8192): |
|
if mlx_load is None or mx is None: |
|
raise RuntimeError( |
|
"MLX embedding support is unavailable. Install mlx and mlx-embeddings on Apple Silicon." |
|
) |
|
self.model_name = model_name |
|
self.max_length = max_length |
|
self.model, self.tokenizer = mlx_load(model_name) |
|
self._dim: Optional[int] = None |
|
|
|
def encode(self, texts: Sequence[str], batch_size: int = 16, **_: Any) -> np.ndarray: |
|
batches: List[np.ndarray] = [] |
|
text_list = list(texts) |
|
for start in range(0, len(text_list), max(1, int(batch_size))): |
|
batch = text_list[start:start + max(1, int(batch_size))] |
|
inputs = self.tokenizer.batch_encode_plus( |
|
batch, |
|
return_tensors="mlx", |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_length, |
|
) |
|
outputs = self.model( |
|
inputs["input_ids"], |
|
attention_mask=inputs.get("attention_mask"), |
|
) |
|
embeds = np.asarray(outputs.text_embeds) |
|
norms = np.linalg.norm(embeds, axis=1, keepdims=True) |
|
norms[norms == 0] = 1.0 |
|
embeds = embeds / norms |
|
batches.append(embeds.astype("float32")) |
|
if not batches: |
|
return np.zeros((0, self.get_sentence_embedding_dimension()), dtype="float32") |
|
result = np.vstack(batches).astype("float32") |
|
if self._dim is None and result.size: |
|
self._dim = int(result.shape[1]) |
|
return result |
|
|
|
def get_sentence_embedding_dimension(self) -> int: |
|
if self._dim is None: |
|
probe = self.encode(["dimension probe"], batch_size=1) |
|
self._dim = int(probe.shape[1]) |
|
return int(self._dim) |
|
|
|
|
|
class MLXQwenReranker: |
|
def __init__(self, model_name: str, max_length: int = 8192): |
|
if mlx_load is None: |
|
raise RuntimeError( |
|
"MLX reranker support is unavailable. Install mlx and mlx-embeddings on Apple Silicon." |
|
) |
|
self.model_name = model_name |
|
self.max_length = max_length |
|
self.model, self.tokenizer = mlx_load(model_name) |
|
self.prefix = ( |
|
"<|im_start|>system\n" |
|
' Judge whether the Document meets the requirements based on the Query and the Instruct provided. ' |
|
'Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' |
|
) |
|
self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" |
|
self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False) |
|
self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False) |
|
yes_ids = self.tokenizer("yes", add_special_tokens=False).input_ids |
|
no_ids = self.tokenizer("no", add_special_tokens=False).input_ids |
|
if not yes_ids or not no_ids: |
|
raise RuntimeError("Could not resolve yes/no token ids for the MLX reranker model.") |
|
self.token_true_id = int(yes_ids[0]) |
|
self.token_false_id = int(no_ids[0]) |
|
self.default_instruction = "Given a local coding-memory query, retrieve relevant Codex history chunks that answer the query" |
|
|
|
def _format_pair(self, query: str, doc: str, instruction: Optional[str] = None) -> str: |
|
chosen_instruction = instruction or self.default_instruction |
|
return f"<Instruct>: {chosen_instruction}\n<Query>: {query}\n<Document>: {doc}" |
|
|
|
def predict(self, pairs: Sequence[Tuple[str, str]], batch_size: int = 8, **_: Any) -> List[float]: |
|
scores: List[float] = [] |
|
for start in range(0, len(pairs), max(1, int(batch_size))): |
|
batch_pairs = pairs[start:start + max(1, int(batch_size))] |
|
formatted = [self._format_pair(query, doc) for query, doc in batch_pairs] |
|
inputs = self.tokenizer( |
|
formatted, |
|
padding=False, |
|
truncation='longest_first', |
|
return_attention_mask=False, |
|
max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens), |
|
) |
|
for i, ele in enumerate(inputs['input_ids']): |
|
inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens |
|
padded = self.tokenizer.pad(inputs, padding=True, return_tensors="mlx", max_length=self.max_length) |
|
outputs = self.model( |
|
padded["input_ids"], |
|
attention_mask=padded.get("attention_mask"), |
|
) |
|
logits = np.asarray(outputs.logits)[:, -1, :] |
|
pair_logits = np.stack( |
|
[logits[:, self.token_false_id], logits[:, self.token_true_id]], |
|
axis=1, |
|
) |
|
pair_logits = pair_logits - pair_logits.max(axis=1, keepdims=True) |
|
probs = np.exp(pair_logits) |
|
probs = probs / probs.sum(axis=1, keepdims=True) |
|
scores.extend(float(x) for x in probs[:, 1].tolist()) |
|
return scores |
|
|
|
|
|
def require_embedding_stack(model_name: Optional[str] = None) -> None: |
|
if faiss is None: |
|
raise RuntimeError("Missing dependency faiss-cpu. Install it before running ingest/query.") |
|
if model_name and is_mlx_model_name(model_name): |
|
if mlx_load is None or mx is None: |
|
raise RuntimeError( |
|
"MLX model support is unavailable. Install mlx and mlx-embeddings on Apple Silicon." |
|
) |
|
return |
|
if SentenceTransformer is None: |
|
raise RuntimeError( |
|
"SentenceTransformer is unavailable. Install sentence-transformers or use an MLX model on Apple Silicon." |
|
) |
|
|
|
|
|
def load_embedding_model(model_name: str) -> Any: |
|
require_embedding_stack(model_name) |
|
if is_mlx_model_name(model_name): |
|
return MLXEmbeddingModel(model_name) |
|
return SentenceTransformer(model_name) |
|
|
|
|
|
def load_reranker(model_name: str) -> Any: |
|
if is_mlx_model_name(model_name): |
|
return MLXQwenReranker(model_name) |
|
if CrossEncoder is None: |
|
raise RuntimeError("CrossEncoder is unavailable. Install sentence-transformers to enable reranking.") |
|
return CrossEncoder(model_name) |
|
|
|
|
|
def create_faiss_index(dim: int) -> Any: |
|
require_embedding_stack() |
|
base = faiss.IndexFlatIP(dim) |
|
index = faiss.IndexIDMap2(base) |
|
return index |
|
|
|
|
|
def load_or_create_faiss_index(store: Store, dim: int) -> Any: |
|
require_embedding_stack() |
|
if store.index_path.exists(): |
|
return faiss.read_index(str(store.index_path)) |
|
return create_faiss_index(dim) |
|
|
|
|
|
def save_faiss_index(index: Any, path: Path) -> None: |
|
require_embedding_stack() |
|
faiss.write_index(index, str(path)) |
|
|
|
|
|
def encode_texts(model: Any, texts: Sequence[str], batch_size: int) -> np.ndarray: |
|
embeddings = model.encode( |
|
list(texts), |
|
batch_size=batch_size, |
|
show_progress_bar=False, |
|
convert_to_numpy=True, |
|
normalize_embeddings=True, |
|
) |
|
return np.asarray(embeddings, dtype="float32") |
|
|
|
|
|
def parse_fts_query(text: str) -> Optional[str]: |
|
tokens = re.findall(r"[A-Za-z0-9_./:-]{2,}", text.lower()) |
|
if not tokens: |
|
return None |
|
unique: List[str] = [] |
|
seen: set[str] = set() |
|
for token in tokens: |
|
if token in seen: |
|
continue |
|
seen.add(token) |
|
unique.append(token) |
|
if len(unique) >= 16: |
|
break |
|
return " OR ".join(f'"{token}"' for token in unique) |
|
|
|
|
|
def reciprocal_rank_fusion(rankings: Sequence[Sequence[int]], k: int = 60) -> List[int]: |
|
scores: Dict[int, float] = {} |
|
for ranking in rankings: |
|
for rank, vector_id in enumerate(ranking, start=1): |
|
scores[int(vector_id)] = scores.get(int(vector_id), 0.0) + 1.0 / (k + rank) |
|
ordered = sorted(scores.items(), key=lambda item: item[1], reverse=True) |
|
return [vector_id for vector_id, _ in ordered] |
|
|
|
|
|
def query_sparse(store: Store, query: str, limit: int) -> List[int]: |
|
fts_query = parse_fts_query(query) |
|
if not fts_query: |
|
return [] |
|
rows = store.conn.execute( |
|
""" |
|
SELECT vector_id, bm25(chunk_fts) AS score |
|
FROM chunk_fts |
|
WHERE chunk_fts MATCH ? |
|
ORDER BY score |
|
LIMIT ? |
|
""", |
|
(fts_query, int(limit)), |
|
).fetchall() |
|
return [int(row["vector_id"]) for row in rows] |
|
|
|
|
|
def query_dense_faiss(index: Any, query_vector: np.ndarray, limit: int) -> List[int]: |
|
if index.ntotal == 0: |
|
return [] |
|
scores, ids = index.search(query_vector, int(limit)) |
|
output: List[int] = [] |
|
for vector_id in ids[0].tolist(): |
|
if int(vector_id) >= 0: |
|
output.append(int(vector_id)) |
|
return output |
|
|
|
|
|
def rerank_rows(query: str, rows: Sequence[sqlite3.Row], reranker: Any, limit: int) -> List[sqlite3.Row]: |
|
if not rows: |
|
return [] |
|
pairs = [(query, str(row["text"])[:3800]) for row in rows] |
|
scores = reranker.predict(pairs, batch_size=16, show_progress_bar=False) |
|
scored_rows = sorted(zip(rows, scores), key=lambda item: float(item[1]), reverse=True) |
|
return [row for row, _ in scored_rows[:limit]] |
|
|
|
|
|
def ingest_command(args: argparse.Namespace) -> int: |
|
codex_home = discover_codex_home(args.codex_home) |
|
out_dir = Path(args.out).expanduser() |
|
store = Store(out_dir) |
|
if args.rebuild: |
|
eprint("[ingest] rebuilding local store") |
|
store.reset() |
|
store = Store(out_dir) |
|
|
|
embed_model_name = args.embedding_model |
|
existing_model_name = store.get_setting("embedding_model") |
|
if existing_model_name and existing_model_name != embed_model_name and not args.rebuild: |
|
raise RuntimeError( |
|
f"Existing store uses embedding model {existing_model_name!r}. Re-run with --rebuild or use the same model." |
|
) |
|
|
|
model = load_embedding_model(embed_model_name) |
|
dim = int(model.get_sentence_embedding_dimension()) |
|
existing_dim = store.get_setting("embedding_dim") |
|
if existing_dim and int(existing_dim) != dim and not args.rebuild: |
|
raise RuntimeError( |
|
f"Existing store uses embedding dimension {existing_dim}. Re-run with --rebuild or use the same model." |
|
) |
|
|
|
store.set_setting("embedding_model", embed_model_name) |
|
store.set_setting("embedding_dim", str(dim)) |
|
store.set_setting("created_at", store.get_setting("created_at") or utc_now_iso()) |
|
|
|
index = load_or_create_faiss_index(store, dim) |
|
pgvector_mirror = PgVectorMirror(args.pgvector_dsn, dim) if args.pgvector_dsn else None |
|
|
|
discovered = discover_sources( |
|
codex_home=codex_home, |
|
include_archived=args.include_archived, |
|
include_sqlite=args.include_sqlite, |
|
include_aux_jsonl=args.include_aux_jsonl, |
|
) |
|
if not discovered: |
|
eprint(f"[ingest] no sources found under {codex_home}") |
|
return 1 |
|
|
|
total_new_chunks = 0 |
|
total_removed_chunks = 0 |
|
total_skipped_sources = 0 |
|
total_processed_sources = 0 |
|
|
|
try: |
|
for kind, path in discovered: |
|
if not args.force and store.source_is_unchanged(path, kind): |
|
total_skipped_sources += 1 |
|
eprint(f"[skip] {path}") |
|
continue |
|
|
|
if args.force or not store.source_is_unchanged(path, kind): |
|
removed = store.remove_source_chunks(str(path), index) |
|
total_removed_chunks += removed |
|
if pgvector_mirror is not None: |
|
pgvector_mirror.remove_source(str(path)) |
|
|
|
eprint(f"[parse] {path}") |
|
chunks_iter = iter_chunks_from_source( |
|
kind=kind, |
|
path=path, |
|
max_chunk_chars=args.max_chunk_chars, |
|
overlap_chars=args.overlap_chars, |
|
tool_output_preview_chars=args.tool_output_preview_chars, |
|
tool_output_chunk_chars=args.tool_output_chunk_chars, |
|
) |
|
|
|
buffer_chunks: List[Chunk] = [] |
|
buffer_texts: List[str] = [] |
|
source_chunk_count = 0 |
|
|
|
def flush() -> None: |
|
nonlocal total_new_chunks, source_chunk_count |
|
if not buffer_chunks: |
|
return |
|
embeddings = encode_texts(model, buffer_texts, batch_size=args.batch_size) |
|
vector_ids = store.insert_chunks(buffer_chunks) |
|
index.add_with_ids(embeddings, np.asarray(vector_ids, dtype=np.int64)) |
|
if pgvector_mirror is not None: |
|
pgvector_mirror.upsert_batch(vector_ids, buffer_chunks, embeddings) |
|
total_new_chunks += len(buffer_chunks) |
|
source_chunk_count += len(buffer_chunks) |
|
buffer_chunks.clear() |
|
buffer_texts.clear() |
|
|
|
for chunk in chunks_iter: |
|
buffer_chunks.append(chunk) |
|
buffer_texts.append(chunk.text) |
|
if len(buffer_chunks) >= args.batch_size: |
|
flush() |
|
flush() |
|
store.mark_source(path, kind) |
|
total_processed_sources += 1 |
|
eprint(f"[done] {path} -> {source_chunk_count} chunks") |
|
|
|
save_faiss_index(index, store.index_path) |
|
print( |
|
dumps_json( |
|
{ |
|
"codex_home": str(codex_home), |
|
"out_dir": str(out_dir), |
|
"embedding_model": embed_model_name, |
|
"embedding_dim": dim, |
|
"sources_found": len(discovered), |
|
"sources_processed": total_processed_sources, |
|
"sources_skipped": total_skipped_sources, |
|
"chunks_removed": total_removed_chunks, |
|
"chunks_added": total_new_chunks, |
|
"index_path": str(store.index_path), |
|
"sqlite_path": str(store.db_path), |
|
"pgvector_enabled": bool(args.pgvector_dsn), |
|
} |
|
) |
|
) |
|
return 0 |
|
finally: |
|
if pgvector_mirror is not None: |
|
pgvector_mirror.close() |
|
store.close() |
|
|
|
|
|
def query_command(args: argparse.Namespace) -> int: |
|
out_dir = Path(args.out).expanduser() |
|
store = Store(out_dir) |
|
try: |
|
if not store.db_path.exists() or not store.conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]: |
|
eprint("[query] store is empty. Run ingest first.") |
|
return 1 |
|
|
|
embed_model_name = store.get_setting("embedding_model") or args.embedding_model |
|
model = load_embedding_model(embed_model_name) |
|
query_vector = encode_texts(model, [args.query], batch_size=1) |
|
|
|
dense_backend = args.dense_backend |
|
dense_ids: List[int] = [] |
|
pgvector_mirror = None |
|
if dense_backend == "pgvector": |
|
if not args.pgvector_dsn: |
|
raise RuntimeError("--dense-backend pgvector requires --pgvector-dsn") |
|
dim = int(store.get_setting("embedding_dim") or query_vector.shape[1]) |
|
pgvector_mirror = PgVectorMirror(args.pgvector_dsn, dim) |
|
dense_ids = [vector_id for vector_id, _score in pgvector_mirror.search(query_vector, args.dense_k)] |
|
else: |
|
if not store.index_path.exists(): |
|
raise RuntimeError("FAISS index is missing. Run ingest first.") |
|
index = faiss.read_index(str(store.index_path)) |
|
dense_ids = query_dense_faiss(index, query_vector, args.dense_k) |
|
|
|
sparse_ids = query_sparse(store, args.query, args.sparse_k) |
|
fused_ids = reciprocal_rank_fusion([dense_ids, sparse_ids])[: max(args.rerank_top_n, args.top_k)] |
|
rows = store.fetch_rows_by_ids(fused_ids) |
|
|
|
final_rows: List[sqlite3.Row] |
|
if args.no_rerank: |
|
final_rows = rows[: args.top_k] |
|
else: |
|
reranker = load_reranker(args.rerank_model) |
|
final_rows = rerank_rows(args.query, rows, reranker, args.top_k) |
|
|
|
if args.context_only: |
|
sections: List[str] = [] |
|
for rank, row in enumerate(final_rows, start=1): |
|
sections.append( |
|
textwrap.dedent( |
|
f""" |
|
===== RETRIEVED CHUNK {rank} ===== |
|
chunk_id: {row['chunk_id']} |
|
session_id: {row['session_id']} |
|
source_kind: {row['source_kind']} |
|
source_path: {row['source_path']} |
|
chunk_type: {row['chunk_type']} |
|
turn_index: {row['turn_index']} |
|
cwd: {row['cwd']} |
|
created_at: {row['created_at']} |
|
{row['text']} |
|
""" |
|
).strip() |
|
) |
|
print("\n\n".join(sections)) |
|
elif args.json: |
|
payload = [] |
|
for rank, row in enumerate(final_rows, start=1): |
|
payload.append( |
|
{ |
|
"rank": rank, |
|
"vector_id": int(row["vector_id"]), |
|
"chunk_id": row["chunk_id"], |
|
"session_id": row["session_id"], |
|
"source_kind": row["source_kind"], |
|
"source_path": row["source_path"], |
|
"chunk_type": row["chunk_type"], |
|
"turn_index": row["turn_index"], |
|
"cwd": row["cwd"], |
|
"created_at": row["created_at"], |
|
"metadata": loads_json(row["metadata_json"]), |
|
"text": row["text"], |
|
} |
|
) |
|
print(dumps_json(payload)) |
|
else: |
|
for rank, row in enumerate(final_rows, start=1): |
|
excerpt = compact_text(str(row["text"]), args.preview_chars) |
|
print(textwrap.dedent( |
|
f""" |
|
===== Result {rank} ===== |
|
chunk_id: {row['chunk_id']} |
|
session_id: {row['session_id']} |
|
source_kind: {row['source_kind']} |
|
source_path: {row['source_path']} |
|
chunk_type: {row['chunk_type']} |
|
turn_index: {row['turn_index']} |
|
cwd: {row['cwd']} |
|
created_at: {row['created_at']} |
|
excerpt: |
|
{excerpt} |
|
""" |
|
).strip()) |
|
print() |
|
return 0 |
|
finally: |
|
if 'pgvector_mirror' in locals() and pgvector_mirror is not None: |
|
pgvector_mirror.close() |
|
store.close() |
|
|
|
|
|
def inspect_command(args: argparse.Namespace) -> int: |
|
codex_home = discover_codex_home(args.codex_home) |
|
sources = discover_sources( |
|
codex_home=codex_home, |
|
include_archived=args.include_archived, |
|
include_sqlite=args.include_sqlite, |
|
include_aux_jsonl=args.include_aux_jsonl, |
|
) |
|
payload = [] |
|
for kind, path in sources: |
|
try: |
|
stat = path.stat() |
|
size = int(stat.st_size) |
|
except FileNotFoundError: |
|
size = 0 |
|
payload.append({"kind": kind, "path": str(path), "size": size}) |
|
print( |
|
dumps_json( |
|
{ |
|
"codex_home": str(codex_home), |
|
"source_count": len(payload), |
|
"sources": payload, |
|
} |
|
) |
|
) |
|
return 0 |
|
|
|
|
|
def build_parser() -> argparse.ArgumentParser: |
|
parser = argparse.ArgumentParser( |
|
description="Turn local Codex session history into a hybrid RAG index with dense retrieval, FTS, and reranking." |
|
) |
|
subparsers = parser.add_subparsers(dest="command", required=True) |
|
|
|
inspect_parser = subparsers.add_parser("inspect", help="Discover Codex sources on disk") |
|
inspect_parser.add_argument("--codex-home", default=None) |
|
inspect_parser.add_argument("--include-archived", action="store_true") |
|
inspect_parser.add_argument("--include-sqlite", action="store_true") |
|
inspect_parser.add_argument("--include-aux-jsonl", action="store_true") |
|
inspect_parser.set_defaults(func=inspect_command) |
|
|
|
ingest_parser = subparsers.add_parser("ingest", help="Parse Codex history and build/update the index") |
|
ingest_parser.add_argument("--codex-home", default=None) |
|
ingest_parser.add_argument("--out", default=DEFAULT_OUT_DIR) |
|
ingest_parser.add_argument("--embedding-model", default=DEFAULT_EMBED_MODEL) |
|
ingest_parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE) |
|
ingest_parser.add_argument("--max-chunk-chars", type=int, default=DEFAULT_MAX_CHUNK_CHARS) |
|
ingest_parser.add_argument("--overlap-chars", type=int, default=DEFAULT_OVERLAP_CHARS) |
|
ingest_parser.add_argument("--tool-output-preview-chars", type=int, default=DEFAULT_TOOL_OUTPUT_PREVIEW) |
|
ingest_parser.add_argument("--tool-output-chunk-chars", type=int, default=DEFAULT_TOOL_OUTPUT_CHUNK_CHARS) |
|
ingest_parser.add_argument("--include-archived", action="store_true") |
|
ingest_parser.add_argument("--include-sqlite", action="store_true") |
|
ingest_parser.add_argument("--include-aux-jsonl", action="store_true") |
|
ingest_parser.add_argument("--pgvector-dsn", default=None) |
|
ingest_parser.add_argument("--rebuild", action="store_true") |
|
ingest_parser.add_argument("--force", action="store_true") |
|
ingest_parser.set_defaults(func=ingest_command) |
|
|
|
query_parser = subparsers.add_parser("query", help="Search the local RAG index") |
|
query_parser.add_argument("query") |
|
query_parser.add_argument("--out", default=DEFAULT_OUT_DIR) |
|
query_parser.add_argument("--embedding-model", default=DEFAULT_EMBED_MODEL) |
|
query_parser.add_argument("--rerank-model", default=DEFAULT_RERANK_MODEL) |
|
query_parser.add_argument("--top-k", type=int, default=8) |
|
query_parser.add_argument("--dense-k", type=int, default=30) |
|
query_parser.add_argument("--sparse-k", type=int, default=30) |
|
query_parser.add_argument("--rerank-top-n", type=int, default=24) |
|
query_parser.add_argument("--dense-backend", choices=["faiss", "pgvector"], default="faiss") |
|
query_parser.add_argument("--pgvector-dsn", default=None) |
|
query_parser.add_argument("--no-rerank", action="store_true") |
|
query_parser.add_argument("--json", action="store_true") |
|
query_parser.add_argument("--context-only", action="store_true") |
|
query_parser.add_argument("--preview-chars", type=int, default=1200) |
|
query_parser.set_defaults(func=query_command) |
|
|
|
return parser |
|
|
|
|
|
def main(argv: Optional[Sequence[str]] = None) -> int: |
|
parser = build_parser() |
|
args = parser.parse_args(argv) |
|
return int(args.func(args)) |
|
|
|
|
|
if __name__ == "__main__": |
|
raise SystemExit(main()) |