-
-
Save imankulov/4051b7805ad737ace7d8de3d3f934d6b to your computer and use it in GitHub Desktop.
| #!/usr/bin/env ipython -i | |
| import datetime | |
| import json | |
| from typing import Optional | |
| import sqlalchemy as sa | |
| from sqlalchemy.orm import declarative_base, sessionmaker | |
| from sqlalchemy.dialects.postgresql import JSONB | |
| from pydantic import BaseModel, Field, parse_obj_as | |
| from pydantic.json import pydantic_encoder | |
| # -------------------------------------------------------------------------------------- | |
| # Define pydantic-alchemy specific types (once per application) | |
| # -------------------------------------------------------------------------------------- | |
| class PydanticType(sa.types.TypeDecorator): | |
| """Pydantic type. | |
| SAVING: | |
| - Uses SQLAlchemy JSON type under the hood. | |
| - Acceps the pydantic model and converts it to a dict on save. | |
| - SQLAlchemy engine JSON-encodes the dict to a string. | |
| RETRIEVING: | |
| - Pulls the string from the database. | |
| - SQLAlchemy engine JSON-decodes the string to a dict. | |
| - Uses the dict to create a pydantic model. | |
| """ | |
| # If you work with PostgreSQL, you can consider using | |
| # sqlalchemy.dialects.postgresql.JSONB instead of a | |
| # generic sa.types.JSON | |
| # | |
| # Ref: https://www.postgresql.org/docs/13/datatype-json.html | |
| impl = sa.types.JSON | |
| def __init__(self, pydantic_type): | |
| super().__init__() | |
| self.pydantic_type = pydantic_type | |
| def load_dialect_impl(self, dialect): | |
| # Use JSONB for PostgreSQL and JSON for other databases. | |
| if dialect.name == "postgresql": | |
| return dialect.type_descriptor(JSONB()) | |
| else: | |
| return dialect.type_descriptor(sa.JSON()) | |
| def process_bind_param(self, value, dialect): | |
| return value.dict() if value else None | |
| # If you use FasAPI, you can replace the line above with their jsonable_encoder(). | |
| # E.g., | |
| # from fastapi.encoders import jsonable_encoder | |
| # return jsonable_encoder(value) if value else None | |
| def process_result_value(self, value, dialect): | |
| return parse_obj_as(self.pydantic_type, value) if value else None | |
| def json_serializer(*args, **kwargs) -> str: | |
| return json.dumps(*args, default=pydantic_encoder, **kwargs) | |
| # -------------------------------------------------------------------------------------- | |
| # Configure SQLAlchemy engine, session and declarative base (once per application) | |
| # The key is to define json_serializer while creating the engine. | |
| # -------------------------------------------------------------------------------------- | |
| engine = sa.create_engine("sqlite:///:memory:", json_serializer=json_serializer) | |
| Session = sessionmaker(bind=engine, expire_on_commit=False, future=True) | |
| Base = declarative_base() | |
| # -------------------------------------------------------------------------------------- | |
| # Define your Pydantic and SQLAlchemy models (as many as needed) | |
| # -------------------------------------------------------------------------------------- | |
| class UserSettings(BaseModel): | |
| notify_at: datetime.datetime = Field(default_factory=datetime.datetime.now) | |
| class User(Base): | |
| __tablename__ = "users" | |
| id: int = sa.Column(sa.Integer, primary_key=True) | |
| name: str = sa.Column(sa.String, doc="User name", comment="User name") | |
| settings: Optional[UserSettings] = sa.Column(PydanticType(UserSettings), nullable=True) | |
| # -------------------------------------------------------------------------------------- | |
| # Create tables (once per application) | |
| # -------------------------------------------------------------------------------------- | |
| Base.metadata.create_all(engine) | |
| # -------------------------------------------------------------------------------------- | |
| # Usage example (we use 2.0 querying style with selects) | |
| # Ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#querying-2-0-style | |
| # -------------------------------------------------------------------------------------- | |
| session = Session() | |
| user = User(name="user", settings=UserSettings()) | |
| session.add(user) | |
| session.commit() | |
| same_user = session.execute(sa.select(User)).scalars().first() |
Is it possible for SQLAlchemy to detect changes in PydanticType field? When I change the field of the pydantic model, I have to manually call function flag_modified to make SQLAlchemy flush the change.
@a1d4r I have the same problem on my end. Did you end up finding a solution ?
@a1d4r I have the same problem on my end. Did you end up finding a solution ?
Nope, I call flag_modified every time I change the model. For example:
from datetime import UTC, datetime
from sqlalchemy.orm.attributes import flag_modified
user.settings.notify_at = datetime.now(UTC)
flag_modified(user.settings, "notify_at")@a1d4r Alright, thank you for the quick response. I am working on something that might automate this if I am able to make it so. I want to avoid to do what you did
By the way, what implementation did you use for the PydanticJSONB ?
I did this:
class PydanticJSONB(TypeDecorator):
impl = JSONB
def __init__(self, model_type: Any, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_type = model_type
self._type = get_origin(model_type)
if self._type is list:
self._item_type = get_args(model_type)[0]
elif self._type is dict:
self._item_type = get_args(model_type)[1]
else:
self._item_type = model_type
self._adapter = TypeAdapter(self.model_type)
def process_bind_param(self, value: Any, dialect: Any) -> Any:
if value is None:
return None
if self._type is list:
if not isinstance(value, list):
raise TypeError(f"Expected list of {self._item_type}")
return [item.model_dump() if isinstance(item, self._item_type) else item for item in value]
elif self._type is dict:
if not isinstance(value, dict):
raise TypeError(f"Expected dict of {self._item_type}")
return {k: item.model_dump() if isinstance(item, self._item_type) else item for k, item in value.items()}
else:
if isinstance(value, self.model_type):
return value.model_dump()
return value
def process_result_value(self, value: Any, dialect: Any) -> Any:
if value is not None:
return self._adapter.validate_python(value)
if self._type is list:
return []
elif self._type is dict:
return {}
else:
return NoneMight release this and more stuff in some kind of sqlmodel-utils package on pypi one day (Because I use SQLModel, but it works with sqlalchemy as well)
@Seluj78 Here is my implementation:
https://gist.github.com/a1d4r/100b06239925a414446305c81433cc88
Basically, the same as the original one, but with typing.
Example:
data: Mapped[ReportData] = mapped_column(PydanticType(ReportData))@a1d4r I DID IT !
Here's the code. More changes might be needed to support more complex types (Like List[PydanticModel] or Dict[str, MyModel] are untested right now).
But it works ! It flags correctly the pydanticjsonb columns as changed !!
import pydantic
from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy.orm import ColumnProperty
from sqlalchemy.orm.attributes import flag_modified
# Where `PydanticJSONB` is my implementation, see previous comments
def flag_pydantic_changes(target):
inspector = inspect(target)
mapper = inspector.mapper
for attr in inspector.attrs:
key = attr.key
prop = mapper.attrs.get(key)
# Skip non-ColumnProperty attributes
if not isinstance(prop, ColumnProperty):
continue
# Check if any column in this property is PydanticJSONB
is_pydantic_jsonb = any(
isinstance(col.type, PydanticJSONB)
for col in prop.columns
)
if is_pydantic_jsonb:
hist = attr.history
original_dict = hist.unchanged[0] if hist.unchanged else None
if issubclass(attr.value.__class__, pydantic.BaseModel):
current_dict = attr.value.model_dump()
else:
current_dict = attr.value
if original_dict != current_dict:
flag_modified(target, key)
@event.listens_for(_BaseModel, "before_update")
def auto_flag_modified(mapper, connection, target):
flag_pydantic_changes(target)
MODELS = [
Users,
# Add your models here
]
for model in MODELS:
event.listen(model, "before_update", auto_flag_modified)I'm quite happy and I will keep working on this and I will probably end up publishing this to github at some point.
@Seluj78 Awesome work! I will check it later. By the way, feel free to contact me if you need any help.
To avoid a deprecation warning as of Pydantic 2.5 you will need to use the following
def process_result_value(self, value, dialect): return self.pydantic_type(**value) if value else None
Building on top of the comment above, you will need to use the following to avoid a deprecation warning as of Pydantic V2:
@override
def process_bind_param(
self,
value: "BaseModel | None",
dialect: "Dialect",
) -> "dict[str, Any] | None":
if value is None:
return None
if not isinstance(value, BaseModel):
raise TypeError(f'Value "{value!r}" is not a pydantic model.')
return value.model_dump(mode="json", exclude_unset=True)
@override
def process_result_value(
self,
value: "dict[str, Any] | None",
dialect: "Dialect",
) -> "BaseModel | None":
return self.pydantic_type(**value) if value else NoneI also added type hints and I'm checking if the value really is an instance of a BaseModel subclass because Python.
Model.dict() is deprecated and you can now use Model.dump(mode="json") to avoid setting up a custom json serializer when creating the engine.
Edit: I forked the gist and updated the code to use the current syntax and to add typehints.
got a problem
if I do
select(table.c.settings['a']), it still try to parse the a value as the whole pydantic model :(