Created
November 18, 2025 22:04
-
-
Save steverice/7cd34fe65d6bb50587a0a4360973e249 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Demonstration of the TypeSpecModel protocol | |
| When generating Python code for TypeSpec models, we allow for the choice of model implementation. | |
| The most compatible implementation is the dataclass implementation, which requires nothing outside of the stdlib. | |
| Other implementations are possible, such as an `attrs` implementation or a `pydantic` implementation. | |
| The only requirement is that the model implements the TypeSpecModel protocol. | |
| This allows surrounding code to deal with objects that implement the protocol, regardless of the underlying implementation. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Callable, Generic, Literal, Mapping, Protocol, Self, TypeVar, overload | |
| from contextlib import suppress | |
| from dataclasses import KW_ONLY, dataclass, field, fields, is_dataclass | |
| T = TypeVar("T") | |
| class TypeSpecModel(Protocol): | |
| """ | |
| The TypeSpecModel protocol | |
| Models must be able to do the following things above and beyond dataclasses: | |
| - model_dump: dump the model as a dict | |
| - model_copy: copy the model | |
| - model_fields_set: return the set of fields that have been set | |
| All three are to support TypeSpec's concept of model property "requiredness" as well as optionality. | |
| Models must track which fields were explicitly set (as opposed to being set to their default value). | |
| In addition, we need models to implement copy + dump with awareness of this tracking. | |
| Any functionality beyond this (such as validation or serialization) is up to the model implementation | |
| and goes beyond the scope of this protocol. | |
| """ | |
| def model_dump(self, *, exclude_unset: bool = True, exclude_none: bool = False) -> dict: ... | |
| def model_copy(self: T, *, update: Mapping[str, Any] | None = None) -> T: ... | |
| @property | |
| def model_fields_set(self) -> set[str]: ... | |
| class Omittable(Generic[T]): | |
| """ | |
| An Omittable descriptor | |
| The Omittable descriptor provides TypeSpec's "not required" functionality. | |
| It is a descriptor that can be used to mark a field as "not required". | |
| This is an independent concept from whether or not the field's value can be None. | |
| Omittable requires that a default value is provided for fields that are not required. | |
| """ | |
| class _OMITTED: | |
| """ | |
| Internal sentinel value to track omitted values. | |
| Not meant to be used externally. | |
| """ | |
| def __repr__(self): | |
| return "<omitted>" | |
| OMITTED = _OMITTED() | |
| def __init__(self, *, default: T | _OMITTED = OMITTED, default_factory: Callable[[], T] | None = None): | |
| if default is not self.OMITTED and default_factory is not None: | |
| raise TypeError("Cannot specify both default and default_factory") | |
| self._default = default | |
| self._default_factory = default_factory | |
| self.name: str | None = None | |
| self._val_key = self._set_key = "" | |
| def __set_name__(self, owner: type[object], name): | |
| self.name = name | |
| self._val_key = f"__{name}_val__" | |
| self._set_key = f"__{name}_is_set__" | |
| setattr(owner, "__has_omittable__", True) | |
| @overload | |
| def get_default_value(self, strict: Literal[True]) -> T: ... | |
| @overload | |
| def get_default_value(self, strict: Literal[False] = False) -> T | _OMITTED: ... | |
| def get_default_value(self, strict: bool = False) -> T | _OMITTED: | |
| if self._default is not self.OMITTED: | |
| return self._default | |
| if self._default_factory is not None: | |
| return self._default_factory() | |
| if strict: | |
| raise AttributeError(f"Omittable {self.name!r} has no default value") | |
| return self.OMITTED | |
| @property | |
| def default(self) -> T | _OMITTED: | |
| return self.get_default_value() | |
| @overload | |
| def __get__(self, instance: None, owner: type[object]) -> Omittable[T]: ... | |
| @overload | |
| def __get__(self, instance: object, owner: type[object]) -> T | _OMITTED: ... | |
| def __get__(self, instance: object | None, owner: type[object]) -> T | Omittable[T] | _OMITTED: | |
| if instance is None: | |
| return self | |
| if self.was_set(instance): | |
| return getattr(instance, self._val_key) | |
| return self.default | |
| def __set__(self, instance: object, value: T | Omittable[T]) -> None: | |
| if isinstance(value, Omittable): | |
| if value is self: | |
| setattr(instance, self._set_key, False) | |
| if hasattr(instance, self._val_key): | |
| delattr(instance, self._val_key) | |
| return | |
| raise TypeError(f"Cannot assign descriptor {value!r} to field {self.name!r}") | |
| setattr(instance, self._val_key, value) | |
| setattr(instance, self._set_key, True) | |
| # Provenance helpers | |
| def was_set(self, instance) -> bool: | |
| return bool(instance.__dict__.get(self._set_key, False)) | |
| class TypeSpecDataclass: | |
| """ | |
| A dataclass implementation of the TypeSpecModel protocol | |
| This implementation uses Omittable descriptors to track which fields have been set. | |
| """ | |
| def __new__(cls, *args, **kwargs): | |
| if not is_dataclass(cls): | |
| raise TypeError(f"TypeSpecDataclass should be used with dataclasses, not {cls!r}") | |
| if omittable := getattr(cls, "__has_omittable__", False) and getattr( | |
| object.__getattribute__(cls, "__dataclass_params__"), "frozen", False | |
| ): | |
| raise TypeError("Dataclasses using Omittable cannot be frozen") | |
| new = super().__new__(cls) | |
| with suppress(AttributeError): | |
| if omittable and object.__getattribute__(new, "__slots__"): | |
| raise TypeError("Dataclasses using Omittable cannot use slots") | |
| return new | |
| @property | |
| def model_fields_set(self) -> set[str]: | |
| assert is_dataclass(self) | |
| was_set = { | |
| name for name, attr in self.__class__.__dict__.items() if isinstance(attr, Omittable) and attr.was_set(self) | |
| } | |
| required = {f.name for f in fields(self) if not isinstance(getattr(self.__class__, f.name, None), Omittable)} | |
| return was_set | required | |
| def _was_set(self, name) -> bool: | |
| attr = getattr(self.__class__, name, None) | |
| return not isinstance(attr, Omittable) or attr.was_set(self) | |
| def model_dump(self, *, exclude_unset: bool = True, exclude_none: bool = False) -> dict[str, Any]: | |
| assert is_dataclass(self) | |
| return { | |
| f.name: getattr(self, f.name) | |
| for f in fields(self) | |
| if (not exclude_unset or self._was_set(f.name)) and (not exclude_none or getattr(self, f.name) is not None) | |
| } | |
| def model_copy(self, *, update: Mapping[str, Any] | None = None) -> Self: | |
| """ | |
| Create a copy of a dataclass instance while preserving Omittable 'unset' fields. | |
| Any keys in `update` are applied and considered 'set'. | |
| """ | |
| kwargs = self.model_dump(exclude_unset=True) | |
| if update: | |
| kwargs.update(update) | |
| return type(self)(**kwargs) | |
| ## Example generated code | |
| @dataclass | |
| class PersonWithNullableAddress(TypeSpecDataclass): | |
| """ | |
| Generated from | |
| model Person { | |
| age: int32; | |
| address: string | null; | |
| } | |
| """ | |
| age: int | |
| address: str | None # required, nullable, no default | |
| @dataclass | |
| class PersonWithOptionalAddress(TypeSpecDataclass): | |
| """ | |
| Generated from | |
| model Person { | |
| age: int32; | |
| address?: string; | |
| } | |
| """ | |
| age: int | |
| _: KW_ONLY # all Omittable fields should be keyword-only | |
| address: Omittable[str | None] = Omittable[str | None]() # omittable, non-nullable, no default | |
| @dataclass | |
| class PersonWithOptionalDefaultAddress(TypeSpecDataclass): | |
| """ | |
| Generated from | |
| model Person { | |
| age: int32; | |
| address?: string = "N/A"; | |
| } | |
| """ | |
| age: int | |
| _: KW_ONLY # all Omittable fields should be keyword-only | |
| address: Omittable[str] = Omittable[str](default="N/A") # omittable, non-nullable, default "N/A" | |
| @dataclass | |
| class PersonWithFieldParams(TypeSpecDataclass): | |
| """ | |
| Generated from | |
| model Person { | |
| age: int32; | |
| address?: string; | |
| } | |
| """ | |
| age: int | |
| _: KW_ONLY # all Omittable fields should be keyword-only | |
| address: Omittable[str | None] = field(default=Omittable[str | None](), repr=False) | |
| def examine_tsp_model(model: TypeSpecModel): | |
| print(model) | |
| if is_dataclass(model): | |
| print('fields:', fields(model)) | |
| print('set fields:', model.model_fields_set) | |
| print('dump:', model.model_dump()) | |
| print('dump without None:', model.model_dump(exclude_none=True)) | |
| print('dump with unset:', model.model_dump(exclude_unset=False)) | |
| print('copy:', model.model_copy()) | |
| print("\n") | |
| examine_tsp_model(PersonWithNullableAddress(11, address=None)) | |
| examine_tsp_model(PersonWithOptionalAddress(11)) | |
| examine_tsp_model(PersonWithOptionalAddress(11, address=None)) | |
| examine_tsp_model(PersonWithOptionalDefaultAddress(11)) | |
| examine_tsp_model(PersonWithFieldParams(11)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment