I love types… maybe too much. It’s definitely my favorite thing about Julia.
I recently made the type checker much stricter on my Python projects trying to get better clarity about my Python code. But that means repeating over and over
y = f(x)
if not isinstance(y, ExpectedType):
raise TypeError("We shouldn't be here.")
Or sometimes I want to coerce to a type, so it’s like
y = f(x)
if not isinstance(y, ExpectedType):
y = some_fix_func(y)
if not isinstance(y, ExpectedType):
raise TypeError("We shouldn't be here.")
I set out to write a function that would capture this behavior and be transparent to the type checker and it is… it’s not simple to do it. With no type-level computation, no issubtype
on generics in at type check time, and UnionType
NOT BEING A TYPE (grrr…), there were a bunch of hurdles.
But now, if you write any Python and would like to insert nice type checks that work at check time and run time, have I got a hacky class for you!
from __future__ import annotations
from types import GenericAlias
from typing import Callable, Generic, TypeVar, overload
T = TypeVar("T")
X = TypeVar("X")
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
class _Typed(Generic[T]):
"""Construct a callable type checking function.
Instead
"""
def __init__(self, ts: tuple = ()) -> None:
self.ts = ts
@staticmethod
def _strip_generics(t: GenericAlias | type[T1]) -> type[T1]:
"""Get origin type of generic."""
if isinstance(t, GenericAlias):
return t.__origin__ # pyright: ignore[reportReturnType]
return t
@overload
def __getitem__(self, t: type[T1]) -> _Typed[T1]: ...
@overload
def __getitem__(self, t: tuple[type[T1], type[T2]]) -> _Typed[T1 | T2]: ...
@overload
def __getitem__(
self,
t: tuple[type[T1], type[T2], type[T3]],
) -> _Typed[T1 | T2 | T3]: ...
def __getitem__(self, t):
if not isinstance(t, tuple):
t = (t,)
ts = tuple(self._strip_generics(_t) for _t in t)
return _Typed(ts)
@overload
def __call__(
self,
x: T | X,
/,
*,
recovery: Callable[[X], T] | None = None,
err_msg: str = "",
) -> T: ...
@overload
def __call__(
self,
x: X,
/,
*,
recovery: Callable[[X], T] | None = None,
err_msg: str = "",
) -> T: ...
def __call__(
self,
x,
/,
*,
recovery=None,
err_msg="",
) -> T:
"""Enforce type constraints on `x` with parameters of `typed`.
`typed[T](x)` will ensure that x is a T.
Kwargs:
- recovery (Optional[callable]): a function to call on `x` if it is not
the correct type.
- err_msg (Optional[str]): a message to prepend to the type error message.
"""
for t in self.ts:
if isinstance(x, t):
return x
if recovery is not None:
return self(
recovery(x),
err_msg=err_msg,
)
if len(err_msg) > 0:
err_msg += "\n"
type_str = " | ".join(self.ts)
err_msg += f"Expected type {type_str}, but got type {type(x)}"
raise TypeError(err_msg)
typed = _Typed()
With this class, we can do:
y = typed[ExpectedType](f(x), recovery=some_fix_func)
And now the type checker is convinced that y is ExpectedType
.