Skip to content

Instantly share code, notes, and snippets.

@nathanjmcdougall
Last active December 22, 2024 01:36
Show Gist options
  • Select an option

  • Save nathanjmcdougall/cb00eac78fbf9d090c0b7605c8eaf409 to your computer and use it in GitHub Desktop.

Select an option

Save nathanjmcdougall/cb00eac78fbf9d090c0b7605c8eaf409 to your computer and use it in GitHub Desktop.
Mergeable pydantic models
from __future__ import annotations
from typing import Self, TypeVar, cast
import pandas as pd
from pydantic import BaseModel
class MergeIncompatibleError(Exception):
"""Raised when two models have conflicting attributes when trying to merge them."""
_T = TypeVar("_T")
class MergeableBaseModel(BaseModel):
"""Base class for pydantic BaseModel that can be deep-merged with the `|` operator.
For example:
```combined_model = model1 | model2```
"""
def __or__(self, other: Self) -> Self:
# No need to for a deep copy since merging will occur recursively.
model = self.model_copy()
for key in self.model_fields.keys():
setattr(
model,
key,
_merge_objects(
getattr(self, key),
getattr(other, key),
name=f"{self.__class__.__name__}.{key}",
),
)
return model
def _merge_objects(option1: _T, option2: _T, *, name: str) -> _T:
"""Merge two objects, raising on conflicts.
Supports merging of MergeableBaseModel and pandas DataFrames. Other types are
considered compatible if they are equal.
"""
if isinstance(option1, MergeableBaseModel) and isinstance(
option2, MergeableBaseModel
):
return cast(_T, option1 | option2)
if isinstance(option1, pd.DataFrame) and isinstance(option2, pd.DataFrame):
return cast(_T, _merge_dataframes(option1, option2, name=name))
if option1 is not None and option2 is not None:
eq_cmp = option1 == option2
if not isinstance(eq_cmp, bool):
try:
eq_cmp = all(eq_cmp)
except (TypeError, ValueError):
pass
if not eq_cmp:
msg = f"'{name}' values are not consistent: {option1} != {option2}"
raise MergeIncompatibleError(msg)
return option1 if option1 is not None else option2
def _merge_dataframes(
option1: pd.DataFrame, option2: pd.DataFrame, *, name: str | None = None
) -> pd.DataFrame:
"""Select columns from both dataframes, raising on column conflicts.
Both dataframes must have the same number of rows or they are incompatible.
Carefully note whether the index of the dataframes is important, as this function
does not attempt to merge based on the index - it simply concatenates columns.
If you need to merge based on the index, consider making a column non-option in
the DataFrame's `pandera.DataFrameModel` defintion, e.g. for this schema definition:
```
class MyModel(pa.DataFrameModel):
a: pa.typing.pandas.Series[int]
b: pa.typing.pandas.Series[float] | None = None
```
Then a is required and b is optional, which would be sensible if a were an ID
column.
Raises:
MergeIncompatibleError: If the dataframes have conflicting columns, or sizes.
"""
if len(option1) != len(option2):
ref = f"DataFrames '{name}'" if name else "DataFrames"
msg = (
f"{ref} to be merged have a different number of rows: "
f"{len(option1)} != {len(option2)}"
)
raise MergeIncompatibleError(msg)
df = option1.copy()
for col in option2.columns:
if col in df.columns and not df[col].equals(option2[col]):
msg = f"Columns '{col}' in DataFrames '{name}' have conflicting values"
raise MergeIncompatibleError(msg)
# In the future, we might use a .merge() method instead of simple column
# concatenation. But for now we will impose a stricter requirement which should
# hopefully be less error prone.
df[col] = option2[col]
return df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment