Last active
December 22, 2024 01:36
-
-
Save nathanjmcdougall/cb00eac78fbf9d090c0b7605c8eaf409 to your computer and use it in GitHub Desktop.
Mergeable pydantic models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| from typing import Self, TypeVar, cast | |
| import pandas as pd | |
| from pydantic import BaseModel | |
| class MergeIncompatibleError(Exception): | |
| """Raised when two models have conflicting attributes when trying to merge them.""" | |
| _T = TypeVar("_T") | |
| class MergeableBaseModel(BaseModel): | |
| """Base class for pydantic BaseModel that can be deep-merged with the `|` operator. | |
| For example: | |
| ```combined_model = model1 | model2``` | |
| """ | |
| def __or__(self, other: Self) -> Self: | |
| # No need to for a deep copy since merging will occur recursively. | |
| model = self.model_copy() | |
| for key in self.model_fields.keys(): | |
| setattr( | |
| model, | |
| key, | |
| _merge_objects( | |
| getattr(self, key), | |
| getattr(other, key), | |
| name=f"{self.__class__.__name__}.{key}", | |
| ), | |
| ) | |
| return model | |
| def _merge_objects(option1: _T, option2: _T, *, name: str) -> _T: | |
| """Merge two objects, raising on conflicts. | |
| Supports merging of MergeableBaseModel and pandas DataFrames. Other types are | |
| considered compatible if they are equal. | |
| """ | |
| if isinstance(option1, MergeableBaseModel) and isinstance( | |
| option2, MergeableBaseModel | |
| ): | |
| return cast(_T, option1 | option2) | |
| if isinstance(option1, pd.DataFrame) and isinstance(option2, pd.DataFrame): | |
| return cast(_T, _merge_dataframes(option1, option2, name=name)) | |
| if option1 is not None and option2 is not None: | |
| eq_cmp = option1 == option2 | |
| if not isinstance(eq_cmp, bool): | |
| try: | |
| eq_cmp = all(eq_cmp) | |
| except (TypeError, ValueError): | |
| pass | |
| if not eq_cmp: | |
| msg = f"'{name}' values are not consistent: {option1} != {option2}" | |
| raise MergeIncompatibleError(msg) | |
| return option1 if option1 is not None else option2 | |
| def _merge_dataframes( | |
| option1: pd.DataFrame, option2: pd.DataFrame, *, name: str | None = None | |
| ) -> pd.DataFrame: | |
| """Select columns from both dataframes, raising on column conflicts. | |
| Both dataframes must have the same number of rows or they are incompatible. | |
| Carefully note whether the index of the dataframes is important, as this function | |
| does not attempt to merge based on the index - it simply concatenates columns. | |
| If you need to merge based on the index, consider making a column non-option in | |
| the DataFrame's `pandera.DataFrameModel` defintion, e.g. for this schema definition: | |
| ``` | |
| class MyModel(pa.DataFrameModel): | |
| a: pa.typing.pandas.Series[int] | |
| b: pa.typing.pandas.Series[float] | None = None | |
| ``` | |
| Then a is required and b is optional, which would be sensible if a were an ID | |
| column. | |
| Raises: | |
| MergeIncompatibleError: If the dataframes have conflicting columns, or sizes. | |
| """ | |
| if len(option1) != len(option2): | |
| ref = f"DataFrames '{name}'" if name else "DataFrames" | |
| msg = ( | |
| f"{ref} to be merged have a different number of rows: " | |
| f"{len(option1)} != {len(option2)}" | |
| ) | |
| raise MergeIncompatibleError(msg) | |
| df = option1.copy() | |
| for col in option2.columns: | |
| if col in df.columns and not df[col].equals(option2[col]): | |
| msg = f"Columns '{col}' in DataFrames '{name}' have conflicting values" | |
| raise MergeIncompatibleError(msg) | |
| # In the future, we might use a .merge() method instead of simple column | |
| # concatenation. But for now we will impose a stricter requirement which should | |
| # hopefully be less error prone. | |
| df[col] = option2[col] | |
| return df |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment