-
-
Save imankulov/4051b7805ad737ace7d8de3d3f934d6b to your computer and use it in GitHub Desktop.
| #!/usr/bin/env ipython -i | |
| import datetime | |
| import json | |
| from typing import Optional | |
| import sqlalchemy as sa | |
| from sqlalchemy.orm import declarative_base, sessionmaker | |
| from sqlalchemy.dialects.postgresql import JSONB | |
| from pydantic import BaseModel, Field, parse_obj_as | |
| from pydantic.json import pydantic_encoder | |
| # -------------------------------------------------------------------------------------- | |
| # Define pydantic-alchemy specific types (once per application) | |
| # -------------------------------------------------------------------------------------- | |
| class PydanticType(sa.types.TypeDecorator): | |
| """Pydantic type. | |
| SAVING: | |
| - Uses SQLAlchemy JSON type under the hood. | |
| - Acceps the pydantic model and converts it to a dict on save. | |
| - SQLAlchemy engine JSON-encodes the dict to a string. | |
| RETRIEVING: | |
| - Pulls the string from the database. | |
| - SQLAlchemy engine JSON-decodes the string to a dict. | |
| - Uses the dict to create a pydantic model. | |
| """ | |
| # If you work with PostgreSQL, you can consider using | |
| # sqlalchemy.dialects.postgresql.JSONB instead of a | |
| # generic sa.types.JSON | |
| # | |
| # Ref: https://www.postgresql.org/docs/13/datatype-json.html | |
| impl = sa.types.JSON | |
| def __init__(self, pydantic_type): | |
| super().__init__() | |
| self.pydantic_type = pydantic_type | |
| def load_dialect_impl(self, dialect): | |
| # Use JSONB for PostgreSQL and JSON for other databases. | |
| if dialect.name == "postgresql": | |
| return dialect.type_descriptor(JSONB()) | |
| else: | |
| return dialect.type_descriptor(sa.JSON()) | |
| def process_bind_param(self, value, dialect): | |
| return value.dict() if value else None | |
| # If you use FasAPI, you can replace the line above with their jsonable_encoder(). | |
| # E.g., | |
| # from fastapi.encoders import jsonable_encoder | |
| # return jsonable_encoder(value) if value else None | |
| def process_result_value(self, value, dialect): | |
| return parse_obj_as(self.pydantic_type, value) if value else None | |
| def json_serializer(*args, **kwargs) -> str: | |
| return json.dumps(*args, default=pydantic_encoder, **kwargs) | |
| # -------------------------------------------------------------------------------------- | |
| # Configure SQLAlchemy engine, session and declarative base (once per application) | |
| # The key is to define json_serializer while creating the engine. | |
| # -------------------------------------------------------------------------------------- | |
| engine = sa.create_engine("sqlite:///:memory:", json_serializer=json_serializer) | |
| Session = sessionmaker(bind=engine, expire_on_commit=False, future=True) | |
| Base = declarative_base() | |
| # -------------------------------------------------------------------------------------- | |
| # Define your Pydantic and SQLAlchemy models (as many as needed) | |
| # -------------------------------------------------------------------------------------- | |
| class UserSettings(BaseModel): | |
| notify_at: datetime.datetime = Field(default_factory=datetime.datetime.now) | |
| class User(Base): | |
| __tablename__ = "users" | |
| id: int = sa.Column(sa.Integer, primary_key=True) | |
| name: str = sa.Column(sa.String, doc="User name", comment="User name") | |
| settings: Optional[UserSettings] = sa.Column(PydanticType(UserSettings), nullable=True) | |
| # -------------------------------------------------------------------------------------- | |
| # Create tables (once per application) | |
| # -------------------------------------------------------------------------------------- | |
| Base.metadata.create_all(engine) | |
| # -------------------------------------------------------------------------------------- | |
| # Usage example (we use 2.0 querying style with selects) | |
| # Ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#querying-2-0-style | |
| # -------------------------------------------------------------------------------------- | |
| session = Session() | |
| user = User(name="user", settings=UserSettings()) | |
| session.add(user) | |
| session.commit() | |
| same_user = session.execute(sa.select(User)).scalars().first() |
@a1d4r I DID IT !
Here's the code. More changes might be needed to support more complex types (Like List[PydanticModel] or Dict[str, MyModel] are untested right now).
But it works ! It flags correctly the pydanticjsonb columns as changed !!
import pydantic
from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy.orm import ColumnProperty
from sqlalchemy.orm.attributes import flag_modified
# Where `PydanticJSONB` is my implementation, see previous comments
def flag_pydantic_changes(target):
inspector = inspect(target)
mapper = inspector.mapper
for attr in inspector.attrs:
key = attr.key
prop = mapper.attrs.get(key)
# Skip non-ColumnProperty attributes
if not isinstance(prop, ColumnProperty):
continue
# Check if any column in this property is PydanticJSONB
is_pydantic_jsonb = any(
isinstance(col.type, PydanticJSONB)
for col in prop.columns
)
if is_pydantic_jsonb:
hist = attr.history
original_dict = hist.unchanged[0] if hist.unchanged else None
if issubclass(attr.value.__class__, pydantic.BaseModel):
current_dict = attr.value.model_dump()
else:
current_dict = attr.value
if original_dict != current_dict:
flag_modified(target, key)
@event.listens_for(_BaseModel, "before_update")
def auto_flag_modified(mapper, connection, target):
flag_pydantic_changes(target)
MODELS = [
Users,
# Add your models here
]
for model in MODELS:
event.listen(model, "before_update", auto_flag_modified)I'm quite happy and I will keep working on this and I will probably end up publishing this to github at some point.
@Seluj78 Awesome work! I will check it later. By the way, feel free to contact me if you need any help.
To avoid a deprecation warning as of Pydantic 2.5 you will need to use the following
def process_result_value(self, value, dialect): return self.pydantic_type(**value) if value else None
Building on top of the comment above, you will need to use the following to avoid a deprecation warning as of Pydantic V2:
@override
def process_bind_param(
self,
value: "BaseModel | None",
dialect: "Dialect",
) -> "dict[str, Any] | None":
if value is None:
return None
if not isinstance(value, BaseModel):
raise TypeError(f'Value "{value!r}" is not a pydantic model.')
return value.model_dump(mode="json", exclude_unset=True)
@override
def process_result_value(
self,
value: "dict[str, Any] | None",
dialect: "Dialect",
) -> "BaseModel | None":
return self.pydantic_type(**value) if value else NoneI also added type hints and I'm checking if the value really is an instance of a BaseModel subclass because Python.
Model.dict() is deprecated and you can now use Model.dump(mode="json") to avoid setting up a custom json serializer when creating the engine.
Edit: I forked the gist and updated the code to use the current syntax and to add typehints.
@Seluj78 Here is my implementation:
https://gist.github.com/a1d4r/100b06239925a414446305c81433cc88
Basically, the same as the original one, but with typing.
Example: