Skip to content

Instantly share code, notes, and snippets.

@jkeifer
Created September 11, 2025 19:53
Show Gist options
  • Select an option

  • Save jkeifer/e5eb6ed5b5b7f2de539353b2c65b6a7b to your computer and use it in GitHub Desktop.

Select an option

Save jkeifer/e5eb6ed5b5b7f2de539353b2c65b6a7b to your computer and use it in GitHub Desktop.
pgTap with pytest
from __future__ import annotations
import os
import shutil
import tempfile
from collections.abc import Iterator
from contextlib import AsyncExitStack, suppress
from pathlib import Path
import asyncpg
import pytest
from dbami.db import DB
from dbami.util import random_name, syncrun
PGTAP_DB_NAME = random_name("myapp_test")
@pytest.fixture(scope="session")
def pg_dump(pytestconfig) -> Iterator[str]:
env_pgd = os.getenv("MYAPP_PG_DUMP", "pg_dump")
if shutil.which(env_pgd):
yield env_pgd
return
# we don't have pg_dump on the path
# fallback to a tmp script
with tempfile.TemporaryDirectory() as d:
pgd = Path(d).joinpath("pg_dump")
pgd.write_text(
f"""#!/bin/sh
cd "{pytestconfig.rootpath}"
docker compose exec postgres pg_dump "$@"
""",
)
pgd.chmod(0o755)
yield str(pgd)
@pytest.fixture(scope="session")
def pgtap_db():
return PGTAP_DB_NAME
def pytest_configure(config) -> str:
# this really shoudl be a custom DB subclass
# with project and schema overridden:
#
# from pathlib import Path
#
# from dbami.db import DB
#
#
# class MyappDatabase(DB):
# def __init__(self):
# super().__init__(
# project=Path(__file__).parent,
# schema_version_table="myapp.schema_version",
# )
#
db = DB()
async def setup():
async with db.get_db_connection(database=PGTAP_DB_NAME) as conn:
await conn.execute("CREATE SCHEMA tap;")
await conn.execute("CREATE EXTENSION pgtap SCHEMA tap;")
await db.load_schema(conn=conn)
syncrun(db.create_database(PGTAP_DB_NAME))
syncrun(setup())
return PGTAP_DB_NAME
def pytest_unconfigure(config) -> None:
# again, custom subclass
db = DB()
with suppress(asyncpg.InvalidCatalogNameError):
syncrun(db.drop_database(PGTAP_DB_NAME))
def pytest_collect_file(parent, file_path) -> SqlFile | None:
if file_path.suffix == ".sql" and file_path.name.startswith("test"):
return SqlFile.from_parent(parent, path=file_path)
return None
async def run_test(sql: str) -> None:
import sqlparse
import tap.line
from tap.parser import Parser
async with AsyncExitStack() as stack:
conn = await stack.enter_async_context(
# custom subclass
DB().get_db_connection(database=PGTAP_DB_NAME),
)
transaction = conn.transaction()
await transaction.start()
try:
for statement in sqlparse.split(sql):
result = await conn.fetchval(statement)
for parsed in Parser().parse_text(result):
if isinstance(parsed, tap.line.Result):
if not parsed.ok:
raise SqlError(parsed.description, statement)
elif isinstance(parsed, tap.line.Diagnostic):
raise PgTapDiagnosticError(parsed.text)
elif isinstance(parsed, tap.line.Plan):
pass
else:
raise TypeError(
f"Unhandled tap type '{parsed.category}': {result}",
)
finally:
await transaction.rollback()
class SqlFile(pytest.File):
def collect(self) -> Iterator[SqlItem]:
yield SqlItem.from_parent(self, name=self.path.stem)
class SqlItem(pytest.Item):
def runtest(self) -> None:
syncrun(run_test(self.path.read_text()))
def repr_failure(self, excinfo, style=None):
"""Called when self.runtest() raises an exception."""
if isinstance(excinfo.value, PgTapError):
self.add_report_section("output", "pgTap", str(excinfo.value))
return ""
return super().repr_failure(excinfo)
def reportinfo(self) -> tuple[Path, None, str]:
return self.path, None, f"{self.name}"
class PgTapError(Exception):
pass
class SqlError(PgTapError):
def __init__(self, description, statement):
import textwrap
super().__init__(
"""pgTap test failure!
{}
from test statement
{}
""".format(
description,
textwrap.indent(statement, " " * 4),
),
)
class PgTapDiagnosticError(PgTapError):
def __init__(self, description):
super().__init__(
f"""pgTap error!
{description}
""",
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment