Skip to content

Instantly share code, notes, and snippets.

@eladn
Created October 18, 2023 10:56
Show Gist options
  • Select an option

  • Save eladn/bfe28766561dc2862706a5339f604744 to your computer and use it in GitHub Desktop.

Select an option

Save eladn/bfe28766561dc2862706a5339f604744 to your computer and use it in GitHub Desktop.
Python function for recursively traversing structured objects for performing a merge of two given objects
__author__ = "Elad Nachmias"
__email__ = "[email protected]"
__date__ = "2023-10-18"
import copy
from typing import Any, Tuple
import dataclasses
import numpy as np
def recursively_structured_merge(
from_obj: Any, into_obj: Any,
allow_calling_merge_shallowly: bool = False,
primitives_merge_operator: str = 'add',
primitives_merge_ignore_none_src: bool = True,
contained_in_key: Tuple[Tuple[str, str], ...] = ()) -> Any:
"""
Recursively merge (in-place) a structured source object `from_obj` into a target one `into_obj`.
:param from_obj: The source object to be merged into the target `into_obj`
:param into_obj: The target object to be updated in-place to contain information from `from_obj`
:param allow_calling_merge_shallowly: Set to False to avoid calling the object's merge() method, in case where this aux function
is already called from the merge() method of that object
:param primitives_merge_operator: Policy for merging primitive values. Supported operations: add, min, max, take_from_merge_src,
take_from_merge_src, bool_or, bool_add, bitwise_or, bitwise_and.
:param primitives_merge_ignore_none_src: For primitives, ignore the primitives merging policy, and leave the dst as-is if src is None
without further checks.
:param contained_in_key: For informative error messages in recursive calls on traversed inner structures.
:return: The merged object. Same object `into_obj` for non-primitive mutable types that are modified in-place. New object for
primitive scalars like int/float.
"""
if isinstance(from_obj, dict):
assert isinstance(into_obj, dict)
for from_item_key, from_item_value in from_obj.items():
if from_item_key not in into_obj:
merged_sub_item = copy.deepcopy(from_item_value)
else:
merged_sub_item = recursively_structured_merge(
from_obj=from_item_value, into_obj=into_obj[from_item_key],
allow_calling_merge_shallowly=True,
primitives_merge_operator=primitives_merge_operator,
primitives_merge_ignore_none_src=primitives_merge_ignore_none_src,
contained_in_key=contained_in_key + (('dict_key', from_item_key),))
into_obj[from_item_key] = merged_sub_item
return into_obj
elif hasattr(into_obj, 'merge') and allow_calling_merge_shallowly:
# Note that if some inner object (encountered during the recursive traversing of the structured input object) have its own
# adhoc `merge()` method implementation, then it will be called rather than keeping recursively calling the current aux
# function recursively. Sometimes, the implementation of such adhoc `merge()` methods can call this function explicitly.
# This can cause infinite recursion calls. Passing `allow_calling_merge_shallowly=False` param is for avoiding calling the
# object's `merge()`` method in case where this aux function is already called directly from the implementation of the
# object's `merge()`` method. The idea of this function is to have a common implementation that can be used for the adhoc
# `merge()` methods implementations of various classes.
into_obj.merge(from_obj)
return into_obj
# It's important that the `merge()` check will be before data-classes.
elif dataclasses.is_dataclass(from_obj):
assert dataclasses.is_dataclass(into_obj)
for field in dataclasses.fields(from_obj):
metadata = {} if field.metadata is None else field.metadata
field_overridden_primitives_merge_operator = metadata.get('primitives_merge_operator', primitives_merge_operator)
field_overridden_primitives_merge_ignore_none_src = \
metadata.get('primitives_merge_ignore_none_src', primitives_merge_ignore_none_src)
merged_field_item = recursively_structured_merge(
from_obj=getattr(from_obj, field.name),
into_obj=getattr(into_obj, field.name),
allow_calling_merge_shallowly=True,
primitives_merge_operator=field_overridden_primitives_merge_operator,
primitives_merge_ignore_none_src=field_overridden_primitives_merge_ignore_none_src,
contained_in_key=contained_in_key + (('dataclass_field', f'{from_obj.__class__.__name__}.{field.name}'),))
setattr(into_obj, field.name, merged_field_item)
return into_obj
elif primitives_merge_ignore_none_src and from_obj is None:
return into_obj
elif is_scalar(from_obj):
if primitives_merge_operator == 'take_from_merge_src':
return from_obj
elif primitives_merge_operator == 'take_from_merge_dst':
return into_obj
elif primitives_merge_operator == 'min':
return min(from_obj, into_obj)
elif primitives_merge_operator == 'max':
return max(from_obj, into_obj)
elif primitives_merge_operator == 'add':
return from_obj + into_obj
elif primitives_merge_operator == 'bool_or':
return from_obj or into_obj
elif primitives_merge_operator == 'bool_and':
return from_obj and into_obj
elif primitives_merge_operator == 'bitwise_or':
return from_obj | into_obj
elif primitives_merge_operator == 'bitwise_and':
return from_obj & into_obj
else:
raise ValueError(f'Unsupported primitive merge operator `{primitives_merge_operator}`.')
else:
if from_obj != into_obj:
raise ValueError(f'Encountered 2 different non-merge-able values under key: `{contained_in_key}`: '
f'from_obj={from_obj}, into_obj={into_obj}.')
return into_obj
def is_scalar(val: Any):
if isinstance(val, np.ndarray) and len(val) == 1:
val = val.item()
return np.isscalar(val)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment