Skip to content

Instantly share code, notes, and snippets.

@steverice
Created November 18, 2025 22:04
Show Gist options
  • Select an option

  • Save steverice/7cd34fe65d6bb50587a0a4360973e249 to your computer and use it in GitHub Desktop.

Select an option

Save steverice/7cd34fe65d6bb50587a0a4360973e249 to your computer and use it in GitHub Desktop.
"""
Demonstration of the TypeSpecModel protocol
When generating Python code for TypeSpec models, we allow for the choice of model implementation.
The most compatible implementation is the dataclass implementation, which requires nothing outside of the stdlib.
Other implementations are possible, such as an `attrs` implementation or a `pydantic` implementation.
The only requirement is that the model implements the TypeSpecModel protocol.
This allows surrounding code to deal with objects that implement the protocol, regardless of the underlying implementation.
"""
from __future__ import annotations
from typing import Any, Callable, Generic, Literal, Mapping, Protocol, Self, TypeVar, overload
from contextlib import suppress
from dataclasses import KW_ONLY, dataclass, field, fields, is_dataclass
T = TypeVar("T")
class TypeSpecModel(Protocol):
"""
The TypeSpecModel protocol
Models must be able to do the following things above and beyond dataclasses:
- model_dump: dump the model as a dict
- model_copy: copy the model
- model_fields_set: return the set of fields that have been set
All three are to support TypeSpec's concept of model property "requiredness" as well as optionality.
Models must track which fields were explicitly set (as opposed to being set to their default value).
In addition, we need models to implement copy + dump with awareness of this tracking.
Any functionality beyond this (such as validation or serialization) is up to the model implementation
and goes beyond the scope of this protocol.
"""
def model_dump(self, *, exclude_unset: bool = True, exclude_none: bool = False) -> dict: ...
def model_copy(self: T, *, update: Mapping[str, Any] | None = None) -> T: ...
@property
def model_fields_set(self) -> set[str]: ...
class Omittable(Generic[T]):
"""
An Omittable descriptor
The Omittable descriptor provides TypeSpec's "not required" functionality.
It is a descriptor that can be used to mark a field as "not required".
This is an independent concept from whether or not the field's value can be None.
Omittable requires that a default value is provided for fields that are not required.
"""
class _OMITTED:
"""
Internal sentinel value to track omitted values.
Not meant to be used externally.
"""
def __repr__(self):
return "<omitted>"
OMITTED = _OMITTED()
def __init__(self, *, default: T | _OMITTED = OMITTED, default_factory: Callable[[], T] | None = None):
if default is not self.OMITTED and default_factory is not None:
raise TypeError("Cannot specify both default and default_factory")
self._default = default
self._default_factory = default_factory
self.name: str | None = None
self._val_key = self._set_key = ""
def __set_name__(self, owner: type[object], name):
self.name = name
self._val_key = f"__{name}_val__"
self._set_key = f"__{name}_is_set__"
setattr(owner, "__has_omittable__", True)
@overload
def get_default_value(self, strict: Literal[True]) -> T: ...
@overload
def get_default_value(self, strict: Literal[False] = False) -> T | _OMITTED: ...
def get_default_value(self, strict: bool = False) -> T | _OMITTED:
if self._default is not self.OMITTED:
return self._default
if self._default_factory is not None:
return self._default_factory()
if strict:
raise AttributeError(f"Omittable {self.name!r} has no default value")
return self.OMITTED
@property
def default(self) -> T | _OMITTED:
return self.get_default_value()
@overload
def __get__(self, instance: None, owner: type[object]) -> Omittable[T]: ...
@overload
def __get__(self, instance: object, owner: type[object]) -> T | _OMITTED: ...
def __get__(self, instance: object | None, owner: type[object]) -> T | Omittable[T] | _OMITTED:
if instance is None:
return self
if self.was_set(instance):
return getattr(instance, self._val_key)
return self.default
def __set__(self, instance: object, value: T | Omittable[T]) -> None:
if isinstance(value, Omittable):
if value is self:
setattr(instance, self._set_key, False)
if hasattr(instance, self._val_key):
delattr(instance, self._val_key)
return
raise TypeError(f"Cannot assign descriptor {value!r} to field {self.name!r}")
setattr(instance, self._val_key, value)
setattr(instance, self._set_key, True)
# Provenance helpers
def was_set(self, instance) -> bool:
return bool(instance.__dict__.get(self._set_key, False))
class TypeSpecDataclass:
"""
A dataclass implementation of the TypeSpecModel protocol
This implementation uses Omittable descriptors to track which fields have been set.
"""
def __new__(cls, *args, **kwargs):
if not is_dataclass(cls):
raise TypeError(f"TypeSpecDataclass should be used with dataclasses, not {cls!r}")
if omittable := getattr(cls, "__has_omittable__", False) and getattr(
object.__getattribute__(cls, "__dataclass_params__"), "frozen", False
):
raise TypeError("Dataclasses using Omittable cannot be frozen")
new = super().__new__(cls)
with suppress(AttributeError):
if omittable and object.__getattribute__(new, "__slots__"):
raise TypeError("Dataclasses using Omittable cannot use slots")
return new
@property
def model_fields_set(self) -> set[str]:
assert is_dataclass(self)
was_set = {
name for name, attr in self.__class__.__dict__.items() if isinstance(attr, Omittable) and attr.was_set(self)
}
required = {f.name for f in fields(self) if not isinstance(getattr(self.__class__, f.name, None), Omittable)}
return was_set | required
def _was_set(self, name) -> bool:
attr = getattr(self.__class__, name, None)
return not isinstance(attr, Omittable) or attr.was_set(self)
def model_dump(self, *, exclude_unset: bool = True, exclude_none: bool = False) -> dict[str, Any]:
assert is_dataclass(self)
return {
f.name: getattr(self, f.name)
for f in fields(self)
if (not exclude_unset or self._was_set(f.name)) and (not exclude_none or getattr(self, f.name) is not None)
}
def model_copy(self, *, update: Mapping[str, Any] | None = None) -> Self:
"""
Create a copy of a dataclass instance while preserving Omittable 'unset' fields.
Any keys in `update` are applied and considered 'set'.
"""
kwargs = self.model_dump(exclude_unset=True)
if update:
kwargs.update(update)
return type(self)(**kwargs)
## Example generated code
@dataclass
class PersonWithNullableAddress(TypeSpecDataclass):
"""
Generated from
model Person {
age: int32;
address: string | null;
}
"""
age: int
address: str | None # required, nullable, no default
@dataclass
class PersonWithOptionalAddress(TypeSpecDataclass):
"""
Generated from
model Person {
age: int32;
address?: string;
}
"""
age: int
_: KW_ONLY # all Omittable fields should be keyword-only
address: Omittable[str | None] = Omittable[str | None]() # omittable, non-nullable, no default
@dataclass
class PersonWithOptionalDefaultAddress(TypeSpecDataclass):
"""
Generated from
model Person {
age: int32;
address?: string = "N/A";
}
"""
age: int
_: KW_ONLY # all Omittable fields should be keyword-only
address: Omittable[str] = Omittable[str](default="N/A") # omittable, non-nullable, default "N/A"
@dataclass
class PersonWithFieldParams(TypeSpecDataclass):
"""
Generated from
model Person {
age: int32;
address?: string;
}
"""
age: int
_: KW_ONLY # all Omittable fields should be keyword-only
address: Omittable[str | None] = field(default=Omittable[str | None](), repr=False)
def examine_tsp_model(model: TypeSpecModel):
print(model)
if is_dataclass(model):
print('fields:', fields(model))
print('set fields:', model.model_fields_set)
print('dump:', model.model_dump())
print('dump without None:', model.model_dump(exclude_none=True))
print('dump with unset:', model.model_dump(exclude_unset=False))
print('copy:', model.model_copy())
print("\n")
examine_tsp_model(PersonWithNullableAddress(11, address=None))
examine_tsp_model(PersonWithOptionalAddress(11))
examine_tsp_model(PersonWithOptionalAddress(11, address=None))
examine_tsp_model(PersonWithOptionalDefaultAddress(11))
examine_tsp_model(PersonWithFieldParams(11))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment