Created
October 18, 2023 10:56
-
-
Save eladn/bfe28766561dc2862706a5339f604744 to your computer and use it in GitHub Desktop.
Python function for recursively traversing structured objects for performing a merge of two given objects
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| __author__ = "Elad Nachmias" | |
| __email__ = "[email protected]" | |
| __date__ = "2023-10-18" | |
| import copy | |
| from typing import Any, Tuple | |
| import dataclasses | |
| import numpy as np | |
| def recursively_structured_merge( | |
| from_obj: Any, into_obj: Any, | |
| allow_calling_merge_shallowly: bool = False, | |
| primitives_merge_operator: str = 'add', | |
| primitives_merge_ignore_none_src: bool = True, | |
| contained_in_key: Tuple[Tuple[str, str], ...] = ()) -> Any: | |
| """ | |
| Recursively merge (in-place) a structured source object `from_obj` into a target one `into_obj`. | |
| :param from_obj: The source object to be merged into the target `into_obj` | |
| :param into_obj: The target object to be updated in-place to contain information from `from_obj` | |
| :param allow_calling_merge_shallowly: Set to False to avoid calling the object's merge() method, in case where this aux function | |
| is already called from the merge() method of that object | |
| :param primitives_merge_operator: Policy for merging primitive values. Supported operations: add, min, max, take_from_merge_src, | |
| take_from_merge_src, bool_or, bool_add, bitwise_or, bitwise_and. | |
| :param primitives_merge_ignore_none_src: For primitives, ignore the primitives merging policy, and leave the dst as-is if src is None | |
| without further checks. | |
| :param contained_in_key: For informative error messages in recursive calls on traversed inner structures. | |
| :return: The merged object. Same object `into_obj` for non-primitive mutable types that are modified in-place. New object for | |
| primitive scalars like int/float. | |
| """ | |
| if isinstance(from_obj, dict): | |
| assert isinstance(into_obj, dict) | |
| for from_item_key, from_item_value in from_obj.items(): | |
| if from_item_key not in into_obj: | |
| merged_sub_item = copy.deepcopy(from_item_value) | |
| else: | |
| merged_sub_item = recursively_structured_merge( | |
| from_obj=from_item_value, into_obj=into_obj[from_item_key], | |
| allow_calling_merge_shallowly=True, | |
| primitives_merge_operator=primitives_merge_operator, | |
| primitives_merge_ignore_none_src=primitives_merge_ignore_none_src, | |
| contained_in_key=contained_in_key + (('dict_key', from_item_key),)) | |
| into_obj[from_item_key] = merged_sub_item | |
| return into_obj | |
| elif hasattr(into_obj, 'merge') and allow_calling_merge_shallowly: | |
| # Note that if some inner object (encountered during the recursive traversing of the structured input object) have its own | |
| # adhoc `merge()` method implementation, then it will be called rather than keeping recursively calling the current aux | |
| # function recursively. Sometimes, the implementation of such adhoc `merge()` methods can call this function explicitly. | |
| # This can cause infinite recursion calls. Passing `allow_calling_merge_shallowly=False` param is for avoiding calling the | |
| # object's `merge()`` method in case where this aux function is already called directly from the implementation of the | |
| # object's `merge()`` method. The idea of this function is to have a common implementation that can be used for the adhoc | |
| # `merge()` methods implementations of various classes. | |
| into_obj.merge(from_obj) | |
| return into_obj | |
| # It's important that the `merge()` check will be before data-classes. | |
| elif dataclasses.is_dataclass(from_obj): | |
| assert dataclasses.is_dataclass(into_obj) | |
| for field in dataclasses.fields(from_obj): | |
| metadata = {} if field.metadata is None else field.metadata | |
| field_overridden_primitives_merge_operator = metadata.get('primitives_merge_operator', primitives_merge_operator) | |
| field_overridden_primitives_merge_ignore_none_src = \ | |
| metadata.get('primitives_merge_ignore_none_src', primitives_merge_ignore_none_src) | |
| merged_field_item = recursively_structured_merge( | |
| from_obj=getattr(from_obj, field.name), | |
| into_obj=getattr(into_obj, field.name), | |
| allow_calling_merge_shallowly=True, | |
| primitives_merge_operator=field_overridden_primitives_merge_operator, | |
| primitives_merge_ignore_none_src=field_overridden_primitives_merge_ignore_none_src, | |
| contained_in_key=contained_in_key + (('dataclass_field', f'{from_obj.__class__.__name__}.{field.name}'),)) | |
| setattr(into_obj, field.name, merged_field_item) | |
| return into_obj | |
| elif primitives_merge_ignore_none_src and from_obj is None: | |
| return into_obj | |
| elif is_scalar(from_obj): | |
| if primitives_merge_operator == 'take_from_merge_src': | |
| return from_obj | |
| elif primitives_merge_operator == 'take_from_merge_dst': | |
| return into_obj | |
| elif primitives_merge_operator == 'min': | |
| return min(from_obj, into_obj) | |
| elif primitives_merge_operator == 'max': | |
| return max(from_obj, into_obj) | |
| elif primitives_merge_operator == 'add': | |
| return from_obj + into_obj | |
| elif primitives_merge_operator == 'bool_or': | |
| return from_obj or into_obj | |
| elif primitives_merge_operator == 'bool_and': | |
| return from_obj and into_obj | |
| elif primitives_merge_operator == 'bitwise_or': | |
| return from_obj | into_obj | |
| elif primitives_merge_operator == 'bitwise_and': | |
| return from_obj & into_obj | |
| else: | |
| raise ValueError(f'Unsupported primitive merge operator `{primitives_merge_operator}`.') | |
| else: | |
| if from_obj != into_obj: | |
| raise ValueError(f'Encountered 2 different non-merge-able values under key: `{contained_in_key}`: ' | |
| f'from_obj={from_obj}, into_obj={into_obj}.') | |
| return into_obj | |
| def is_scalar(val: Any): | |
| if isinstance(val, np.ndarray) and len(val) == 1: | |
| val = val.item() | |
| return np.isscalar(val) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment