-
-
Save ovsds/bad260d9fa92a11959f06648dda78aee to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| import datetime | |
| import functools | |
| import time | |
| import tracemalloc | |
| import typing | |
| import databases | |
| import faker | |
| import faker.providers as faker_providers | |
| import sqlalchemy.ext.asyncio as sqlalchemy_asyncio | |
| import sqlalchemy.orm as sqlalchemy_orm | |
| from sqlalchemy import Column, DateTime, Integer, String, func | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.future import select | |
| if typing.TYPE_CHECKING: | |
| AsyncSessionMaker = sqlalchemy_orm.sessionmaker[ | |
| sqlalchemy_asyncio.AsyncSession | |
| ] # pyright: ignore[reportGeneralTypeIssues] | |
| else: | |
| AsyncSessionMaker = sqlalchemy_orm.sessionmaker | |
| DB_URL = "postgresql+asyncpg://test_user:test_password@localhost:5432/test_db" | |
| TEST_SAMPLE_SIZE = 10_000 | |
| TEST_RETRIES_COUNT = 10_000 | |
| Base = declarative_base() | |
| class Test(Base): | |
| __tablename__ = "test_table" | |
| id = Column(Integer, primary_key=True) | |
| name = Column(String) | |
| date = Column(DateTime, server_default=func.now()) | |
| def profile(func: typing.Callable) -> typing.Callable: | |
| @functools.wraps(func) | |
| async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: | |
| tracemalloc.start() | |
| start = time.time() | |
| result = await func(*args, **kwargs) | |
| elapsed_time = (time.time() - start) * 1000 | |
| _, mem_peak = tracemalloc.get_traced_memory() | |
| tracemalloc.stop() | |
| print(f"{func.__name__}:\n max_memory: {mem_peak/1000_000:,.3f} mb\n exec time: {elapsed_time:,.3f} ms") | |
| return result | |
| return wrapper | |
| async def create_tables(engine: sqlalchemy_asyncio.AsyncEngine): | |
| async with engine.begin() as conn: | |
| await conn.run_sync(Base.metadata.drop_all) | |
| await conn.run_sync(Base.metadata.create_all) | |
| async def add_fake_data(session_maker: AsyncSessionMaker): | |
| fake = faker.Faker() | |
| fake.add_provider(faker_providers.date_time) | |
| test_data = [Test(name=fake.name(), date=fake.date_time()) for _ in range(TEST_SAMPLE_SIZE)] | |
| async with session_maker() as session: | |
| session.add_all(test_data) | |
| @profile | |
| async def test_sqlalchemy(session_maker: AsyncSessionMaker): | |
| query = select(Test).where(Test.date > datetime.datetime(2008, 1, 1)) | |
| async with session_maker() as session: | |
| for _ in range(TEST_RETRIES_COUNT): | |
| await session.execute(query) | |
| @profile | |
| async def test_databases(database: databases.Database): | |
| test_table = Test.__table__ | |
| query = test_table.select().where(test_table.c.date > datetime.datetime(2008, 1, 1)) | |
| for _ in range(TEST_RETRIES_COUNT): | |
| await database.fetch_all(query) | |
| async def main(): | |
| engine = sqlalchemy_asyncio.create_async_engine(DB_URL, convert_unicode=True) | |
| session_maker = sqlalchemy_orm.sessionmaker( | |
| autocommit=False, autoflush=False, bind=engine, class_=sqlalchemy_asyncio.AsyncSession | |
| ) | |
| await create_tables(engine) | |
| await add_fake_data(session_maker) | |
| database = databases.Database(DB_URL) | |
| await database.connect() | |
| await test_databases(database) | |
| await database.disconnect() | |
| await test_sqlalchemy(session_maker) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment