Skip to content

Instantly share code, notes, and snippets.

@ovsds
Forked from KenKi0/test_alchemy
Last active September 11, 2022 15:32
Show Gist options
  • Select an option

  • Save ovsds/bad260d9fa92a11959f06648dda78aee to your computer and use it in GitHub Desktop.

Select an option

Save ovsds/bad260d9fa92a11959f06648dda78aee to your computer and use it in GitHub Desktop.
import asyncio
import datetime
import functools
import time
import tracemalloc
import typing
import databases
import faker
import faker.providers as faker_providers
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
import sqlalchemy.orm as sqlalchemy_orm
from sqlalchemy import Column, DateTime, Integer, String, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.future import select
if typing.TYPE_CHECKING:
AsyncSessionMaker = sqlalchemy_orm.sessionmaker[
sqlalchemy_asyncio.AsyncSession
] # pyright: ignore[reportGeneralTypeIssues]
else:
AsyncSessionMaker = sqlalchemy_orm.sessionmaker
DB_URL = "postgresql+asyncpg://test_user:test_password@localhost:5432/test_db"
TEST_SAMPLE_SIZE = 10_000
TEST_RETRIES_COUNT = 10_000
Base = declarative_base()
class Test(Base):
__tablename__ = "test_table"
id = Column(Integer, primary_key=True)
name = Column(String)
date = Column(DateTime, server_default=func.now())
def profile(func: typing.Callable) -> typing.Callable:
@functools.wraps(func)
async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
tracemalloc.start()
start = time.time()
result = await func(*args, **kwargs)
elapsed_time = (time.time() - start) * 1000
_, mem_peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
print(f"{func.__name__}:\n max_memory: {mem_peak/1000_000:,.3f} mb\n exec time: {elapsed_time:,.3f} ms")
return result
return wrapper
async def create_tables(engine: sqlalchemy_asyncio.AsyncEngine):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async def add_fake_data(session_maker: AsyncSessionMaker):
fake = faker.Faker()
fake.add_provider(faker_providers.date_time)
test_data = [Test(name=fake.name(), date=fake.date_time()) for _ in range(TEST_SAMPLE_SIZE)]
async with session_maker() as session:
session.add_all(test_data)
@profile
async def test_sqlalchemy(session_maker: AsyncSessionMaker):
query = select(Test).where(Test.date > datetime.datetime(2008, 1, 1))
async with session_maker() as session:
for _ in range(TEST_RETRIES_COUNT):
await session.execute(query)
@profile
async def test_databases(database: databases.Database):
test_table = Test.__table__
query = test_table.select().where(test_table.c.date > datetime.datetime(2008, 1, 1))
for _ in range(TEST_RETRIES_COUNT):
await database.fetch_all(query)
async def main():
engine = sqlalchemy_asyncio.create_async_engine(DB_URL, convert_unicode=True)
session_maker = sqlalchemy_orm.sessionmaker(
autocommit=False, autoflush=False, bind=engine, class_=sqlalchemy_asyncio.AsyncSession
)
await create_tables(engine)
await add_fake_data(session_maker)
database = databases.Database(DB_URL)
await database.connect()
await test_databases(database)
await database.disconnect()
await test_sqlalchemy(session_maker)
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment