Skip to content

Instantly share code, notes, and snippets.

@schwehr
Created October 30, 2025 15:47
Show Gist options
  • Select an option

  • Save schwehr/4f6ffac7a9a4dbdea900c24fb87917e1 to your computer and use it in GitHub Desktop.

Select an option

Save schwehr/4f6ffac7a9a4dbdea900c24fb87917e1 to your computer and use it in GitHub Desktop.
A first attempt at text/rich mode for eerepr - does not work - https://github.com/aazuspan/eerepr/issues/68
# SPDX-License-Identifier: MIT
from __future__ import annotations
import html
from functools import _lru_cache_wrapper, lru_cache
from typing import Any, Literal, Union
from warnings import warn
import ee
try:
import rich
from rich.tree import Tree
RICH_INSTALLED = True
except ImportError:
RICH_INSTALLED = False
Tree = None
from eerepr.config import Config
from eerepr.html import convert_to_html, escape_object
from eerepr.style import CSS
if RICH_INSTALLED:
from eerepr.rich import convert_to_rich, escape_object_rich
REPR_HTML = "_repr_html_"
RICH_CONSOLE = "__rich_console__"
EEObject = Union[ee.Element, ee.ComputedObject]
# Track which repr methods have been set so we can overwrite them if needed.
reprs_set: set[type] = set()
options = Config()
def _attach_html_repr(cls: type, repr_func: Any) -> None:
"""Add a HTML repr method to an EE class. Only overwrite the method if it was set by
this function.
"""
if not hasattr(cls, REPR_HTML) or cls in reprs_set:
reprs_set.add(cls)
setattr(cls, REPR_HTML, repr_func)
def _attach_rich_repr(cls: type, repr_func: Any) -> None:
"""Add a rich console repr method to an EE class. Only overwrite the method if it was set by
this function.
"""
if not hasattr(cls, RICH_CONSOLE) or cls in reprs_set:
reprs_set.add(cls)
setattr(cls, RICH_CONSOLE, repr_func)
def _is_nondeterministic(obj: EEObject) -> bool:
"""Check if an object returns nondeterministic results which would break caching.
Currently, this only tests for the case of `ee.List.shuffle(seed=False)`.
"""
invocation = obj.serialize()
shuffled = "List.shuffle" in invocation
false_seed = '"seed": {"constantValue": false}' in invocation
return shuffled and false_seed
@lru_cache(maxsize=None)
def _repr_html_(obj: EEObject) -> str:
"""Generate an HTML representation of an EE object."""
# Escape all strings in object info to prevent injection
info = escape_object(obj.getInfo())
body = convert_to_html(info)
return (
"<div>"
f"<style>{CSS}</style>"
"<div class='eerepr'>"
f"<ul>{body}</ul>"
"</div>"
"</div>"
)
def _uncached_repr_html_(obj: EEObject) -> str:
"""Generate an HTML representation of an EE object without caching."""
if isinstance(_repr_html_, _lru_cache_wrapper):
return _repr_html_.__wrapped__(obj)
return _repr_html_(obj)
def _ee_repr(obj: EEObject) -> str:
"""Handle errors and conditional caching for _repr_html_."""
repr_func = _uncached_repr_html_ if _is_nondeterministic(obj) else _repr_html_
try:
rep = repr_func(obj)
except ee.EEException as e:
if options.on_error == "raise":
raise e from None
warn(
f"Getting info failed with: '{e}'. Falling back to string repr.",
stacklevel=2,
)
return f"<pre>{html.escape(repr(obj))}</pre>"
mbs = len(rep) / 1e6
if mbs > options.max_repr_mbs:
warn(
message=(
f"HTML repr size ({mbs:.0f}mB) exceeds maximum"
f" ({options.max_repr_mbs:.0f}mB), falling back to string repr. You"
" can set `eerepr.options.max_repr_mbs` to print larger objects,"
" but this may cause performance issues."
),
stacklevel=2,
)
return f"<pre>{html.escape(repr(obj))}</pre>"
return rep
if RICH_INSTALLED:
@lru_cache(maxsize=None)
def _rich_repr_(obj: EEObject) -> Tree:
"""Generate a rich representation of an EE object."""
info = escape_object_rich(obj.getInfo())
return convert_to_rich(info)
def _uncached_rich_repr_(obj: EEObject) -> Tree:
"""Generate a rich representation of an EE object without caching."""
if isinstance(_rich_repr_, _lru_cache_wrapper):
return _rich_repr_.__wrapped__(obj)
return _rich_repr_(obj)
def _ee_rich_repr(obj: EEObject):
"""Handle errors and conditional caching for _rich_repr_."""
repr_func = _uncached_rich_repr_ if _is_nondeterministic(obj) else _rich_repr_
try:
rep = repr_func(obj)
except ee.EEException as e:
if options.on_error == "raise":
raise e from None
warn(
f"Getting info failed with: '{e}'. Falling back to string repr.",
stacklevel=2,
)
return f"<pre>{html.escape(repr(obj))}</pre>"
return rep
def __rich_console__(self, console, options):
yield _ee_rich_repr(self)
else:
_rich_repr_ = None
_uncached_rich_repr_ = None
_ee_rich_repr = None
__rich_console__ = None
def initialize(
max_cache_size: int | None = None,
max_repr_mbs: int = 100,
on_error: Literal["warn", "raise"] = "warn",
) -> None:
"""Attach HTML repr methods to EE objects and initialize a cache.
Re-running this function will reset the cache.
Parameters
----------
max_cache_size : int, optional
The maximum number of EE objects to cache. If None, the cache size is unlimited.
Set to 0 to disable caching.
max_repr_mbs : int, default 100
The maximum HTML repr size to display, in MBs. Setting this too high may freeze
the client when printing very large objects. When a repr exceeds this size, the
string repr will be displayed instead along with a warning.
on_error : {'warn', 'raise'}, default 'warn'
Whether to raise an error or display a warning when an error occurs fetching
Earth Engine data.
"""
global _repr_html_
if RICH_INSTALLED:
global _rich_repr_
options.update(
max_cache_size=max_cache_size,
max_repr_mbs=max_repr_mbs,
on_error=on_error,
)
if isinstance(_repr_html_, _lru_cache_wrapper):
_repr_html_ = _repr_html_.__wrapped__ # type: ignore
if RICH_INSTALLED and _rich_repr_ and isinstance(_rich_repr_, _lru_cache_wrapper):
_rich_repr_ = _rich_repr_.__wrapped__
if max_cache_size != 0:
_repr_html_ = lru_cache(maxsize=options.max_cache_size)(_repr_html_)
if RICH_INSTALLED and _rich_repr_:
_rich_repr_ = lru_cache(maxsize=options.max_cache_size)(_rich_repr_)
for cls in [ee.Element, ee.ComputedObject]:
_attach_html_repr(cls, _ee_repr)
if RICH_INSTALLED:
_attach_rich_repr(cls, __rich_console__)
def reset():
"""Remove HTML repr methods added by eerepr to EE objects and reset the cache."""
for cls in reprs_set:
if hasattr(cls, REPR_HTML):
delattr(cls, REPR_HTML)
if RICH_INSTALLED and hasattr(cls, RICH_CONSOLE):
delattr(cls, RICH_CONSOLE)
reprs_set.clear()
if isinstance(_repr_html_, _lru_cache_wrapper):
_repr_html_.cache_clear()
if RICH_INSTALLED and _rich_repr_ and isinstance(_rich_repr_, _lru_cache_wrapper):
_rich_repr_.cache_clear()
# SPDX-License-Identifier: MIT
from __future__ import annotations
from datetime import datetime, timezone
from itertools import chain
from typing import Any, Hashable
from rich.tree import Tree
from rich.markup import escape
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
PROPERTY_PRIORITY = [
"type",
"id",
"version",
"bands",
"columns",
"geometry",
"properties",
]
def escape_object_rich(obj: Any) -> Any:
"""Recursively escape Rich markup in strings in a Python object."""
if isinstance(obj, str):
return escape(obj)
if isinstance(obj, list):
return [escape_object_rich(element) for element in obj]
if isinstance(obj, dict):
return {escape_object_rich(key): escape_object_rich(value) for key, value in obj.items()}
return obj
def get_rich_renderable(obj: Any, key: str | int | None = None):
"""Returns a rich renderable for obj."""
if isinstance(obj, dict):
return dict_to_rich_tree(obj, key)
if isinstance(obj, list):
return list_to_rich_tree(obj, key)
# Leaf node
if key is not None:
return f"[deep_sky_blue1]{key}[/deep_sky_blue1]: {obj}"
return str(obj)
def dict_to_rich_tree(obj: dict, key: str | int | None = None) -> Tree:
label = _build_label(obj)
header = f"[deep_sky_blue1]{key}[/deep_sky_blue1]: {label}" if key is not None else label
tree = Tree(header)
sorted_keys = [k for k in PROPERTY_PRIORITY if k in obj] + sorted(
[k for k in obj if k not in PROPERTY_PRIORITY]
)
for k in sorted_keys:
tree.add(get_rich_renderable(obj[k], key=k))
return tree
def list_to_rich_tree(obj: list, key: str | int | None = None) -> Tree:
n = len(obj)
label = f"List ({n} {'element' if n == 1 else 'elements'})"
header = f"[deep_sky_blue1]{key}[/deep_sky_blue1]: {label}" if key is not None else label
tree = Tree(header)
for i, item in enumerate(obj):
tree.add(get_rich_renderable(item, key=i))
return tree
def convert_to_rich(obj: Any) -> Tree:
"""Converts a Python object to a Rich Tree."""
renderable = get_rich_renderable(obj)
if isinstance(renderable, Tree):
return renderable
return Tree(str(renderable))
### All _build_*_label functions from html.py needed below ###
def _build_image_label(obj: dict) -> str:
obj_id = obj.get("id")
id_label = f" {obj_id}" if obj_id else ""
n = len(obj.get("bands", []))
noun = "band" if n == 1 else "bands"
return f"Image{id_label} ({n} {noun})"
def _build_imagecollection_label(obj: dict) -> str:
obj_id = obj.get("id")
id_label = f" {obj_id} " if obj_id else ""
n = len(obj.get("features", []))
noun = "element" if n == 1 else "elements"
return f"ImageCollection{id_label} ({n} {noun})"
def _build_date_label(obj: dict) -> str:
dt = datetime.fromtimestamp(obj.get("value", 0) / 1000, tz=timezone.utc)
return f"Date ({dt.strftime(DATE_FORMAT)})"
def _build_feature_label(obj: dict) -> str:
n = len(obj.get("properties", []))
try:
geom_type = obj["geometry"]["type"]
except (TypeError, KeyError):
geom_type = None
type_label = f"{geom_type}, " if geom_type is not None else ""
noun = "property" if n == 1 else "properties"
return f"Feature ({type_label}{n} {noun})"
def _build_featurecollection_label(obj: dict) -> str:
obj_id = obj.get("id")
id_label = f" {obj_id} " if obj_id else ""
ncols = len(obj.get("columns", []))
nfeats = len(obj.get("features", []))
col_noun = "column" if ncols == 1 else "columns"
feat_noun = "element" if nfeats == 1 else "elements"
return f"FeatureCollection{id_label} ({nfeats} {feat_noun}, {ncols} {col_noun})"
def _build_point_label(obj: dict) -> str:
x, y = obj.get("coordinates", [None, None])
xstr = f"{x:.2f}" if isinstance(x, (int, float)) else "NaN"
ystr = f"{y:.2f}" if isinstance(x, (int, float)) else "NaN"
return f"Point ({xstr}, {ystr})"
def _build_polygon_label(obj: dict) -> str:
n = len(obj.get("coordinates", [[]])[0])
noun = "vertex" if n == 1 else "vertices"
return f"Polygon ({n} {noun})"
def _build_multipolygon_label(obj: dict) -> str:
coords = obj.get("coordinates", [])[0]
flat = list(chain.from_iterable(coords))
n = len(flat)
noun = "vertex" if n == 1 else "vertices"
return f"MultiPolygon ({n} {noun})"
def _build_multipoint_label(obj: dict) -> str:
"""This also works for LineString and LinearRing."""
obj_type = obj.get("type")
n = len(obj.get("coordinates", []))
noun = "vertex" if n == 1 else "vertices"
return f"{obj_type} ({n} {noun})"
def _build_pixeltype_label(obj: dict) -> str:
prec = obj.get("precision", "")
minimum = str(obj.get("min", ""))
maximum = str(obj.get("max", ""))
val_range = f"[{minimum}, {maximum}]"
type_ranges = {
"[-128, 127]": "signed int8",
"[0, 255]": "unsigned int8",
"[-32768, 32767]": "signed int16",
"[0, 65535]": "unsigned int16",
"[-2147483648, 2147483647]": "signed int32",
"[0, 4294967295]": "unsigned int32",
"[-9.223372036854776e+18, 9.223372036854776e+18]": "signed int64",
}
if prec in ["double", "float"]:
return prec
try:
return type_ranges[val_range]
except KeyError:
return f"{prec} ∈ {val_range}"
def _build_band_label(obj: dict) -> str:
band_id = obj.get("id", "")
if band_id:
band_id = f'"{band_id}"'
dtype = _build_pixeltype_label(obj.get("data_type", {}))
dims = obj.get("dimensions")
dimensions = f"{dims[0]}x{dims[1]} px" if dims else ""
crs = obj.get("crs", "")
return ", ".join(filter(None, [band_id, dtype, crs, dimensions]))
def _build_daterange_label(obj: dict) -> str:
start, end = obj.get("dates", [0, 0])
dt_start = datetime.fromtimestamp(start / 1000, tz=timezone.utc)
dt_end = datetime.fromtimestamp(end / 1000, tz=timezone.utc)
return (
f"DateRange [{dt_start.strftime(DATE_FORMAT)}, {dt_end.strftime(DATE_FORMAT)}]"
)
def _build_object_label(obj: dict) -> str:
"""Build a label for a generic JS object."""
n = len(obj.keys())
noun = "property" if n == 1 else "properties"
return f"Object ({n} {noun})"
def _build_typed_label(obj: dict) -> str:
"""Build a label for an object with an unrecognized type."""
obj_type = obj.get("type")
obj_id = obj.get("id", "")
id_label = f" {obj_id} " if obj_id else ""
return f"{obj_type}{id_label}"
def _build_label(obj: dict) -> str:
"""Take an info dictionary from Earth Engine and return a header label.
These labels attempt to be consistent with outputs from the Code Editor.
"""
labelers = {
"Image": _build_image_label,
"ImageCollection": _build_imagecollection_label,
"Date": _build_date_label,
"Feature": _build_feature_label,
"FeatureCollection": _build_featurecollection_label,
"Point": _build_point_label,
"MultiPoint": _build_multipoint_label,
"LineString": _build_multipoint_label,
"LinearRing": _build_multipoint_label,
"Polygon": _build_polygon_label,
"MultiPolygon": _build_multipolygon_label,
"PixelType": _build_pixeltype_label,
"DateRange": _build_daterange_label,
}
obj_type = obj.get("type", "")
if not obj_type:
if "data_type" in obj and "id" in obj:
return _build_band_label(obj)
return _build_object_label(obj)
try:
return labelers[obj_type](obj)
except KeyError:
return _build_typed_label(obj)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment