Created
September 11, 2025 19:53
-
-
Save jkeifer/e5eb6ed5b5b7f2de539353b2c65b6a7b to your computer and use it in GitHub Desktop.
pgTap with pytest
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| import os | |
| import shutil | |
| import tempfile | |
| from collections.abc import Iterator | |
| from contextlib import AsyncExitStack, suppress | |
| from pathlib import Path | |
| import asyncpg | |
| import pytest | |
| from dbami.db import DB | |
| from dbami.util import random_name, syncrun | |
| PGTAP_DB_NAME = random_name("myapp_test") | |
| @pytest.fixture(scope="session") | |
| def pg_dump(pytestconfig) -> Iterator[str]: | |
| env_pgd = os.getenv("MYAPP_PG_DUMP", "pg_dump") | |
| if shutil.which(env_pgd): | |
| yield env_pgd | |
| return | |
| # we don't have pg_dump on the path | |
| # fallback to a tmp script | |
| with tempfile.TemporaryDirectory() as d: | |
| pgd = Path(d).joinpath("pg_dump") | |
| pgd.write_text( | |
| f"""#!/bin/sh | |
| cd "{pytestconfig.rootpath}" | |
| docker compose exec postgres pg_dump "$@" | |
| """, | |
| ) | |
| pgd.chmod(0o755) | |
| yield str(pgd) | |
| @pytest.fixture(scope="session") | |
| def pgtap_db(): | |
| return PGTAP_DB_NAME | |
| def pytest_configure(config) -> str: | |
| # this really shoudl be a custom DB subclass | |
| # with project and schema overridden: | |
| # | |
| # from pathlib import Path | |
| # | |
| # from dbami.db import DB | |
| # | |
| # | |
| # class MyappDatabase(DB): | |
| # def __init__(self): | |
| # super().__init__( | |
| # project=Path(__file__).parent, | |
| # schema_version_table="myapp.schema_version", | |
| # ) | |
| # | |
| db = DB() | |
| async def setup(): | |
| async with db.get_db_connection(database=PGTAP_DB_NAME) as conn: | |
| await conn.execute("CREATE SCHEMA tap;") | |
| await conn.execute("CREATE EXTENSION pgtap SCHEMA tap;") | |
| await db.load_schema(conn=conn) | |
| syncrun(db.create_database(PGTAP_DB_NAME)) | |
| syncrun(setup()) | |
| return PGTAP_DB_NAME | |
| def pytest_unconfigure(config) -> None: | |
| # again, custom subclass | |
| db = DB() | |
| with suppress(asyncpg.InvalidCatalogNameError): | |
| syncrun(db.drop_database(PGTAP_DB_NAME)) | |
| def pytest_collect_file(parent, file_path) -> SqlFile | None: | |
| if file_path.suffix == ".sql" and file_path.name.startswith("test"): | |
| return SqlFile.from_parent(parent, path=file_path) | |
| return None | |
| async def run_test(sql: str) -> None: | |
| import sqlparse | |
| import tap.line | |
| from tap.parser import Parser | |
| async with AsyncExitStack() as stack: | |
| conn = await stack.enter_async_context( | |
| # custom subclass | |
| DB().get_db_connection(database=PGTAP_DB_NAME), | |
| ) | |
| transaction = conn.transaction() | |
| await transaction.start() | |
| try: | |
| for statement in sqlparse.split(sql): | |
| result = await conn.fetchval(statement) | |
| for parsed in Parser().parse_text(result): | |
| if isinstance(parsed, tap.line.Result): | |
| if not parsed.ok: | |
| raise SqlError(parsed.description, statement) | |
| elif isinstance(parsed, tap.line.Diagnostic): | |
| raise PgTapDiagnosticError(parsed.text) | |
| elif isinstance(parsed, tap.line.Plan): | |
| pass | |
| else: | |
| raise TypeError( | |
| f"Unhandled tap type '{parsed.category}': {result}", | |
| ) | |
| finally: | |
| await transaction.rollback() | |
| class SqlFile(pytest.File): | |
| def collect(self) -> Iterator[SqlItem]: | |
| yield SqlItem.from_parent(self, name=self.path.stem) | |
| class SqlItem(pytest.Item): | |
| def runtest(self) -> None: | |
| syncrun(run_test(self.path.read_text())) | |
| def repr_failure(self, excinfo, style=None): | |
| """Called when self.runtest() raises an exception.""" | |
| if isinstance(excinfo.value, PgTapError): | |
| self.add_report_section("output", "pgTap", str(excinfo.value)) | |
| return "" | |
| return super().repr_failure(excinfo) | |
| def reportinfo(self) -> tuple[Path, None, str]: | |
| return self.path, None, f"{self.name}" | |
| class PgTapError(Exception): | |
| pass | |
| class SqlError(PgTapError): | |
| def __init__(self, description, statement): | |
| import textwrap | |
| super().__init__( | |
| """pgTap test failure! | |
| {} | |
| from test statement | |
| {} | |
| """.format( | |
| description, | |
| textwrap.indent(statement, " " * 4), | |
| ), | |
| ) | |
| class PgTapDiagnosticError(PgTapError): | |
| def __init__(self, description): | |
| super().__init__( | |
| f"""pgTap error! | |
| {description} | |
| """, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment