Initial commit

This commit is contained in:
vortex 2024-02-26 00:17:10 +05:30
parent bc49523c99
commit 44d606cce7
1929 changed files with 612166 additions and 0 deletions

7
dist/ba_data/python/efro/__init__.py vendored Normal file
View file

@ -0,0 +1,7 @@
# Released under the MIT License. See LICENSE for details.
#
"""Common bits of functionality shared between all efro projects.
Things in here should be hardened, highly type-safe, and well-covered by unit
tests since they are widely used in live client and server code.
"""

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

366
dist/ba_data/python/efro/call.py vendored Normal file
View file

@ -0,0 +1,366 @@
# Released under the MIT License. See LICENSE for details.
#
"""Call related functionality shared between all efro components."""
from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar, Generic, Callable, cast
import functools
if TYPE_CHECKING:
from typing import Any, overload
CT = TypeVar('CT', bound=Callable)
class _CallbackCall(Generic[CT]):
"""Descriptor for exposing a call with a type defined by a TypeVar."""
def __get__(self, obj: Any, type_in: Any = None) -> CT:
return cast(CT, None)
class CallbackSet(Generic[CT]):
"""Wrangles callbacks for a particular event in a type-safe manner."""
# In the type-checker's eyes, our 'run' attr is a CallbackCall which
# returns a callable with the type we were created with. This lets us
# type-check our run calls. (Is there another way to expose a function
# with a signature defined by a generic?..)
# At runtime, run() simply passes its args verbatim to its registered
# callbacks; no types are checked.
if TYPE_CHECKING:
run: _CallbackCall[CT] = _CallbackCall()
else:
def run(self, *args, **keywds):
"""Run all callbacks."""
print('HELLO FROM RUN', *args, **keywds)
def __init__(self) -> None:
print('CallbackSet()')
def __del__(self) -> None:
print('~CallbackSet()')
def add(self, call: CT) -> None:
"""Add a callback to be run."""
print('Would add call', call)
# Define Call() which can be used in type-checking call-wrappers that behave
# similarly to functools.partial (in that they take a callable and some
# positional arguments to be passed to it).
# In type-checking land, We define several different _CallXArg classes
# corresponding to different argument counts and define Call() as an
# overloaded function which returns one of them based on how many args are
# passed.
# To use this, simply assign your call type to this Call for type checking:
# Example:
# class _MyCallWrapper:
# <runtime class defined here>
# if TYPE_CHECKING:
# MyCallWrapper = efro.call.Call
# else:
# MyCallWrapper = _MyCallWrapper
# Note that this setup currently only works with positional arguments; if you
# would like to pass args via keyword you can wrap a lambda or local function
# which takes keyword args and converts to a call containing keywords.
if TYPE_CHECKING:
In1T = TypeVar('In1T')
In2T = TypeVar('In2T')
In3T = TypeVar('In3T')
In4T = TypeVar('In4T')
In5T = TypeVar('In5T')
In6T = TypeVar('In6T')
In7T = TypeVar('In7T')
OutT = TypeVar('OutT')
class _CallNoArgs(Generic[OutT]):
"""Single argument variant of call wrapper."""
def __init__(self, _call: Callable[[], OutT]):
...
def __call__(self) -> OutT:
...
class _Call1Arg(Generic[In1T, OutT]):
"""Single argument variant of call wrapper."""
def __init__(self, _call: Callable[[In1T], OutT]):
...
def __call__(self, _arg1: In1T) -> OutT:
...
class _Call2Args(Generic[In1T, In2T, OutT]):
"""Two argument variant of call wrapper"""
def __init__(self, _call: Callable[[In1T, In2T], OutT]):
...
def __call__(self, _arg1: In1T, _arg2: In2T) -> OutT:
...
class _Call3Args(Generic[In1T, In2T, In3T, OutT]):
"""Three argument variant of call wrapper"""
def __init__(self, _call: Callable[[In1T, In2T, In3T], OutT]):
...
def __call__(self, _arg1: In1T, _arg2: In2T, _arg3: In3T) -> OutT:
...
class _Call4Args(Generic[In1T, In2T, In3T, In4T, OutT]):
"""Four argument variant of call wrapper"""
def __init__(self, _call: Callable[[In1T, In2T, In3T, In4T], OutT]):
...
def __call__(
self, _arg1: In1T, _arg2: In2T, _arg3: In3T, _arg4: In4T
) -> OutT:
...
class _Call5Args(Generic[In1T, In2T, In3T, In4T, In5T, OutT]):
"""Five argument variant of call wrapper"""
def __init__(
self, _call: Callable[[In1T, In2T, In3T, In4T, In5T], OutT]
):
...
def __call__(
self,
_arg1: In1T,
_arg2: In2T,
_arg3: In3T,
_arg4: In4T,
_arg5: In5T,
) -> OutT:
...
class _Call6Args(Generic[In1T, In2T, In3T, In4T, In5T, In6T, OutT]):
"""Six argument variant of call wrapper"""
def __init__(
self, _call: Callable[[In1T, In2T, In3T, In4T, In5T, In6T], OutT]
):
...
def __call__(
self,
_arg1: In1T,
_arg2: In2T,
_arg3: In3T,
_arg4: In4T,
_arg5: In5T,
_arg6: In6T,
) -> OutT:
...
class _Call7Args(Generic[In1T, In2T, In3T, In4T, In5T, In6T, In7T, OutT]):
"""Seven argument variant of call wrapper"""
def __init__(
self,
_call: Callable[[In1T, In2T, In3T, In4T, In5T, In6T, In7T], OutT],
):
...
def __call__(
self,
_arg1: In1T,
_arg2: In2T,
_arg3: In3T,
_arg4: In4T,
_arg5: In5T,
_arg6: In6T,
_arg7: In7T,
) -> OutT:
...
# No arg call; no args bundled.
# noinspection PyPep8Naming
@overload
def Call(call: Callable[[], OutT]) -> _CallNoArgs[OutT]:
...
# 1 arg call; 1 arg bundled.
# noinspection PyPep8Naming
@overload
def Call(call: Callable[[In1T], OutT], arg1: In1T) -> _CallNoArgs[OutT]:
...
# 1 arg call; no args bundled.
# noinspection PyPep8Naming
@overload
def Call(call: Callable[[In1T], OutT]) -> _Call1Arg[In1T, OutT]:
...
# 2 arg call; 2 args bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T], OutT], arg1: In1T, arg2: In2T
) -> _CallNoArgs[OutT]:
...
# 2 arg call; 1 arg bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T], OutT], arg1: In1T
) -> _Call1Arg[In2T, OutT]:
...
# 2 arg call; no args bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T], OutT]
) -> _Call2Args[In1T, In2T, OutT]:
...
# 3 arg call; 3 args bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T, In3T], OutT],
arg1: In1T,
arg2: In2T,
arg3: In3T,
) -> _CallNoArgs[OutT]:
...
# 3 arg call; 2 args bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T, In3T], OutT], arg1: In1T, arg2: In2T
) -> _Call1Arg[In3T, OutT]:
...
# 3 arg call; 1 arg bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T, In3T], OutT], arg1: In1T
) -> _Call2Args[In2T, In3T, OutT]:
...
# 3 arg call; no args bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T, In3T], OutT]
) -> _Call3Args[In1T, In2T, In3T, OutT]:
...
# 4 arg call; 4 args bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T, In3T, In4T], OutT],
arg1: In1T,
arg2: In2T,
arg3: In3T,
arg4: In4T,
) -> _CallNoArgs[OutT]:
...
# 4 arg call; 3 args bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T, In3T, In4T], OutT],
arg1: In1T,
arg2: In2T,
arg3: In3T,
) -> _Call1Arg[In4T, OutT]:
...
# 4 arg call; 2 args bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T, In3T, In4T], OutT],
arg1: In1T,
arg2: In2T,
) -> _Call2Args[In3T, In4T, OutT]:
...
# 4 arg call; 1 arg bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T, In3T, In4T], OutT],
arg1: In1T,
) -> _Call3Args[In2T, In3T, In4T, OutT]:
...
# 4 arg call; no args bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T, In3T, In4T], OutT],
) -> _Call4Args[In1T, In2T, In3T, In4T, OutT]:
...
# 5 arg call; 5 args bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T, In3T, In4T, In5T], OutT],
arg1: In1T,
arg2: In2T,
arg3: In3T,
arg4: In4T,
arg5: In5T,
) -> _CallNoArgs[OutT]:
...
# 6 arg call; 6 args bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T, In3T, In4T, In5T, In6T], OutT],
arg1: In1T,
arg2: In2T,
arg3: In3T,
arg4: In4T,
arg5: In5T,
arg6: In6T,
) -> _CallNoArgs[OutT]:
...
# 7 arg call; 7 args bundled.
# noinspection PyPep8Naming
@overload
def Call(
call: Callable[[In1T, In2T, In3T, In4T, In5T, In6T, In7T], OutT],
arg1: In1T,
arg2: In2T,
arg3: In3T,
arg4: In4T,
arg5: In5T,
arg6: In6T,
arg7: In7T,
) -> _CallNoArgs[OutT]:
...
# noinspection PyPep8Naming
def Call(*_args: Any, **_keywds: Any) -> Any:
...
# (Type-safe Partial)
# A convenient wrapper around functools.partial which adds type-safety
# (though it does not support keyword arguments).
tpartial = Call
else:
tpartial = functools.partial

49
dist/ba_data/python/efro/cloudshell.py vendored Normal file
View file

@ -0,0 +1,49 @@
# Released under the MIT License. See LICENSE for details.
#
"""My nifty ssh/mosh/rsync mishmash."""
from __future__ import annotations
from enum import Enum
from dataclasses import dataclass
from efro.dataclassio import ioprepped
class LockType(Enum):
"""Types of locks that can be acquired on a host."""
HOST = 'host'
WORKSPACE = 'workspace'
PYCHARM = 'pycharm'
CLION = 'clion'
@ioprepped
@dataclass
class HostConfig:
"""Config for a cloud machine to run commands on.
precommand, if set, will be run before the passed commands.
Note that it is not run in interactive mode (when no command is given).
"""
address: str | None = None
user: str = 'ubuntu'
port: int = 22
mosh_port: int | None = None
mosh_server_path: str | None = None
mosh_shell: str = 'sh'
workspaces_root: str = '/home/${USER}/cloudshell_workspaces'
sync_perms: bool = True
precommand: str | None = None
managed: bool = False
idle_minutes: int = 5
can_sudo_reboot: bool = False
max_sessions: int = 3
reboot_wait_seconds: int = 20
reboot_attempts: int = 1
def resolved_workspaces_root(self) -> str:
"""Returns workspaces_root with standard substitutions."""
return self.workspaces_root.replace('${USER}', self.user)

295
dist/ba_data/python/efro/dataclasses.py vendored Normal file
View file

@ -0,0 +1,295 @@
# Released under the MIT License. See LICENSE for details.
#
"""Custom functionality for dealing with dataclasses."""
# Note: We do lots of comparing of exact types here which is normally
# frowned upon (stuff like isinstance() is usually encouraged).
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
import dataclasses
import inspect
from enum import Enum
from typing import TYPE_CHECKING, TypeVar, Generic
from efro.util import enum_by_value
if TYPE_CHECKING:
from typing import Any, Dict, Type, Tuple, Optional
T = TypeVar('T')
SIMPLE_NAMES_TO_TYPES: Dict[str, Type] = {
'int': int,
'bool': bool,
'str': str,
'float': float,
}
SIMPLE_TYPES_TO_NAMES = {tp: nm for nm, tp in SIMPLE_NAMES_TO_TYPES.items()}
def dataclass_to_dict(obj: Any, coerce_to_float: bool = True) -> dict:
"""Given a dataclass object, emit a json-friendly dict.
All values will be checked to ensure they match the types specified
on fields. Note that only a limited set of types is supported.
If coerce_to_float is True, integer values present on float typed fields
will be converted to floats in the dict output. If False, a TypeError
will be triggered.
"""
out = _Outputter(obj, create=True, coerce_to_float=coerce_to_float).run()
assert isinstance(out, dict)
return out
def dataclass_from_dict(cls: Type[T],
values: dict,
coerce_to_float: bool = True) -> T:
"""Given a dict, instantiates a dataclass of the given type.
The dict must be in the json-friendly format as emitted from
dataclass_to_dict. This means that sequence values such as tuples or
sets should be passed as lists, enums should be passed as their
associated values, and nested dataclasses should be passed as dicts.
If coerce_to_float is True, int values passed for float typed fields
will be converted to float values. Otherwise a TypeError is raised.
"""
return _Inputter(cls, coerce_to_float=coerce_to_float).run(values)
def dataclass_validate(obj: Any, coerce_to_float: bool = True) -> None:
"""Ensure that current values in a dataclass are the correct types."""
_Outputter(obj, create=False, coerce_to_float=coerce_to_float).run()
def _field_type_str(cls: Type, field: dataclasses.Field) -> str:
# We expect to be operating under 'from __future__ import annotations'
# so field types should always be strings for us; not actual types.
# (Can pull this check out once we get to Python 3.10)
typestr: str = field.type # type: ignore
if not isinstance(typestr, str):
raise RuntimeError(
f'Dataclass {cls.__name__} seems to have'
f' been created without "from __future__ import annotations";'
f' those dataclasses are unsupported here.')
return typestr
def _raise_type_error(fieldpath: str, valuetype: Type,
expected: Tuple[Type, ...]) -> None:
"""Raise an error when a field value's type does not match expected."""
assert isinstance(expected, tuple)
assert all(isinstance(e, type) for e in expected)
if len(expected) == 1:
expected_str = expected[0].__name__
else:
names = ', '.join(t.__name__ for t in expected)
expected_str = f'Union[{names}]'
raise TypeError(f'Invalid value type for "{fieldpath}";'
f' expected "{expected_str}", got'
f' "{valuetype.__name__}".')
class _Outputter:
def __init__(self, obj: Any, create: bool, coerce_to_float: bool) -> None:
self._obj = obj
self._create = create
self._coerce_to_float = coerce_to_float
def run(self) -> Any:
"""Do the thing."""
return self._dataclass_to_output(self._obj, '')
def _value_to_output(self, fieldpath: str, typestr: str,
value: Any) -> Any:
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# For simple flat types, look for exact matches:
simpletype = SIMPLE_NAMES_TO_TYPES.get(typestr)
if simpletype is not None:
if type(value) is not simpletype:
# Special case: if they want to coerce ints to floats, do so.
if (self._coerce_to_float and simpletype is float
and type(value) is int):
return float(value) if self._create else None
_raise_type_error(fieldpath, type(value), (simpletype, ))
return value
if typestr.startswith('Optional[') and typestr.endswith(']'):
subtypestr = typestr[9:-1]
# Handle the 'None' case special and do the default otherwise.
if value is None:
return None
return self._value_to_output(fieldpath, subtypestr, value)
if typestr.startswith('List[') and typestr.endswith(']'):
subtypestr = typestr[5:-1]
if not isinstance(value, list):
raise TypeError(f'Expected a list for {fieldpath};'
f' found a {type(value)}')
if self._create:
return [
self._value_to_output(fieldpath, subtypestr, x)
for x in value
]
for x in value:
self._value_to_output(fieldpath, subtypestr, x)
return None
if typestr.startswith('Set[') and typestr.endswith(']'):
subtypestr = typestr[4:-1]
if not isinstance(value, set):
raise TypeError(f'Expected a set for {fieldpath};'
f' found a {type(value)}')
if self._create:
# Note: we output json-friendly values so this becomes a list.
return [
self._value_to_output(fieldpath, subtypestr, x)
for x in value
]
for x in value:
self._value_to_output(fieldpath, subtypestr, x)
return None
if dataclasses.is_dataclass(value):
return self._dataclass_to_output(value, fieldpath)
if isinstance(value, Enum):
enumvalue = value.value
if type(enumvalue) not in SIMPLE_TYPES_TO_NAMES:
raise TypeError(f'Invalid enum value type {type(enumvalue)}'
f' for "{fieldpath}".')
return enumvalue
raise TypeError(
f"Field '{fieldpath}' of type '{typestr}' is unsupported here.")
def _dataclass_to_output(self, obj: Any, fieldpath: str) -> Any:
if not dataclasses.is_dataclass(obj):
raise TypeError(f'Passed obj {obj} is not a dataclass.')
fields = dataclasses.fields(obj)
out: Optional[Dict[str, Any]] = {} if self._create else None
for field in fields:
fieldname = field.name
if fieldpath:
subfieldpath = f'{fieldpath}.{fieldname}'
else:
subfieldpath = fieldname
typestr = _field_type_str(type(obj), field)
value = getattr(obj, fieldname)
outvalue = self._value_to_output(subfieldpath, typestr, value)
if self._create:
assert out is not None
out[fieldname] = outvalue
return out
class _Inputter(Generic[T]):
def __init__(self, cls: Type[T], coerce_to_float: bool):
self._cls = cls
self._coerce_to_float = coerce_to_float
def run(self, values: dict) -> T:
"""Do the thing."""
return self._dataclass_from_input( # type: ignore
self._cls, '', values)
def _value_from_input(self, cls: Type, fieldpath: str, typestr: str,
value: Any) -> Any:
"""Convert an assigned value to what a dataclass field expects."""
# pylint: disable=too-many-return-statements
simpletype = SIMPLE_NAMES_TO_TYPES.get(typestr)
if simpletype is not None:
if type(value) is not simpletype:
# Special case: if they want to coerce ints to floats, do so.
if (self._coerce_to_float and simpletype is float
and type(value) is int):
return float(value)
_raise_type_error(fieldpath, type(value), (simpletype, ))
return value
if typestr.startswith('List[') and typestr.endswith(']'):
return self._sequence_from_input(cls, fieldpath, typestr, value,
'List', list)
if typestr.startswith('Set[') and typestr.endswith(']'):
return self._sequence_from_input(cls, fieldpath, typestr, value,
'Set', set)
if typestr.startswith('Optional[') and typestr.endswith(']'):
subtypestr = typestr[9:-1]
# Handle the 'None' case special and do the default
# thing otherwise.
if value is None:
return None
return self._value_from_input(cls, fieldpath, subtypestr, value)
# Ok, its not a builtin type. It might be an enum or nested dataclass.
cls2 = getattr(inspect.getmodule(cls), typestr, None)
if cls2 is None:
raise RuntimeError(f"Unable to resolve '{typestr}'"
f" used by class '{cls.__name__}';"
f' make sure all nested types are declared'
f' in the global namespace of the module where'
f" '{cls.__name__} is defined.")
if dataclasses.is_dataclass(cls2):
return self._dataclass_from_input(cls2, fieldpath, value)
if issubclass(cls2, Enum):
return enum_by_value(cls2, value)
raise TypeError(
f"Field '{fieldpath}' of type '{typestr}' is unsupported here.")
def _dataclass_from_input(self, cls: Type, fieldpath: str,
values: dict) -> Any:
"""Given a dict, instantiates a dataclass of the given type.
The dict must be in the json-friendly format as emitted from
dataclass_to_dict. This means that sequence values such as tuples or
sets should be passed as lists, enums should be passed as their
associated values, and nested dataclasses should be passed as dicts.
"""
if not dataclasses.is_dataclass(cls):
raise TypeError(f'Passed class {cls} is not a dataclass.')
if not isinstance(values, dict):
raise TypeError("Expected a dict for 'values' arg.")
# noinspection PyDataclass
fields = dataclasses.fields(cls)
fields_by_name = {f.name: f for f in fields}
args: Dict[str, Any] = {}
for key, value in values.items():
field = fields_by_name.get(key)
if field is None:
raise AttributeError(f"'{cls.__name__}' has no '{key}' field.")
typestr = _field_type_str(cls, field)
subfieldpath = (f'{fieldpath}.{field.name}'
if fieldpath else field.name)
args[key] = self._value_from_input(cls, subfieldpath, typestr,
value)
return cls(**args)
def _sequence_from_input(self, cls: Type, fieldpath: str, typestr: str,
value: Any, seqtypestr: str,
seqtype: Type) -> Any:
# Because we are json-centric, we expect a list for all sequences.
if type(value) is not list:
raise TypeError(f'Invalid input value for "{fieldpath}";'
f' expected a list, got a {type(value).__name__}')
subtypestr = typestr[len(seqtypestr) + 1:-1]
return seqtype(
self._value_from_input(cls, fieldpath, subtypestr, i)
for i in value)

1351
dist/ba_data/python/efro/dataclassio.py vendored Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,50 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for importing, exporting, and validating dataclasses.
This allows complex nested dataclasses to be flattened to json-compatible
data and restored from said data. It also gracefully handles and preserves
unrecognized attribute data, allowing older clients to interact with newer
data formats in a nondestructive manner.
"""
from __future__ import annotations
from efro.util import set_canonical_module
from efro.dataclassio._base import Codec, IOAttrs, IOExtendedData
from efro.dataclassio._prep import (
ioprep,
ioprepped,
will_ioprep,
is_ioprepped_dataclass,
)
from efro.dataclassio._pathcapture import DataclassFieldLookup
from efro.dataclassio._api import (
JsonStyle,
dataclass_to_dict,
dataclass_to_json,
dataclass_from_dict,
dataclass_from_json,
dataclass_validate,
)
__all__ = [
'JsonStyle',
'Codec',
'IOAttrs',
'IOExtendedData',
'ioprep',
'ioprepped',
'will_ioprep',
'is_ioprepped_dataclass',
'DataclassFieldLookup',
'dataclass_to_dict',
'dataclass_to_json',
'dataclass_from_dict',
'dataclass_from_json',
'dataclass_validate',
]
# Have these things present themselves cleanly as 'thismodule.SomeClass'
# instead of 'thismodule._internalmodule.SomeClass'
set_canonical_module(module_globals=globals(), names=__all__)

View file

@ -0,0 +1,163 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for importing, exporting, and validating dataclasses.
This allows complex nested dataclasses to be flattened to json-compatible
data and restored from said data. It also gracefully handles and preserves
unrecognized attribute data, allowing older clients to interact with newer
data formats in a nondestructive manner.
"""
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, TypeVar
from efro.dataclassio._outputter import _Outputter
from efro.dataclassio._inputter import _Inputter
from efro.dataclassio._base import Codec
if TYPE_CHECKING:
from typing import Any
T = TypeVar('T')
class JsonStyle(Enum):
"""Different style types for json."""
# Single line, no spaces, no sorting. Not deterministic.
# Use this for most storage purposes.
FAST = 'fast'
# Single line, no spaces, sorted keys. Deterministic.
# Use this when output may be hashed or compared for equality.
SORTED = 'sorted'
# Multiple lines, spaces, sorted keys. Deterministic.
# Use this for pretty human readable output.
PRETTY = 'pretty'
def dataclass_to_dict(
obj: Any, codec: Codec = Codec.JSON, coerce_to_float: bool = True
) -> dict:
"""Given a dataclass object, return a json-friendly dict.
All values will be checked to ensure they match the types specified
on fields. Note that a limited set of types and data configurations is
supported.
Values with type Any will be checked to ensure they match types supported
directly by json. This does not include types such as tuples which are
implicitly translated by Python's json module (as this would break
the ability to do a lossless round-trip with data).
If coerce_to_float is True, integer values present on float typed fields
will be converted to float in the dict output. If False, a TypeError
will be triggered.
"""
out = _Outputter(
obj, create=True, codec=codec, coerce_to_float=coerce_to_float
).run()
assert isinstance(out, dict)
return out
def dataclass_to_json(
obj: Any,
coerce_to_float: bool = True,
pretty: bool = False,
sort_keys: bool | None = None,
) -> str:
"""Utility function; return a json string from a dataclass instance.
Basically json.dumps(dataclass_to_dict(...)).
By default, keys are sorted for pretty output and not otherwise, but
this can be overridden by supplying a value for the 'sort_keys' arg.
"""
import json
jdict = dataclass_to_dict(
obj=obj, coerce_to_float=coerce_to_float, codec=Codec.JSON
)
if sort_keys is None:
sort_keys = pretty
if pretty:
return json.dumps(jdict, indent=2, sort_keys=sort_keys)
return json.dumps(jdict, separators=(',', ':'), sort_keys=sort_keys)
def dataclass_from_dict(
cls: type[T],
values: dict,
codec: Codec = Codec.JSON,
coerce_to_float: bool = True,
allow_unknown_attrs: bool = True,
discard_unknown_attrs: bool = False,
) -> T:
"""Given a dict, return a dataclass of a given type.
The dict must be formatted to match the specified codec (generally
json-friendly object types). This means that sequence values such as
tuples or sets should be passed as lists, enums should be passed as their
associated values, nested dataclasses should be passed as dicts, etc.
All values are checked to ensure their types/values are valid.
Data for attributes of type Any will be checked to ensure they match
types supported directly by json. This does not include types such
as tuples which are implicitly translated by Python's json module
(as this would break the ability to do a lossless round-trip with data).
If coerce_to_float is True, int values passed for float typed fields
will be converted to float values. Otherwise, a TypeError is raised.
If allow_unknown_attrs is False, AttributeErrors will be raised for
attributes present in the dict but not on the data class. Otherwise, they
will be preserved as part of the instance and included if it is
exported back to a dict, unless discard_unknown_attrs is True, in which
case they will simply be discarded.
"""
return _Inputter(
cls,
codec=codec,
coerce_to_float=coerce_to_float,
allow_unknown_attrs=allow_unknown_attrs,
discard_unknown_attrs=discard_unknown_attrs,
).run(values)
def dataclass_from_json(
cls: type[T],
json_str: str,
coerce_to_float: bool = True,
allow_unknown_attrs: bool = True,
discard_unknown_attrs: bool = False,
) -> T:
"""Utility function; return a dataclass instance given a json string.
Basically dataclass_from_dict(json.loads(...))
"""
import json
return dataclass_from_dict(
cls=cls,
values=json.loads(json_str),
coerce_to_float=coerce_to_float,
allow_unknown_attrs=allow_unknown_attrs,
discard_unknown_attrs=discard_unknown_attrs,
)
def dataclass_validate(
obj: Any, coerce_to_float: bool = True, codec: Codec = Codec.JSON
) -> None:
"""Ensure that values in a dataclass instance are the correct types."""
# Simply run an output pass but tell it not to generate data;
# only run validation.
_Outputter(
obj, create=False, codec=codec, coerce_to_float=coerce_to_float
).run()

View file

@ -0,0 +1,276 @@
# Released under the MIT License. See LICENSE for details.
#
"""Core components of dataclassio."""
from __future__ import annotations
import dataclasses
import typing
import datetime
from enum import Enum
from typing import TYPE_CHECKING, get_args
# noinspection PyProtectedMember
from typing import _AnnotatedAlias # type: ignore
if TYPE_CHECKING:
from typing import Any, Callable
# Types which we can pass through as-is.
SIMPLE_TYPES = {int, bool, str, float, type(None)}
# Attr name for dict of extra attributes included on dataclass instances.
# Note that this is only added if extra attributes are present.
EXTRA_ATTRS_ATTR = '_DCIOEXATTRS'
def _raise_type_error(
fieldpath: str, valuetype: type, expected: tuple[type, ...]
) -> None:
"""Raise an error when a field value's type does not match expected."""
assert isinstance(expected, tuple)
assert all(isinstance(e, type) for e in expected)
if len(expected) == 1:
expected_str = expected[0].__name__
else:
expected_str = ' | '.join(t.__name__ for t in expected)
raise TypeError(
f'Invalid value type for "{fieldpath}";'
f' expected "{expected_str}", got'
f' "{valuetype.__name__}".'
)
class Codec(Enum):
"""Specifies expected data format exported to or imported from."""
# Use only types that will translate cleanly to/from json: lists,
# dicts with str keys, bools, ints, floats, and None.
JSON = 'json'
# Mostly like JSON but passes bytes and datetime objects through
# as-is instead of converting them to json-friendly types.
FIRESTORE = 'firestore'
class IOExtendedData:
"""A class that data types can inherit from for extra functionality."""
def will_output(self) -> None:
"""Called before data is sent to an outputter.
Can be overridden to validate or filter data before
sending it on its way.
"""
@classmethod
def will_input(cls, data: dict) -> None:
"""Called on raw data before a class instance is created from it.
Can be overridden to migrate old data formats to new, etc.
"""
def _is_valid_for_codec(obj: Any, codec: Codec) -> bool:
"""Return whether a value consists solely of json-supported types.
Note that this does not include things like tuples which are
implicitly translated to lists by python's json module.
"""
if obj is None:
return True
objtype = type(obj)
if objtype in (int, float, str, bool):
return True
if objtype is dict:
# JSON 'objects' supports only string dict keys, but all value types.
return all(
isinstance(k, str) and _is_valid_for_codec(v, codec)
for k, v in obj.items()
)
if objtype is list:
return all(_is_valid_for_codec(elem, codec) for elem in obj)
# A few things are valid in firestore but not json.
if issubclass(objtype, datetime.datetime) or objtype is bytes:
return codec is Codec.FIRESTORE
return False
class IOAttrs:
"""For specifying io behavior in annotations.
'storagename', if passed, is the name used when storing to json/etc.
'store_default' can be set to False to avoid writing values when equal
to the default value. Note that this requires the dataclass field
to define a default or default_factory or for its IOAttrs to
define a soft_default value.
'whole_days', if True, requires datetime values to be exactly on day
boundaries (see efro.util.utc_today()).
'whole_hours', if True, requires datetime values to lie exactly on hour
boundaries (see efro.util.utc_this_hour()).
'whole_minutes', if True, requires datetime values to lie exactly on minute
boundaries (see efro.util.utc_this_minute()).
'soft_default', if passed, injects a default value into dataclass
instantiation when the field is not present in the input data.
This allows dataclasses to add new non-optional fields while
gracefully 'upgrading' old data. Note that when a soft_default is
present it will take precedence over field defaults when determining
whether to store a value for a field with store_default=False
(since the soft_default value is what we'll get when reading that
same data back in when the field is omitted).
'soft_default_factory' is similar to 'default_factory' in dataclass
fields; it should be used instead of 'soft_default' for mutable types
such as lists to prevent a single default object from unintentionally
changing over time.
"""
# A sentinel object to detect if a parameter is supplied or not. Use
# a class to give it a better repr.
class _MissingType:
pass
MISSING = _MissingType()
storagename: str | None = None
store_default: bool = True
whole_days: bool = False
whole_hours: bool = False
whole_minutes: bool = False
soft_default: Any = MISSING
soft_default_factory: Callable[[], Any] | _MissingType = MISSING
def __init__(
self,
storagename: str | None = storagename,
store_default: bool = store_default,
whole_days: bool = whole_days,
whole_hours: bool = whole_hours,
whole_minutes: bool = whole_minutes,
soft_default: Any = MISSING,
soft_default_factory: Callable[[], Any] | _MissingType = MISSING,
):
# Only store values that differ from class defaults to keep
# our instances nice and lean.
cls = type(self)
if storagename != cls.storagename:
self.storagename = storagename
if store_default != cls.store_default:
self.store_default = store_default
if whole_days != cls.whole_days:
self.whole_days = whole_days
if whole_hours != cls.whole_hours:
self.whole_hours = whole_hours
if whole_minutes != cls.whole_minutes:
self.whole_minutes = whole_minutes
if soft_default is not cls.soft_default:
# Do what dataclasses does with its default types and
# tell the user to use factory for mutable ones.
if isinstance(soft_default, (list, dict, set)):
raise ValueError(
f'mutable {type(soft_default)} is not allowed'
f' for soft_default; use soft_default_factory.'
)
self.soft_default = soft_default
if soft_default_factory is not cls.soft_default_factory:
self.soft_default_factory = soft_default_factory
if self.soft_default is not cls.soft_default:
raise ValueError(
'Cannot set both soft_default and soft_default_factory'
)
def validate_for_field(self, cls: type, field: dataclasses.Field) -> None:
"""Ensure the IOAttrs instance is ok to use with the provided field."""
# Turning off store_default requires the field to have either
# a default or a a default_factory or for us to have soft equivalents.
if not self.store_default:
field_default_factory: Any = field.default_factory
if (
field_default_factory is dataclasses.MISSING
and field.default is dataclasses.MISSING
and self.soft_default is self.MISSING
and self.soft_default_factory is self.MISSING
):
raise TypeError(
f'Field {field.name} of {cls} has'
f' neither a default nor a default_factory'
f' and IOAttrs contains neither a soft_default'
f' nor a soft_default_factory;'
f' store_default=False cannot be set for it.'
)
def validate_datetime(
self, value: datetime.datetime, fieldpath: str
) -> None:
"""Ensure a datetime value meets our value requirements."""
if self.whole_days:
if any(
x != 0
for x in (
value.hour,
value.minute,
value.second,
value.microsecond,
)
):
raise ValueError(
f'Value {value} at {fieldpath} is not a whole day.'
)
elif self.whole_hours:
if any(
x != 0 for x in (value.minute, value.second, value.microsecond)
):
raise ValueError(
f'Value {value} at {fieldpath}' f' is not a whole hour.'
)
elif self.whole_minutes:
if any(x != 0 for x in (value.second, value.microsecond)):
raise ValueError(
f'Value {value} at {fieldpath}' f' is not a whole minute.'
)
def _get_origin(anntype: Any) -> Any:
"""Given a type annotation, return its origin or itself if there is none.
This differs from typing.get_origin in that it will never return None.
This lets us use the same code path for handling typing.List
that we do for handling list, which is good since they can be used
interchangeably in annotations.
"""
origin = typing.get_origin(anntype)
return anntype if origin is None else origin
def _parse_annotated(anntype: Any) -> tuple[Any, IOAttrs | None]:
"""Parse Annotated() constructs, returning annotated type & IOAttrs."""
# If we get an Annotated[foo, bar, eep] we take
# foo as the actual type, and we look for IOAttrs instances in
# bar/eep to affect our behavior.
ioattrs: IOAttrs | None = None
if isinstance(anntype, _AnnotatedAlias):
annargs = get_args(anntype)
for annarg in annargs[1:]:
if isinstance(annarg, IOAttrs):
if ioattrs is not None:
raise RuntimeError(
'Multiple IOAttrs instances found for a'
' single annotation; this is not supported.'
)
ioattrs = annarg
# I occasionally just throw a 'x' down when I mean IOAttrs('x');
# catch these mistakes.
elif isinstance(annarg, (str, int, float, bool)):
raise RuntimeError(
f'Raw {type(annarg)} found in Annotated[] entry:'
f' {anntype}; this is probably not what you intended.'
)
anntype = annargs[0]
return anntype, ioattrs

View file

@ -0,0 +1,555 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for dataclassio related to pulling data into dataclasses."""
# Note: We do lots of comparing of exact types here which is normally
# frowned upon (stuff like isinstance() is usually encouraged).
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
from enum import Enum
import dataclasses
import typing
import types
import datetime
from typing import TYPE_CHECKING, Generic, TypeVar
from efro.util import enum_by_value, check_utc
from efro.dataclassio._base import (
Codec,
_parse_annotated,
EXTRA_ATTRS_ATTR,
_is_valid_for_codec,
_get_origin,
SIMPLE_TYPES,
_raise_type_error,
IOExtendedData,
)
from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING:
from typing import Any
from efro.dataclassio._base import IOAttrs
from efro.dataclassio._outputter import _Outputter
T = TypeVar('T')
class _Inputter(Generic[T]):
def __init__(
self,
cls: type[T],
codec: Codec,
coerce_to_float: bool,
allow_unknown_attrs: bool = True,
discard_unknown_attrs: bool = False,
):
self._cls = cls
self._codec = codec
self._coerce_to_float = coerce_to_float
self._allow_unknown_attrs = allow_unknown_attrs
self._discard_unknown_attrs = discard_unknown_attrs
self._soft_default_validator: _Outputter | None = None
if not allow_unknown_attrs and discard_unknown_attrs:
raise ValueError(
'discard_unknown_attrs cannot be True'
' when allow_unknown_attrs is False.'
)
def run(self, values: dict) -> T:
"""Do the thing."""
# For special extended data types, call their 'will_output' callback.
tcls = self._cls
if issubclass(tcls, IOExtendedData):
tcls.will_input(values)
out = self._dataclass_from_input(self._cls, '', values)
assert isinstance(out, self._cls)
return out
def _value_from_input(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: Any,
ioattrs: IOAttrs | None,
) -> Any:
"""Convert an assigned value to what a dataclass field expects."""
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
origin = _get_origin(anntype)
if origin is typing.Any:
if not _is_valid_for_codec(value, self._codec):
raise TypeError(
f'Invalid value type for \'{fieldpath}\';'
f' \'Any\' typed values must contain only'
f' types directly supported by the specified'
f' codec ({self._codec.name}); found'
f' \'{type(value).__name__}\' which is not.'
)
return value
if origin is typing.Union or origin is types.UnionType:
# Currently, the only unions we support are None/Value
# (translated from Optional), which we verified on prep.
# So let's treat this as a simple optional case.
if value is None:
return None
childanntypes_l = [
c for c in typing.get_args(anntype) if c is not type(None)
] # noqa (pycodestyle complains about *is* with type)
assert len(childanntypes_l) == 1
return self._value_from_input(
cls, fieldpath, childanntypes_l[0], value, ioattrs
)
# Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time).
assert isinstance(origin, type)
if origin in SIMPLE_TYPES:
if type(value) is not origin:
# Special case: if they want to coerce ints to floats, do so.
if (
self._coerce_to_float
and origin is float
and type(value) is int
):
return float(value)
_raise_type_error(fieldpath, type(value), (origin,))
return value
if origin in {list, set}:
return self._sequence_from_input(
cls, fieldpath, anntype, value, origin, ioattrs
)
if origin is tuple:
return self._tuple_from_input(
cls, fieldpath, anntype, value, ioattrs
)
if origin is dict:
return self._dict_from_input(
cls, fieldpath, anntype, value, ioattrs
)
if dataclasses.is_dataclass(origin):
return self._dataclass_from_input(origin, fieldpath, value)
if issubclass(origin, Enum):
return enum_by_value(origin, value)
if issubclass(origin, datetime.datetime):
return self._datetime_from_input(cls, fieldpath, value, ioattrs)
if origin is bytes:
return self._bytes_from_input(origin, fieldpath, value)
raise TypeError(
f"Field '{fieldpath}' of type '{anntype}' is unsupported here."
)
def _bytes_from_input(self, cls: type, fieldpath: str, value: Any) -> bytes:
"""Given input data, returns bytes."""
import base64
# For firestore, bytes are passed as-is. Otherwise, they're encoded
# as base64.
if self._codec is Codec.FIRESTORE:
if not isinstance(value, bytes):
raise TypeError(
f'Expected a bytes object for {fieldpath}'
f' on {cls.__name__}; got a {type(value)}.'
)
return value
assert self._codec is Codec.JSON
if not isinstance(value, str):
raise TypeError(
f'Expected a string object for {fieldpath}'
f' on {cls.__name__}; got a {type(value)}.'
)
return base64.b64decode(value)
def _dataclass_from_input(
self, cls: type, fieldpath: str, values: dict
) -> Any:
"""Given a dict, instantiates a dataclass of the given type.
The dict must be in the json-friendly format as emitted from
dataclass_to_dict. This means that sequence values such as tuples or
sets should be passed as lists, enums should be passed as their
associated values, and nested dataclasses should be passed as dicts.
"""
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
if not isinstance(values, dict):
raise TypeError(
f'Expected a dict for {fieldpath} on {cls.__name__};'
f' got a {type(values)}.'
)
prep = PrepSession(explicit=False).prep_dataclass(
cls, recursion_level=0
)
assert prep is not None
extra_attrs = {}
# noinspection PyDataclass
fields = dataclasses.fields(cls)
fields_by_name = {f.name: f for f in fields}
# Preprocess all fields to convert Annotated[] to contained types
# and IOAttrs.
parsed_field_annotations = {
f.name: _parse_annotated(prep.annotations[f.name]) for f in fields
}
# Go through all data in the input, converting it to either dataclass
# args or extra data.
args: dict[str, Any] = {}
for rawkey, value in values.items():
key = prep.storage_names_to_attr_names.get(rawkey, rawkey)
field = fields_by_name.get(key)
# Store unknown attrs off to the side (or error if desired).
if field is None:
if self._allow_unknown_attrs:
if self._discard_unknown_attrs:
continue
# Treat this like 'Any' data; ensure that it is valid
# raw json.
if not _is_valid_for_codec(value, self._codec):
raise TypeError(
f'Unknown attr \'{key}\''
f' on {fieldpath} contains data type(s)'
f' not supported by the specified codec'
f' ({self._codec.name}).'
)
extra_attrs[key] = value
else:
raise AttributeError(
f"'{cls.__name__}' has no '{key}' field."
)
else:
fieldname = field.name
anntype, ioattrs = parsed_field_annotations[fieldname]
subfieldpath = (
f'{fieldpath}.{fieldname}' if fieldpath else fieldname
)
args[key] = self._value_from_input(
cls, subfieldpath, anntype, value, ioattrs
)
# Go through all fields looking for any not yet present in our data.
# If we find any such fields with a soft-default value or factory
# defined, inject that soft value into our args.
for key, aparsed in parsed_field_annotations.items():
if key in args:
continue
ioattrs = aparsed[1]
if ioattrs is not None and (
ioattrs.soft_default is not ioattrs.MISSING
or ioattrs.soft_default_factory is not ioattrs.MISSING
):
if ioattrs.soft_default is not ioattrs.MISSING:
soft_default = ioattrs.soft_default
else:
assert callable(ioattrs.soft_default_factory)
soft_default = ioattrs.soft_default_factory()
args[key] = soft_default
# Make sure these values are valid since we didn't run
# them through our normal input type checking.
self._type_check_soft_default(
value=soft_default,
anntype=aparsed[0],
fieldpath=(f'{fieldpath}.{key}' if fieldpath else key),
)
try:
out = cls(**args)
except Exception as exc:
raise ValueError(
f'Error instantiating class {cls.__name__}'
f' at {fieldpath}: {exc}'
) from exc
if extra_attrs:
setattr(out, EXTRA_ATTRS_ATTR, extra_attrs)
return out
def _type_check_soft_default(
self, value: Any, anntype: Any, fieldpath: str
) -> None:
from efro.dataclassio._outputter import _Outputter
# Counter-intuitively, we create an outputter as part of
# our inputter. Soft-default values are already internal types;
# we need to make sure they can go out from there.
if self._soft_default_validator is None:
self._soft_default_validator = _Outputter(
obj=None,
create=False,
codec=self._codec,
coerce_to_float=self._coerce_to_float,
)
self._soft_default_validator.soft_default_check(
value=value, anntype=anntype, fieldpath=fieldpath
)
def _dict_from_input(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: Any,
ioattrs: IOAttrs | None,
) -> Any:
# pylint: disable=too-many-branches
# pylint: disable=too-many-locals
if not isinstance(value, dict):
raise TypeError(
f'Expected a dict for \'{fieldpath}\' on {cls.__name__};'
f' got a {type(value)}.'
)
childtypes = typing.get_args(anntype)
assert len(childtypes) in (0, 2)
out: dict
# We treat 'Any' dicts simply as json; we don't do any translating.
if not childtypes or childtypes[0] is typing.Any:
if not isinstance(value, dict) or not _is_valid_for_codec(
value, self._codec
):
raise TypeError(
f'Got invalid value for Dict[Any, Any]'
f' at \'{fieldpath}\' on {cls.__name__};'
f' all keys and values must be'
f' compatible with the specified codec'
f' ({self._codec.name}).'
)
out = value
else:
out = {}
keyanntype, valanntype = childtypes
# Ok; we've got definite key/value types (which we verified as
# valid during prep). Run all keys/values through it.
# str keys we just take directly since that's supported by json.
if keyanntype is str:
for key, val in value.items():
if not isinstance(key, str):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected a str.'
)
out[key] = self._value_from_input(
cls, fieldpath, valanntype, val, ioattrs
)
# int keys are stored in json as str versions of themselves.
elif keyanntype is int:
for key, val in value.items():
if not isinstance(key, str):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected a str.'
)
try:
keyint = int(key)
except ValueError as exc:
raise TypeError(
f'Got invalid key value {key} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected an int in string form.'
) from exc
out[keyint] = self._value_from_input(
cls, fieldpath, valanntype, val, ioattrs
)
elif issubclass(keyanntype, Enum):
# In prep, we verified that all these enums' values have
# the same type, so we can just look at the first to see if
# this is a string enum or an int enum.
enumvaltype = type(next(iter(keyanntype)).value)
assert enumvaltype in (int, str)
if enumvaltype is str:
for key, val in value.items():
try:
enumval = enum_by_value(keyanntype, key)
except ValueError as exc:
raise ValueError(
f'Got invalid key value {key} for'
f' dict key at \'{fieldpath}\''
f' on {cls.__name__};'
f' expected a value corresponding to'
f' a {keyanntype}.'
) from exc
out[enumval] = self._value_from_input(
cls, fieldpath, valanntype, val, ioattrs
)
else:
for key, val in value.items():
try:
enumval = enum_by_value(keyanntype, int(key))
except (ValueError, TypeError) as exc:
raise ValueError(
f'Got invalid key value {key} for'
f' dict key at \'{fieldpath}\''
f' on {cls.__name__};'
f' expected {keyanntype} value (though'
f' in string form).'
) from exc
out[enumval] = self._value_from_input(
cls, fieldpath, valanntype, val, ioattrs
)
else:
raise RuntimeError(f'Unhandled dict in-key-type {keyanntype}')
return out
def _sequence_from_input(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: Any,
seqtype: type,
ioattrs: IOAttrs | None,
) -> Any:
# Because we are json-centric, we expect a list for all sequences.
if type(value) is not list:
raise TypeError(
f'Invalid input value for "{fieldpath}";'
f' expected a list, got a {type(value).__name__}'
)
childanntypes = typing.get_args(anntype)
# 'Any' type children; make sure they are valid json values
# and then just grab them.
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
for i, child in enumerate(value):
if not _is_valid_for_codec(child, self._codec):
raise TypeError(
f'Item {i} of {fieldpath} contains'
f' data type(s) not supported by json.'
)
return value if type(value) is seqtype else seqtype(value)
# We contain elements of some specified type.
assert len(childanntypes) == 1
childanntype = childanntypes[0]
return seqtype(
self._value_from_input(cls, fieldpath, childanntype, i, ioattrs)
for i in value
)
def _datetime_from_input(
self, cls: type, fieldpath: str, value: Any, ioattrs: IOAttrs | None
) -> Any:
# For firestore we expect a datetime object.
if self._codec is Codec.FIRESTORE:
# Don't compare exact type here, as firestore can give us
# a subclass with extended precision.
if not isinstance(value, datetime.datetime):
raise TypeError(
f'Invalid input value for "{fieldpath}" on'
f' "{cls.__name__}";'
f' expected a datetime, got a {type(value).__name__}'
)
check_utc(value)
return value
assert self._codec is Codec.JSON
# We expect a list of 7 ints.
if type(value) is not list:
raise TypeError(
f'Invalid input value for "{fieldpath}" on "{cls.__name__}";'
f' expected a list, got a {type(value).__name__}'
)
if len(value) != 7 or not all(isinstance(x, int) for x in value):
raise ValueError(
f'Invalid input value for "{fieldpath}" on "{cls.__name__}";'
f' expected a list of 7 ints, got {[type(v) for v in value]}.'
)
out = datetime.datetime( # type: ignore
*value, tzinfo=datetime.timezone.utc
)
if ioattrs is not None:
ioattrs.validate_datetime(out, fieldpath)
return out
def _tuple_from_input(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: Any,
ioattrs: IOAttrs | None,
) -> Any:
out: list = []
# Because we are json-centric, we expect a list for all sequences.
if type(value) is not list:
raise TypeError(
f'Invalid input value for "{fieldpath}";'
f' expected a list, got a {type(value).__name__}'
)
childanntypes = typing.get_args(anntype)
# We should have verified this to be non-zero at prep-time.
assert childanntypes
if len(value) != len(childanntypes):
raise ValueError(
f'Invalid tuple input for "{fieldpath}";'
f' expected {len(childanntypes)} values,'
f' found {len(value)}.'
)
for i, childanntype in enumerate(childanntypes):
childval = value[i]
# 'Any' type children; make sure they are valid json values
# and then just grab them.
if childanntype is typing.Any:
if not _is_valid_for_codec(childval, self._codec):
raise TypeError(
f'Item {i} of {fieldpath} contains'
f' data type(s) not supported by json.'
)
out.append(childval)
else:
out.append(
self._value_from_input(
cls, fieldpath, childanntype, childval, ioattrs
)
)
assert len(out) == len(childanntypes)
return tuple(out)

View file

@ -0,0 +1,457 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for dataclassio related to exporting data from dataclasses."""
# Note: We do lots of comparing of exact types here which is normally
# frowned upon (stuff like isinstance() is usually encouraged).
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
from enum import Enum
import dataclasses
import typing
import types
import datetime
from typing import TYPE_CHECKING
from efro.util import check_utc
from efro.dataclassio._base import (
Codec,
_parse_annotated,
EXTRA_ATTRS_ATTR,
_is_valid_for_codec,
_get_origin,
SIMPLE_TYPES,
_raise_type_error,
IOExtendedData,
)
from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING:
from typing import Any
from efro.dataclassio._base import IOAttrs
class _Outputter:
"""Validates or exports data contained in a dataclass instance."""
def __init__(
self, obj: Any, create: bool, codec: Codec, coerce_to_float: bool
) -> None:
self._obj = obj
self._create = create
self._codec = codec
self._coerce_to_float = coerce_to_float
def run(self) -> Any:
"""Do the thing."""
assert dataclasses.is_dataclass(self._obj)
# For special extended data types, call their 'will_output' callback.
if isinstance(self._obj, IOExtendedData):
self._obj.will_output()
return self._process_dataclass(type(self._obj), self._obj, '')
def soft_default_check(
self, value: Any, anntype: Any, fieldpath: str
) -> None:
"""(internal)"""
self._process_value(
type(value),
fieldpath=fieldpath,
anntype=anntype,
value=value,
ioattrs=None,
)
def _process_dataclass(self, cls: type, obj: Any, fieldpath: str) -> Any:
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
prep = PrepSession(explicit=False).prep_dataclass(
type(obj), recursion_level=0
)
assert prep is not None
fields = dataclasses.fields(obj)
out: dict[str, Any] | None = {} if self._create else None
for field in fields:
fieldname = field.name
if fieldpath:
subfieldpath = f'{fieldpath}.{fieldname}'
else:
subfieldpath = fieldname
anntype = prep.annotations[fieldname]
value = getattr(obj, fieldname)
anntype, ioattrs = _parse_annotated(anntype)
# If we're not storing default values for this fella,
# we can skip all output processing if we've got a default value.
if ioattrs is not None and not ioattrs.store_default:
# If both soft_defaults and regular field defaults
# are present we want to go with soft_defaults since
# those same values would be re-injected when reading
# the same data back in if we've omitted the field.
default_factory: Any = field.default_factory
if ioattrs.soft_default is not ioattrs.MISSING:
if ioattrs.soft_default == value:
continue
elif ioattrs.soft_default_factory is not ioattrs.MISSING:
assert callable(ioattrs.soft_default_factory)
if ioattrs.soft_default_factory() == value:
continue
elif field.default is not dataclasses.MISSING:
if field.default == value:
continue
elif default_factory is not dataclasses.MISSING:
if default_factory() == value:
continue
else:
raise RuntimeError(
f'Field {fieldname} of {cls.__name__} has'
f' no source of default values; store_default=False'
f' cannot be set for it. (AND THIS SHOULD HAVE BEEN'
f' CAUGHT IN PREP!)'
)
outvalue = self._process_value(
cls, subfieldpath, anntype, value, ioattrs
)
if self._create:
assert out is not None
storagename = (
fieldname
if (ioattrs is None or ioattrs.storagename is None)
else ioattrs.storagename
)
out[storagename] = outvalue
# If there's extra-attrs stored on us, check/include them.
extra_attrs = getattr(obj, EXTRA_ATTRS_ATTR, None)
if isinstance(extra_attrs, dict):
if not _is_valid_for_codec(extra_attrs, self._codec):
raise TypeError(
f'Extra attrs on \'{fieldpath}\' contains data type(s)'
f' not supported by \'{self._codec.value}\' codec:'
f' {extra_attrs}.'
)
if self._create:
assert out is not None
out.update(extra_attrs)
return out
def _process_value(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: Any,
ioattrs: IOAttrs | None,
) -> Any:
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
origin = _get_origin(anntype)
if origin is typing.Any:
if not _is_valid_for_codec(value, self._codec):
raise TypeError(
f'Invalid value type for \'{fieldpath}\';'
f" 'Any' typed values must contain types directly"
f' supported by the specified codec ({self._codec.name});'
f' found \'{type(value).__name__}\' which is not.'
)
return value if self._create else None
if origin is typing.Union or origin is types.UnionType:
# Currently, the only unions we support are None/Value
# (translated from Optional), which we verified on prep.
# So let's treat this as a simple optional case.
if value is None:
return None
childanntypes_l = [
c for c in typing.get_args(anntype) if c is not type(None)
] # noqa (pycodestyle complains about *is* with type)
assert len(childanntypes_l) == 1
return self._process_value(
cls, fieldpath, childanntypes_l[0], value, ioattrs
)
# Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time).
assert isinstance(origin, type)
# For simple flat types, look for exact matches:
if origin in SIMPLE_TYPES:
if type(value) is not origin:
# Special case: if they want to coerce ints to floats, do so.
if (
self._coerce_to_float
and origin is float
and type(value) is int
):
return float(value) if self._create else None
_raise_type_error(fieldpath, type(value), (origin,))
return value if self._create else None
if origin is tuple:
if not isinstance(value, tuple):
raise TypeError(
f'Expected a tuple for {fieldpath};'
f' found a {type(value)}'
)
childanntypes = typing.get_args(anntype)
# We should have verified this was non-zero at prep-time
assert childanntypes
if len(value) != len(childanntypes):
raise TypeError(
f'Tuple at {fieldpath} contains'
f' {len(value)} values; type specifies'
f' {len(childanntypes)}.'
)
if self._create:
return [
self._process_value(
cls, fieldpath, childanntypes[i], x, ioattrs
)
for i, x in enumerate(value)
]
for i, x in enumerate(value):
self._process_value(
cls, fieldpath, childanntypes[i], x, ioattrs
)
return None
if origin is list:
if not isinstance(value, list):
raise TypeError(
f'Expected a list for {fieldpath};'
f' found a {type(value)}'
)
childanntypes = typing.get_args(anntype)
# 'Any' type children; make sure they are valid values for
# the specified codec.
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
for i, child in enumerate(value):
if not _is_valid_for_codec(child, self._codec):
raise TypeError(
f'Item {i} of {fieldpath} contains'
f' data type(s) not supported by the specified'
f' codec ({self._codec.name}).'
)
# Hmm; should we do a copy here?
return value if self._create else None
# We contain elements of some specified type.
assert len(childanntypes) == 1
if self._create:
return [
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
for x in value
]
for x in value:
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
return None
if origin is set:
if not isinstance(value, set):
raise TypeError(
f'Expected a set for {fieldpath};' f' found a {type(value)}'
)
childanntypes = typing.get_args(anntype)
# 'Any' type children; make sure they are valid Any values.
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
for child in value:
if not _is_valid_for_codec(child, self._codec):
raise TypeError(
f'Set at {fieldpath} contains'
f' data type(s) not supported by the'
f' specified codec ({self._codec.name}).'
)
return list(value) if self._create else None
# We contain elements of some specified type.
assert len(childanntypes) == 1
if self._create:
# Note: we output json-friendly values so this becomes
# a list.
return [
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
for x in value
]
for x in value:
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
return None
if origin is dict:
return self._process_dict(cls, fieldpath, anntype, value, ioattrs)
if dataclasses.is_dataclass(origin):
if not isinstance(value, origin):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
return self._process_dataclass(cls, value, fieldpath)
if issubclass(origin, Enum):
if not isinstance(value, origin):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
# At prep-time we verified that these enums had valid value
# types, so we can blindly return it here.
return value.value if self._create else None
if issubclass(origin, datetime.datetime):
if not isinstance(value, origin):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
check_utc(value)
if ioattrs is not None:
ioattrs.validate_datetime(value, fieldpath)
if self._codec is Codec.FIRESTORE:
return value
assert self._codec is Codec.JSON
return (
[
value.year,
value.month,
value.day,
value.hour,
value.minute,
value.second,
value.microsecond,
]
if self._create
else None
)
if origin is bytes:
return self._process_bytes(cls, fieldpath, value)
raise TypeError(
f"Field '{fieldpath}' of type '{anntype}' is unsupported here."
)
def _process_bytes(self, cls: type, fieldpath: str, value: bytes) -> Any:
import base64
if not isinstance(value, bytes):
raise TypeError(
f'Expected bytes for {fieldpath} on {cls.__name__};'
f' found a {type(value)}.'
)
if not self._create:
return None
# In JSON we convert to base64, but firestore directly supports bytes.
if self._codec is Codec.JSON:
return base64.b64encode(value).decode()
assert self._codec is Codec.FIRESTORE
return value
def _process_dict(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: dict,
ioattrs: IOAttrs | None,
) -> Any:
# pylint: disable=too-many-branches
if not isinstance(value, dict):
raise TypeError(
f'Expected a dict for {fieldpath};' f' found a {type(value)}.'
)
childtypes = typing.get_args(anntype)
assert len(childtypes) in (0, 2)
# We treat 'Any' dicts simply as json; we don't do any translating.
if not childtypes or childtypes[0] is typing.Any:
if not isinstance(value, dict) or not _is_valid_for_codec(
value, self._codec
):
raise TypeError(
f'Invalid value for Dict[Any, Any]'
f' at \'{fieldpath}\' on {cls.__name__};'
f' all keys and values must be directly compatible'
f' with the specified codec ({self._codec.name})'
f' when dict type is Any.'
)
return value if self._create else None
# Ok; we've got a definite key type (which we verified as valid
# during prep). Make sure all keys match it.
out: dict | None = {} if self._create else None
keyanntype, valanntype = childtypes
# str keys we just export directly since that's supported by json.
if keyanntype is str:
for key, val in value.items():
if not isinstance(key, str):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected {keyanntype}.'
)
outval = self._process_value(
cls, fieldpath, valanntype, val, ioattrs
)
if self._create:
assert out is not None
out[key] = outval
# int keys are stored as str versions of themselves.
elif keyanntype is int:
for key, val in value.items():
if not isinstance(key, int):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected an int.'
)
outval = self._process_value(
cls, fieldpath, valanntype, val, ioattrs
)
if self._create:
assert out is not None
out[str(key)] = outval
elif issubclass(keyanntype, Enum):
for key, val in value.items():
if not isinstance(key, keyanntype):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected a {keyanntype}.'
)
outval = self._process_value(
cls, fieldpath, valanntype, val, ioattrs
)
if self._create:
assert out is not None
out[str(key.value)] = outval
else:
raise RuntimeError(f'Unhandled dict out-key-type {keyanntype}')
return out

View file

@ -0,0 +1,115 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality related to capturing nested dataclass paths."""
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING, TypeVar, Generic
from efro.dataclassio._base import _parse_annotated, _get_origin
from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING:
from typing import Any, Callable
T = TypeVar('T')
class _PathCapture:
"""Utility for obtaining dataclass storage paths in a type safe way."""
def __init__(self, obj: Any, pathparts: list[str] | None = None):
self._is_dataclass = dataclasses.is_dataclass(obj)
if pathparts is None:
pathparts = []
self._cls = obj if isinstance(obj, type) else type(obj)
self._pathparts = pathparts
def __getattr__(self, name: str) -> _PathCapture:
# We only allow diving into sub-objects if we are a dataclass.
if not self._is_dataclass:
raise TypeError(
f"Field path cannot include attribute '{name}' "
f'under parent {self._cls}; parent types must be dataclasses.'
)
prep = PrepSession(explicit=False).prep_dataclass(
self._cls, recursion_level=0
)
assert prep is not None
try:
anntype = prep.annotations[name]
except KeyError as exc:
raise AttributeError(f'{type(self)} has no {name} field.') from exc
anntype, ioattrs = _parse_annotated(anntype)
storagename = (
name
if (ioattrs is None or ioattrs.storagename is None)
else ioattrs.storagename
)
origin = _get_origin(anntype)
return _PathCapture(origin, pathparts=self._pathparts + [storagename])
@property
def path(self) -> str:
"""The final output path."""
return '.'.join(self._pathparts)
class DataclassFieldLookup(Generic[T]):
"""Get info about nested dataclass fields in type-safe way."""
def __init__(self, cls: type[T]) -> None:
self.cls = cls
def path(self, callback: Callable[[T], Any]) -> str:
"""Look up a path on child dataclass fields.
example:
DataclassFieldLookup(MyType).path(lambda obj: obj.foo.bar)
The above example will return the string 'foo.bar' or something
like 'f.b' if the dataclasses have custom storage names set.
It will also be static-type-checked, triggering an error if
MyType.foo.bar is not a valid path. Note, however, that the
callback technically allows any return value but only nested
dataclasses and their fields will succeed.
"""
# We tell the type system that we are returning an instance
# of our class, which allows it to perform type checking on
# member lookups. In reality, however, we are providing a
# special object which captures path lookups, so we can build
# a string from them.
if not TYPE_CHECKING:
out = callback(_PathCapture(self.cls))
if not isinstance(out, _PathCapture):
raise TypeError(
f'Expected a valid path under'
f' the provided object; got a {type(out)}.'
)
return out.path
return ''
def paths(self, callback: Callable[[T], list[Any]]) -> list[str]:
"""Look up multiple paths on child dataclass fields.
Functionality is identical to path() but for multiple paths at once.
example:
DataclassFieldLookup(MyType).paths(lambda obj: [obj.foo, obj.bar])
"""
outvals: list[str] = []
if not TYPE_CHECKING:
outs = callback(_PathCapture(self.cls))
assert isinstance(outs, list)
for out in outs:
if not isinstance(out, _PathCapture):
raise TypeError(
f'Expected a valid path under'
f' the provided object; got a {type(out)}.'
)
outvals.append(out.path)
return outvals

View file

@ -0,0 +1,459 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for prepping types for use with dataclassio."""
# Note: We do lots of comparing of exact types here which is normally
# frowned upon (stuff like isinstance() is usually encouraged).
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
import logging
from enum import Enum
import dataclasses
import typing
import types
import datetime
from typing import TYPE_CHECKING, TypeVar, get_type_hints
# noinspection PyProtectedMember
from efro.dataclassio._base import _parse_annotated, _get_origin, SIMPLE_TYPES
if TYPE_CHECKING:
from typing import Any
from efro.dataclassio._base import IOAttrs
T = TypeVar('T')
# How deep we go when prepping nested types
# (basically for detecting recursive types)
MAX_RECURSION = 10
# Attr name for data we store on dataclass types that have been prepped.
PREP_ATTR = '_DCIOPREP'
# We also store the prep-session while the prep is in progress.
# (necessary to support recursive types).
PREP_SESSION_ATTR = '_DCIOPREPSESSION'
def ioprep(cls: type, globalns: dict | None = None) -> None:
"""Prep a dataclass type for use with this module's functionality.
Prepping ensures that all types contained in a data class as well as
the usage of said types are supported by this module and pre-builds
necessary constructs needed for encoding/decoding/etc.
Prepping will happen on-the-fly as needed, but a warning will be
emitted in such cases, as it is better to explicitly prep all used types
early in a process to ensure any invalid types or configuration are caught
immediately.
Prepping a dataclass involves evaluating its type annotations, which,
as of PEP 563, are stored simply as strings. This evaluation is done
with localns set to the class dict (so that types defined in the class
can be used) and globalns set to the containing module's class.
It is possible to override globalns for special cases such as when
prepping happens as part of an execed string instead of within a
module.
"""
PrepSession(explicit=True, globalns=globalns).prep_dataclass(
cls, recursion_level=0
)
def ioprepped(cls: type[T]) -> type[T]:
"""Class decorator for easily prepping a dataclass at definition time.
Note that in some cases it may not be possible to prep a dataclass
immediately (such as when its type annotations refer to forward-declared
types). In these cases, dataclass_prep() should be explicitly called for
the class as soon as possible; ideally at module import time to expose any
errors as early as possible in execution.
"""
ioprep(cls)
return cls
def will_ioprep(cls: type[T]) -> type[T]:
"""Class decorator hinting that we will prep a class later.
In some cases (such as recursive types) we cannot use the @ioprepped
decorator and must instead call ioprep() explicitly later. However,
some of our custom pylint checking behaves differently when the
@ioprepped decorator is present, in that case requiring type annotations
to be present and not simply forward declared under an "if TYPE_CHECKING"
block. (since they are used at runtime).
The @will_ioprep decorator triggers the same pylint behavior
differences as @ioprepped (which are necessary for the later ioprep() call
to work correctly) but without actually running any prep itself.
"""
return cls
def is_ioprepped_dataclass(obj: Any) -> bool:
"""Return whether the obj is an ioprepped dataclass type or instance."""
cls = obj if isinstance(obj, type) else type(obj)
return dataclasses.is_dataclass(cls) and hasattr(cls, PREP_ATTR)
@dataclasses.dataclass
class PrepData:
"""Data we prepare and cache for a class during prep.
This data is used as part of the encoding/decoding/validating process.
"""
# Resolved annotation data with 'live' classes.
annotations: dict[str, Any]
# Map of storage names to attr names.
storage_names_to_attr_names: dict[str, str]
class PrepSession:
"""Context for a prep."""
def __init__(self, explicit: bool, globalns: dict | None = None):
self.explicit = explicit
self.globalns = globalns
def prep_dataclass(
self, cls: type, recursion_level: int
) -> PrepData | None:
"""Run prep on a dataclass if necessary and return its prep data.
The only case where this will return None is for recursive types
if the type is already being prepped higher in the call order.
"""
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
# We should only need to do this once per dataclass.
existing_data = getattr(cls, PREP_ATTR, None)
if existing_data is not None:
assert isinstance(existing_data, PrepData)
return existing_data
# Sanity check.
# Note that we now support recursive types via the PREP_SESSION_ATTR,
# so we theoretically shouldn't run into this this.
if recursion_level > MAX_RECURSION:
raise RuntimeError('Max recursion exceeded.')
# We should only be passed classes which are dataclasses.
if not isinstance(cls, type) or not dataclasses.is_dataclass(cls):
raise TypeError(f'Passed arg {cls} is not a dataclass type.')
# Add a pointer to the prep-session while doing the prep.
# This way we can ignore types that we're already in the process
# of prepping and can support recursive types.
existing_prep = getattr(cls, PREP_SESSION_ATTR, None)
if existing_prep is not None:
if existing_prep is self:
return None
# We shouldn't need to support failed preps
# or preps from multiple threads at once.
raise RuntimeError('Found existing in-progress prep.')
setattr(cls, PREP_SESSION_ATTR, self)
# Generate a warning on non-explicit preps; we prefer prep to
# happen explicitly at runtime so errors can be detected early on.
if not self.explicit:
logging.warning(
'efro.dataclassio: implicitly prepping dataclass: %s.'
' It is highly recommended to explicitly prep dataclasses'
' as soon as possible after definition (via'
' efro.dataclassio.ioprep() or the'
' @efro.dataclassio.ioprepped decorator).',
cls,
)
try:
# NOTE: Now passing the class' __dict__ (vars()) as locals
# which allows us to pick up nested classes, etc.
resolved_annotations = get_type_hints(
cls,
localns=vars(cls),
globalns=self.globalns,
include_extras=True,
)
# pylint: enable=unexpected-keyword-arg
except Exception as exc:
raise TypeError(
f'dataclassio prep for {cls} failed with error: {exc}.'
f' Make sure all types used in annotations are defined'
f' at the module or class level or add them as part of an'
f' explicit prep call.'
) from exc
# noinspection PyDataclass
fields = dataclasses.fields(cls)
fields_by_name = {f.name: f for f in fields}
all_storage_names: set[str] = set()
storage_names_to_attr_names: dict[str, str] = {}
# Ok; we've resolved actual types for this dataclass.
# now recurse through them, verifying that we support all contained
# types and prepping any contained dataclass types.
for attrname, anntype in resolved_annotations.items():
anntype, ioattrs = _parse_annotated(anntype)
# If we found attached IOAttrs data, make sure it contains
# valid values for the field it is attached to.
if ioattrs is not None:
ioattrs.validate_for_field(cls, fields_by_name[attrname])
if ioattrs.storagename is not None:
storagename = ioattrs.storagename
storage_names_to_attr_names[ioattrs.storagename] = attrname
else:
storagename = attrname
else:
storagename = attrname
# Make sure we don't have any clashes in our storage names.
if storagename in all_storage_names:
raise TypeError(
f'Multiple attrs on {cls} are using'
f' storage-name \'{storagename}\''
)
all_storage_names.add(storagename)
self.prep_type(
cls,
attrname,
anntype,
ioattrs=ioattrs,
recursion_level=recursion_level + 1,
)
# Success! Store our resolved stuff with the class and we're done.
prepdata = PrepData(
annotations=resolved_annotations,
storage_names_to_attr_names=storage_names_to_attr_names,
)
setattr(cls, PREP_ATTR, prepdata)
# Clear our prep-session tag.
assert getattr(cls, PREP_SESSION_ATTR, None) is self
delattr(cls, PREP_SESSION_ATTR)
return prepdata
def prep_type(
self,
cls: type,
attrname: str,
anntype: Any,
ioattrs: IOAttrs | None,
recursion_level: int,
) -> None:
"""Run prep on a dataclass."""
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
if recursion_level > MAX_RECURSION:
raise RuntimeError('Max recursion exceeded.')
origin = _get_origin(anntype)
if origin is typing.Union or origin is types.UnionType:
self.prep_union(
cls, attrname, anntype, recursion_level=recursion_level + 1
)
return
if anntype is typing.Any:
return
# Everything below this point assumes the annotation type resolves
# to a concrete type.
if not isinstance(origin, type):
raise TypeError(
f'Unsupported type found for \'{attrname}\' on {cls}:'
f' {anntype}'
)
# If a soft_default value/factory was passed, we do some basic
# type checking on the top-level value here. We also run full
# recursive validation on values later during inputting, but this
# should catch at least some errors early on, which can be
# useful since soft_defaults are not static type checked.
if ioattrs is not None:
have_soft_default = False
soft_default: Any = None
if ioattrs.soft_default is not ioattrs.MISSING:
have_soft_default = True
soft_default = ioattrs.soft_default
elif ioattrs.soft_default_factory is not ioattrs.MISSING:
assert callable(ioattrs.soft_default_factory)
have_soft_default = True
soft_default = ioattrs.soft_default_factory()
# Do a simple type check for the top level to catch basic
# soft_default mismatches early; full check will happen at
# input time.
if have_soft_default:
if not isinstance(soft_default, origin):
raise TypeError(
f'{cls} attr {attrname} has type {origin}'
f' but soft_default value is type {type(soft_default)}'
)
if origin in SIMPLE_TYPES:
return
# For sets and lists, check out their single contained type (if any).
if origin in (list, set):
childtypes = typing.get_args(anntype)
if len(childtypes) == 0:
# This is equivalent to Any; nothing else needs checking.
return
if len(childtypes) > 1:
raise TypeError(
f'Unrecognized typing arg count {len(childtypes)}'
f" for {anntype} attr '{attrname}' on {cls}"
)
self.prep_type(
cls,
attrname,
childtypes[0],
ioattrs=None,
recursion_level=recursion_level + 1,
)
return
if origin is dict:
childtypes = typing.get_args(anntype)
assert len(childtypes) in (0, 2)
# For key types we support Any, str, int,
# and Enums with uniform str/int values.
if not childtypes or childtypes[0] is typing.Any:
# 'Any' needs no further checks (just checked per-instance).
pass
elif childtypes[0] in (str, int):
# str and int are all good as keys.
pass
elif issubclass(childtypes[0], Enum):
# Allow our usual str or int enum types as keys.
self.prep_enum(childtypes[0])
else:
raise TypeError(
f'Dict key type {childtypes[0]} for \'{attrname}\''
f' on {cls.__name__} is not supported by dataclassio.'
)
# For value types we support any of our normal types.
if not childtypes or _get_origin(childtypes[1]) is typing.Any:
# 'Any' needs no further checks (just checked per-instance).
pass
else:
self.prep_type(
cls,
attrname,
childtypes[1],
ioattrs=None,
recursion_level=recursion_level + 1,
)
return
# For Tuples, simply check individual member types.
# (and, for now, explicitly disallow zero member types or usage
# of ellipsis)
if origin is tuple:
childtypes = typing.get_args(anntype)
if not childtypes:
raise TypeError(
f'Tuple at \'{attrname}\''
f' has no type args; dataclassio requires type args.'
)
if childtypes[-1] is ...:
raise TypeError(
f'Found ellipsis as part of type for'
f' \'{attrname}\' on {cls.__name__};'
f' these are not'
f' supported by dataclassio.'
)
for childtype in childtypes:
self.prep_type(
cls,
attrname,
childtype,
ioattrs=None,
recursion_level=recursion_level + 1,
)
return
if issubclass(origin, Enum):
self.prep_enum(origin)
return
# We allow datetime objects (and google's extended subclass of them
# used in firestore, which is why we don't look for exact type here).
if issubclass(origin, datetime.datetime):
return
if dataclasses.is_dataclass(origin):
self.prep_dataclass(origin, recursion_level=recursion_level + 1)
return
if origin is bytes:
return
raise TypeError(
f"Attr '{attrname}' on {cls.__name__} contains"
f" type '{anntype}'"
f' which is not supported by dataclassio.'
)
def prep_union(
self, cls: type, attrname: str, anntype: Any, recursion_level: int
) -> None:
"""Run prep on a Union type."""
typeargs = typing.get_args(anntype)
if (
len(typeargs) != 2
or len([c for c in typeargs if c is type(None)]) != 1
): # noqa
raise TypeError(
f'Union {anntype} for attr \'{attrname}\' on'
f' {cls.__name__} is not supported by dataclassio;'
f' only 2 member Unions with one type being None'
f' are supported.'
)
for childtype in typeargs:
self.prep_type(
cls,
attrname,
childtype,
None,
recursion_level=recursion_level + 1,
)
def prep_enum(self, enumtype: type[Enum]) -> None:
"""Run prep on an enum type."""
valtype: Any = None
# We currently support enums with str or int values; fail if we
# find any others.
for enumval in enumtype:
if not isinstance(enumval.value, (str, int)):
raise TypeError(
f'Enum value {enumval} has value type'
f' {type(enumval.value)}; only str and int is'
f' supported by dataclassio.'
)
if valtype is None:
valtype = type(enumval.value)
else:
if type(enumval.value) is not valtype:
raise TypeError(
f'Enum type {enumtype} has multiple'
f' value types; dataclassio requires'
f' them to be uniform.'
)

View file

@ -0,0 +1,71 @@
# Released under the MIT License. See LICENSE for details.
#
"""Extra rarely-needed functionality related to dataclasses."""
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
def dataclass_diff(obj1: Any, obj2: Any) -> str:
"""Generate a string showing differences between two dataclass instances.
Both must be of the exact same type.
"""
diff = _diff(obj1, obj2, 2)
return ' <no differences>' if diff == '' else diff
class DataclassDiff:
"""Wraps dataclass_diff() in an object for efficiency.
It is preferable to pass this to logging calls instead of the
final diff string since the diff will never be generated if
the associated logging level is not being emitted.
"""
def __init__(self, obj1: Any, obj2: Any):
self._obj1 = obj1
self._obj2 = obj2
def __repr__(self) -> str:
return dataclass_diff(self._obj1, self._obj2)
def _diff(obj1: Any, obj2: Any, indent: int) -> str:
assert dataclasses.is_dataclass(obj1)
assert dataclasses.is_dataclass(obj2)
if type(obj1) is not type(obj2):
raise TypeError(
f'Passed objects are not of the same'
f' type ({type(obj1)} and {type(obj2)}).'
)
bits: list[str] = []
indentstr = ' ' * indent
fields = dataclasses.fields(obj1)
for field in fields:
fieldname = field.name
val1 = getattr(obj1, fieldname)
val2 = getattr(obj2, fieldname)
# For nested dataclasses, dive in and do nice piecewise compares.
if (
dataclasses.is_dataclass(val1)
and dataclasses.is_dataclass(val2)
and type(val1) is type(val2)
):
diff = _diff(val1, val2, indent + 2)
if diff != '':
bits.append(f'{indentstr}{fieldname}:')
bits.append(diff)
# For all else just do a single line
# (perhaps we could improve on this for other complex types)
else:
if val1 != val2:
bits.append(f'{indentstr}{fieldname}: {val1} -> {val2}')
return '\n'.join(bits)

349
dist/ba_data/python/efro/debug.py vendored Normal file
View file

@ -0,0 +1,349 @@
# Released under the MIT License. See LICENSE for details.
#
"""Utilities for debugging memory leaks or other issues.
IMPORTANT - these functions use the gc module which looks 'under the hood'
at Python and sometimes returns not-fully-initialized objects, which may
cause crashes or errors due to suddenly having references to them that they
didn't expect, etc. See https://github.com/python/cpython/issues/59313.
For this reason, these methods should NEVER be called in production code.
Enable them only for debugging situations and be aware that their use may
itself cause problems. The same is true for the gc module itself.
"""
from __future__ import annotations
import gc
import sys
import types
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, TextIO
ABS_MAX_LEVEL = 10
# NOTE: In general we want this toolset to allow us to explore
# which objects are holding references to others so we can diagnose
# leaks/etc. It is a bit tricky to do that, however, without
# affecting the objects we are looking at by adding temporary references
# from module dicts, function scopes, etc. So we need to try to be
# careful about cleaning up after ourselves and explicitly avoiding
# returning these temporary references wherever possible.
# A good test is running printrefs() repeatedly on some object that is
# known to be static. If the list of references or the ids or any
# the listed references changes with each run, it's a good sign that
# we're showing some temporary objects that we should be ignoring.
def getobjs(
cls: type | str, contains: str | None = None, expanded: bool = False
) -> list[Any]:
"""Return all garbage-collected objects matching criteria.
'type' can be an actual type or a string in which case objects
whose types contain that string will be returned.
If 'contains' is provided, objects will be filtered to those
containing that in their str() representations.
"""
# Don't wanna return stuff waiting to be garbage-collected.
gc.collect()
if not isinstance(cls, type | str):
raise TypeError('Expected a type or string for cls')
if not isinstance(contains, str | None):
raise TypeError('Expected a string or None for contains')
allobjs = _get_all_objects(expanded=expanded)
if isinstance(cls, str):
objs = [o for o in allobjs if cls in str(type(o))]
else:
objs = [o for o in allobjs if isinstance(o, cls)]
if contains is not None:
objs = [o for o in objs if contains in str(o)]
return objs
# Recursively expand slists objects into olist, using seen to track
# already processed objects.
def _getr(slist: list[Any], olist: list[Any], seen: set[int]) -> None:
for obj in slist:
if id(obj) in seen:
continue
seen.add(id(obj))
olist.append(obj)
tll = gc.get_referents(obj)
if tll:
_getr(tll, olist, seen)
def _get_all_objects(expanded: bool) -> list[Any]:
"""Return an expanded list of all objects.
See https://utcc.utoronto.ca/~cks/space/blog/python/GetAllObjects
"""
gcl = gc.get_objects()
if not expanded:
return gcl
olist: list[Any] = []
seen: set[int] = set()
# Just in case:
seen.add(id(gcl))
seen.add(id(olist))
seen.add(id(seen))
# _getr does the real work.
_getr(gcl, olist, seen)
return olist
def getobj(objid: int, expanded: bool = False) -> Any:
"""Return a garbage-collected object by its id.
Remember that this is VERY inefficient and should only ever be used
for debugging.
"""
if not isinstance(objid, int):
raise TypeError(f'Expected an int for objid; got a {type(objid)}.')
# Don't wanna return stuff waiting to be garbage-collected.
gc.collect()
allobjs = _get_all_objects(expanded=expanded)
for obj in allobjs:
if id(obj) == objid:
return obj
raise RuntimeError(f'Object with id {objid} not found.')
def getrefs(obj: Any) -> list[Any]:
"""Given an object, return things referencing it."""
v = vars() # Ignore ref coming from locals.
return [o for o in gc.get_referrers(obj) if o is not v]
def printfiles(file: TextIO | None = None) -> None:
"""Print info about open files in the current app."""
import io
file = sys.stderr if file is None else file
try:
import psutil
except ImportError:
print(
"Error: printfiles requires the 'psutil' module to be installed.",
file=file,
)
return
proc = psutil.Process()
# Let's grab all Python file handles so we can associate raw files
# with their Python objects when possible.
fileio_ids = {obj.fileno(): obj for obj in getobjs(io.FileIO)}
textio_ids = {obj.fileno(): obj for obj in getobjs(io.TextIOWrapper)}
# FIXME: we could do a more limited version of this when psutil is
# not present that simply includes Python's files.
print('Files open by this app (not limited to Python\'s):', file=file)
for i, ofile in enumerate(proc.open_files()):
# Mypy doesn't know about mode apparently.
# (and can't use type: ignore because we don't require psutil
# and then mypy complains about unused ignore comment when its
# not present)
mode = getattr(ofile, 'mode')
assert isinstance(mode, str)
textio = textio_ids.get(ofile.fd)
textio_s = id(textio) if textio is not None else '<not found>'
fileio = fileio_ids.get(ofile.fd)
fileio_s = id(fileio) if fileio is not None else '<not found>'
print(
f'#{i+1}: path={ofile.path!r},'
f' fd={ofile.fd}, mode={mode!r}, TextIOWrapper={textio_s},'
f' FileIO={fileio_s}'
)
def printrefs(
obj: Any,
max_level: int = 2,
exclude_objs: list[Any] | None = None,
expand_ids: list[int] | None = None,
file: TextIO | None = None,
) -> None:
"""Print human readable list of objects referring to an object.
'max_level' specifies how many levels of recursion are printed.
'exclude_objs' can be a list of exact objects to skip if found in the
referrers list. This can be useful to avoid printing the local context
where the object was passed in from (locals(), etc).
'expand_ids' can be a list of object ids; if that particular object is
found, it will always be expanded even if max_level has been reached.
"""
_printrefs(
obj,
level=0,
max_level=max_level,
exclude_objs=[] if exclude_objs is None else exclude_objs,
expand_ids=[] if expand_ids is None else expand_ids,
file=sys.stderr if file is None else file,
)
def printtypes(
limit: int = 50, file: TextIO | None = None, expanded: bool = False
) -> None:
"""Print a human readable list of which types have the most instances."""
assert limit > 0
objtypes: dict[str, int] = {}
gc.collect() # Recommended before get_objects().
allobjs = _get_all_objects(expanded=expanded)
allobjc = len(allobjs)
for obj in allobjs:
modname = type(obj).__module__
tpname = type(obj).__qualname__
if modname != 'builtins':
tpname = f'{modname}.{tpname}'
objtypes[tpname] = objtypes.get(tpname, 0) + 1
# Presumably allobjs contains stack-frame/dict type stuff
# from this function call which in turn contain refs to allobjs.
# Let's try to prevent these huge lists from accumulating until
# the cyclical collector (hopefully) gets to them.
allobjs.clear()
del allobjs
print(f'Types most allocated ({allobjc} total objects):', file=file)
for i, tpitem in enumerate(
sorted(objtypes.items(), key=lambda x: x[1], reverse=True)[:limit]
):
tpname, tpval = tpitem
percent = tpval / allobjc * 100.0
print(f'{i+1}: {tpname}: {tpval} ({percent:.2f}%)', file=file)
def printsizes(
limit: int = 50, file: TextIO | None = None, expanded: bool = False
) -> None:
"""Print total allocated sizes of different types."""
assert limit > 0
objsizes: dict[str, int] = {}
gc.collect() # Recommended before get_objects().
allobjs = _get_all_objects(expanded=expanded)
totalobjsize = 0
for obj in allobjs:
modname = type(obj).__module__
tpname = type(obj).__qualname__
if modname != 'builtins':
tpname = f'{modname}.{tpname}'
objsize = sys.getsizeof(obj)
objsizes[tpname] = objsizes.get(tpname, 0) + objsize
totalobjsize += objsize
totalobjmb = totalobjsize / (1024 * 1024)
print(
f'Types with most allocated bytes ({totalobjmb:.2f} mb total):',
file=file,
)
for i, tpitem in enumerate(
sorted(objsizes.items(), key=lambda x: x[1], reverse=True)[:limit]
):
tpname, tpval = tpitem
percent = tpval / totalobjsize * 100.0
print(f'{i+1}: {tpname}: {tpval} ({percent:.2f}%)', file=file)
def _desctype(obj: Any) -> str:
cls = type(obj)
if cls is types.ModuleType:
return f'{type(obj).__name__} {obj.__name__}'
if cls is types.MethodType:
bnd = 'bound' if hasattr(obj, '__self__') else 'unbound'
return f'{bnd} {type(obj).__name__} {obj.__name__}'
return f'{type(obj).__name__}'
def _desc(obj: Any) -> str:
extra: str | None = None
if isinstance(obj, list | tuple):
# Print length and the first few types.
tps = [_desctype(i) for i in obj[:3]]
tpsj = ', '.join(tps)
tpss = (
f', contains [{tpsj}, ...]'
if len(obj) > 3
else f', contains [{tpsj}]'
if tps
else ''
)
extra = f' (len {len(obj)}{tpss})'
elif isinstance(obj, dict):
# If it seems to be the vars() for a type or module,
# try to identify what.
for ref in getrefs(obj):
if hasattr(ref, '__dict__') and vars(ref) is obj:
extra = f' (vars for {_desctype(ref)} @ {id(ref)})'
# Generic dict: print length and the first few key:type pairs.
if extra is None:
pairs = [
f'{repr(n)}: {_desctype(v)}' for n, v in list(obj.items())[:3]
]
pairsj = ', '.join(pairs)
pairss = (
f', contains {{{pairsj}, ...}}'
if len(obj) > 3
else f', contains {{{pairsj}}}'
if pairs
else ''
)
extra = f' (len {len(obj)}{pairss})'
if extra is None:
extra = ''
return f'{_desctype(obj)} @ {id(obj)}{extra}'
def _printrefs(
obj: Any,
level: int,
max_level: int,
exclude_objs: list,
expand_ids: list[int],
file: TextIO,
) -> None:
ind = ' ' * level
print(ind + _desc(obj), file=file)
v = vars()
if level < max_level or (id(obj) in expand_ids and level < ABS_MAX_LEVEL):
refs = getrefs(obj)
for ref in refs:
# It seems we tend to get a transient cell object with contents
# set to obj. Would be nice to understand why that happens
# but just ignoring it for now.
if isinstance(ref, types.CellType) and ref.cell_contents is obj:
continue
# Ignore anything we were asked to ignore.
if exclude_objs is not None:
if any(ref is eobj for eobj in exclude_objs):
continue
# Ignore references from our locals.
if ref is v:
continue
# The 'refs' list we just made will be listed as a referrer
# of this obj, so explicitly exclude it from the obj's listing.
_printrefs(
ref,
level=level + 1,
max_level=max_level,
exclude_objs=exclude_objs + [refs],
expand_ids=expand_ids,
file=file,
)

View file

@ -0,0 +1,36 @@
# Released under the MIT License. See LICENSE for details.
#
"""Entity functionality.
A system for defining structured data, supporting both static and runtime
type safety, serialization, efficient/sparse storage, per-field value
limits, etc. This is a heavyweight option in comparison to things such as
dataclasses, but the increased features can make the overhead worth it for
certain use cases.
Advantages compared to nested dataclasses:
- Field names separated from their data representation so can get more
concise json data, change variable names while preserving back-compat, etc.
- Can wrap and preserve unmapped data (so fields can be added to new versions
of something without breaking old versions' ability to read the data)
- Incorrectly typed data is caught at runtime (for dataclasses we rely on
type-checking and explicit validation calls)
Disadvantages compared to nested dataclasses:
- More complex to use
- Significantly more heavyweight (roughly 10 times slower in quick tests)
- Can't currently be initialized in constructors (this would probably require
a Mypy plugin to do in a type-safe way)
"""
# pylint: disable=unused-import
from efro.entity._entity import EntityMixin, Entity
from efro.entity._field import (Field, CompoundField, ListField, DictField,
CompoundListField, CompoundDictField)
from efro.entity._value import (
EnumValue, OptionalEnumValue, IntValue, OptionalIntValue, StringValue,
OptionalStringValue, BoolValue, OptionalBoolValue, FloatValue,
OptionalFloatValue, DateTimeValue, OptionalDateTimeValue, Float3Value,
CompoundValue)
from efro.entity._support import FieldInspector

133
dist/ba_data/python/efro/entity/_base.py vendored Normal file
View file

@ -0,0 +1,133 @@
# Released under the MIT License. See LICENSE for details.
#
"""Base classes for the entity system."""
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING
from efro.util import enum_by_value
if TYPE_CHECKING:
from typing import Any, Type
def dict_key_to_raw(key: Any, keytype: Type) -> Any:
"""Given a key value from the world, filter to stored key."""
if not isinstance(key, keytype):
raise TypeError(
f'Invalid key type; expected {keytype}, got {type(key)}.')
if issubclass(keytype, Enum):
val = key.value
# We convert int enums to string since that is what firestore supports.
if isinstance(val, int):
val = str(val)
return val
return key
def dict_key_from_raw(key: Any, keytype: Type) -> Any:
"""Given internal key, filter to world visible type."""
if issubclass(keytype, Enum):
# We store all enum keys as strings; if the enum uses
# int keys, convert back.
for enumval in keytype:
if isinstance(enumval.value, int):
return enum_by_value(keytype, int(key))
break
return enum_by_value(keytype, key)
return key
class DataHandler:
"""Base class for anything that can wrangle entity data.
This contains common functionality shared by Fields and Values.
"""
def get_default_data(self) -> Any:
"""Return the default internal data value for this object.
This will be inserted when initing nonexistent entity data.
"""
raise RuntimeError(f'get_default_data() unimplemented for {self}')
def filter_input(self, data: Any, error: bool) -> Any:
"""Given arbitrary input data, return valid internal data.
If error is True, exceptions should be thrown for any non-trivial
mismatch (more than just int vs float/etc.). Otherwise the invalid
data should be replaced with valid defaults and the problem noted
via the logging module.
The passed-in data can be modified in-place or returned as-is, or
completely new data can be returned. Compound types are responsible
for setting defaults and/or calling this recursively for their
children. Data that is not used by the field (such as orphaned values
in a dict field) can be left alone.
Supported types for internal data are:
- anything that works with json (lists, dicts, bools, floats, ints,
strings, None) - no tuples!
- datetime.datetime objects
"""
del error # unused
return data
def filter_output(self, data: Any) -> Any:
"""Given valid internal data, return user-facing data.
Note that entity data is expected to be filtered to correctness on
input, so if internal and extra entity data are the same type
Value types such as Vec3 may store data internally as simple float
tuples but return Vec3 objects to the user/etc. this is the mechanism
by which they do so.
"""
return data
def prune_data(self, data: Any) -> bool:
"""Prune internal data to strip out default values/etc.
Should return a bool indicating whether root data itself can be pruned.
The object is responsible for pruning any sub-fields before returning.
"""
class BaseField(DataHandler):
"""Base class for all field types."""
def __init__(self, d_key: str = None) -> None:
# Key for this field's data in parent dict/list (when applicable;
# some fields such as the child field under a list field represent
# more than a single field entry so this is unused)
self.d_key = d_key
# IMPORTANT: this method should only be overridden in the eyes of the
# type-checker (to specify exact return types). Subclasses should instead
# override get_with_data() for doing the actual work, since that method
# may sometimes be called explicitly instead of through __get__
def __get__(self, obj: Any, type_in: Any = None) -> Any:
if obj is None:
# when called on the type, we return the field
return self
return self.get_with_data(obj.d_data)
# IMPORTANT: same deal as __get__() (see note above)
def __set__(self, obj: Any, value: Any) -> None:
assert obj is not None
self.set_with_data(obj.d_data, value, error=True)
def get_with_data(self, data: Any) -> Any:
"""Get the field value given an explicit data source."""
assert self.d_key is not None
return self.filter_output(data[self.d_key])
def set_with_data(self, data: Any, value: Any, error: bool) -> Any:
"""Set the field value given an explicit data target.
If error is True, exceptions should be thrown for invalid data;
otherwise the problem should be logged but corrected.
"""
assert self.d_key is not None
data[self.d_key] = self.filter_input(value, error=error)

View file

@ -0,0 +1,222 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for the actual Entity types."""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, TypeVar
from efro.entity._support import FieldInspector, BoundCompoundValue
from efro.entity._value import CompoundValue
from efro.json import ExtendedJSONEncoder, ExtendedJSONDecoder
if TYPE_CHECKING:
from typing import Dict, Any, Type, Union, Optional
T = TypeVar('T', bound='EntityMixin')
class EntityMixin:
"""Mixin class to add data-storage to CompoundValue, forming an Entity.
Distinct Entity types should inherit from this first and a CompoundValue
(sub)type second. This order ensures that constructor arguments for this
class are accessible on the new type.
"""
def __init__(self,
d_data: Dict[str, Any] = None,
error: bool = True) -> None:
super().__init__()
if not isinstance(self, CompoundValue):
raise RuntimeError('EntityMixin class must be combined'
' with a CompoundValue class.')
# Underlying data for this entity; fields simply operate on this.
self.d_data: Dict[str, Any] = {}
assert isinstance(self, EntityMixin)
self.set_data(d_data if d_data is not None else {}, error=error)
def reset(self) -> None:
"""Resets data to default."""
self.set_data({}, error=True)
def set_data(self, data: Dict, error: bool = True) -> None:
"""Set the data for this entity and apply all value filters to it.
Note that it is more efficient to pass data to an Entity's constructor
than it is to create a default Entity and then call this on it.
"""
assert isinstance(self, CompoundValue)
self.d_data = self.filter_input(data, error=error)
def copy_data(self, target: Union[CompoundValue,
BoundCompoundValue]) -> None:
"""Copy data from a target Entity or compound-value.
This first verifies that the target has a matching set of fields
and then copies its data into ourself. To copy data into a nested
compound field, the assignment operator can be used.
"""
import copy
from efro.entity.util import have_matching_fields
tvalue: CompoundValue
if isinstance(target, CompoundValue):
tvalue = target
elif isinstance(target, BoundCompoundValue):
tvalue = target.d_value
else:
raise TypeError(
'Target must be a CompoundValue or BoundCompoundValue')
target_data = getattr(target, 'd_data', None)
if target_data is None:
raise ValueError('Target is not bound to data.')
assert isinstance(self, CompoundValue)
if not have_matching_fields(self, tvalue):
raise ValueError(
f'Fields for target {type(tvalue)} do not match ours'
f" ({type(self)}); can't copy data.")
self.d_data = copy.deepcopy(target_data)
def steal_data(self, target: EntityMixin) -> None:
"""Steal data from another entity.
This is more efficient than copy_data, as data is moved instead
of copied. However this leaves the target object in an invalid
state, and it must no longer be used after this call.
This can be convenient for entities to use to update themselves
with the result of a database transaction (which generally return
fresh entities).
"""
from efro.entity.util import have_matching_fields
if not isinstance(target, EntityMixin):
raise TypeError('EntityMixin is required.')
assert isinstance(target, CompoundValue)
assert isinstance(self, CompoundValue)
if not have_matching_fields(self, target):
raise ValueError(
f'Fields for target {type(target)} do not match ours'
f" ({type(self)}); can't steal data.")
assert target.d_data is not None
self.d_data = target.d_data
# Make sure target blows up if someone tries to use it.
# noinspection PyTypeHints
target.d_data = None # type: ignore
def pruned_data(self) -> Dict[str, Any]:
"""Return a pruned version of this instance's data.
This varies from d_data in that values may be stripped out if
they are equal to defaults (for fields with that option enabled).
"""
import copy
data = copy.deepcopy(self.d_data)
assert isinstance(self, CompoundValue)
self.prune_fields_data(data)
return data
def to_json_str(self,
prune: bool = True,
pretty: bool = False,
sort_keys_override: Optional[bool] = None) -> str:
"""Convert the entity to a json string.
This uses efro.jsontools.ExtendedJSONEncoder/Decoder
to support data types not natively storable in json.
Be sure to use the corresponding loading functions here for
this same reason.
By default, keys are sorted when pretty-printing and not otherwise,
but this can be overridden by passing a bool as sort_keys_override.
"""
if prune:
data = self.pruned_data()
else:
data = self.d_data
if pretty:
return json.dumps(
data,
indent=2,
sort_keys=(sort_keys_override
if sort_keys_override is not None else True),
cls=ExtendedJSONEncoder)
# When not doing pretty, go for quick and compact.
return json.dumps(data,
separators=(',', ':'),
sort_keys=(sort_keys_override if sort_keys_override
is not None else False),
cls=ExtendedJSONEncoder)
@staticmethod
def json_loads(s: Union[str, bytes]) -> Any:
"""Load a json string using our special extended decoder.
Note that this simply returns loaded json data; no
Entities are involved.
"""
return json.loads(s, cls=ExtendedJSONDecoder)
def load_from_json_str(self,
s: Union[str, bytes],
error: bool = True) -> None:
"""Set the entity's data in-place from a json string.
The 'error' argument determines whether Exceptions will be raised
for invalid data values. Values will be reset/conformed to valid ones
if error is False. Note that Exceptions will always be raised
in the case of invalid formatted json.
"""
data = self.json_loads(s)
self.set_data(data, error=error)
@classmethod
def from_json_str(cls: Type[T],
s: Union[str, bytes],
error: bool = True) -> T:
"""Instantiate a new instance with provided json string.
The 'error' argument determines whether exceptions will be raised
on invalid data values. Values will be reset/conformed to valid ones
if error is False. Note that exceptions will always be raised
in the case of invalid formatted json.
"""
obj = cls(d_data=cls.json_loads(s), error=error)
return obj
# Note: though d_fields actually returns a FieldInspector,
# in type-checking-land we currently just say it returns self.
# This allows the type-checker to at least validate subfield access,
# though the types will be incorrect (values instead of inspectors).
# This means that anything taking FieldInspectors needs to take 'Any'
# at the moment. Hopefully we can make this cleaner via a mypy
# plugin at some point.
if TYPE_CHECKING:
@property
def d_fields(self: T) -> T:
"""For accessing entity field objects (as opposed to values)."""
...
else:
@property
def d_fields(self):
"""For accessing entity field objects (as opposed to values)."""
return FieldInspector(self, self, [], [])
class Entity(EntityMixin, CompoundValue):
"""A data class consisting of Fields and their underlying data.
Fields and Values simply define a data layout; Entities are concrete
objects using those layouts.
Inherit from this class and add Fields to define a simple Entity type.
Alternately, combine an EntityMixin with any CompoundValue child class
to accomplish the same. The latter allows sharing CompoundValue
layouts between different concrete Entity types. For example, a
'Weapon' CompoundValue could be embedded as part of a 'Character'
Entity but also exist as a distinct 'WeaponEntity' in an armory
database.
"""

View file

@ -0,0 +1,602 @@
# Released under the MIT License. See LICENSE for details.
#
"""Field types for the entity system."""
from __future__ import annotations
import copy
import logging
from enum import Enum
from typing import TYPE_CHECKING, Generic, TypeVar, overload
# from efro.util import enum_by_value
from efro.entity._base import BaseField, dict_key_to_raw, dict_key_from_raw
from efro.entity._support import (BoundCompoundValue, BoundListField,
BoundDictField, BoundCompoundListField,
BoundCompoundDictField)
from efro.entity.util import have_matching_fields
if TYPE_CHECKING:
from typing import Dict, Type, List, Any
from efro.entity._value import TypedValue, CompoundValue
T = TypeVar('T')
TK = TypeVar('TK')
TC = TypeVar('TC', bound='CompoundValue')
class Field(BaseField, Generic[T]):
"""Field consisting of a single value."""
def __init__(self,
d_key: str,
value: TypedValue[T],
store_default: bool = True) -> None:
super().__init__(d_key)
self.d_value = value
self._store_default = store_default
def __repr__(self) -> str:
return f'<Field "{self.d_key}" with {self.d_value}>'
def get_default_data(self) -> Any:
return self.d_value.get_default_data()
def filter_input(self, data: Any, error: bool) -> Any:
return self.d_value.filter_input(data, error)
def filter_output(self, data: Any) -> Any:
return self.d_value.filter_output(data)
def prune_data(self, data: Any) -> bool:
return self.d_value.prune_data(data)
if TYPE_CHECKING:
# Use default runtime get/set but let type-checker know our types.
# Note: we actually return a bound-field when accessed on
# a type instead of an instance, but we don't reflect that here yet
# (would need to write a mypy plugin so sub-field access works first)
@overload
def __get__(self, obj: None, cls: Any = None) -> Field[T]:
...
@overload
def __get__(self, obj: Any, cls: Any = None) -> T:
...
def __get__(self, obj: Any, cls: Any = None) -> Any:
...
def __set__(self, obj: Any, value: T) -> None:
...
class CompoundField(BaseField, Generic[TC]):
"""Field consisting of a single compound value."""
def __init__(self,
d_key: str,
value: TC,
store_default: bool = True) -> None:
super().__init__(d_key)
if __debug__:
from efro.entity._value import CompoundValue
assert isinstance(value, CompoundValue)
assert not hasattr(value, 'd_data')
self.d_value = value
self._store_default = store_default
def get_default_data(self) -> dict:
return self.d_value.get_default_data()
def filter_input(self, data: Any, error: bool) -> dict:
return self.d_value.filter_input(data, error)
def prune_data(self, data: Any) -> bool:
return self.d_value.prune_data(data)
# Note:
# Currently, to the type-checker we just return a simple instance
# of our CompoundValue so it can properly type-check access to its
# attrs. However at runtime we return a FieldInspector or
# BoundCompoundField which both use magic to provide the same attrs
# dynamically (but which the type-checker doesn't understand).
# Perhaps at some point we can write a mypy plugin to correct this.
if TYPE_CHECKING:
def __get__(self, obj: Any, cls: Any = None) -> TC:
...
# Theoretically this type-checking may be too tight;
# we can support assigning a parent class to a child class if
# their fields match. Not sure if that'll ever come up though;
# gonna leave this for now as I prefer to have *some* checking.
# Also once we get BoundCompoundValues working with mypy we'll
# need to accept those too.
def __set__(self: CompoundField[TC], obj: Any, value: TC) -> None:
...
def get_with_data(self, data: Any) -> Any:
assert self.d_key in data
return BoundCompoundValue(self.d_value, data[self.d_key])
def set_with_data(self, data: Any, value: Any, error: bool) -> Any:
from efro.entity._value import CompoundValue
# Ok here's the deal: our type checking above allows any subtype
# of our CompoundValue in here, but we want to be more picky than
# that. Let's check fields for equality. This way we'll allow
# assigning something like a Carentity to a Car field
# (where the data is the same), but won't allow assigning a Car
# to a Vehicle field (as Car probably adds more fields).
value1: CompoundValue
if isinstance(value, BoundCompoundValue):
value1 = value.d_value
elif isinstance(value, CompoundValue):
value1 = value
else:
raise ValueError(f"Can't assign from object type {type(value)}")
dataval = getattr(value, 'd_data', None)
if dataval is None:
raise ValueError(f"Can't assign from unbound object {value}")
if self.d_value.get_fields() != value1.get_fields():
raise ValueError(f"Can't assign to {self.d_value} from"
f' incompatible type {value.d_value}; '
f'sub-fields do not match.')
# If we're allowing this to go through, we can simply copy the
# data from the passed in value. The fields match so it should
# be in a valid state already.
data[self.d_key] = copy.deepcopy(dataval)
class ListField(BaseField, Generic[T]):
"""Field consisting of repeated values."""
def __init__(self,
d_key: str,
value: TypedValue[T],
store_default: bool = True) -> None:
super().__init__(d_key)
self.d_value = value
self._store_default = store_default
def get_default_data(self) -> list:
return []
def filter_input(self, data: Any, error: bool) -> Any:
# If we were passed a BoundListField, operate on its raw values
if isinstance(data, BoundListField):
data = data.d_data
if not isinstance(data, list):
if error:
raise TypeError(f'list value expected; got {type(data)}')
logging.error('Ignoring non-list data for %s: %s', self, data)
data = []
for i, entry in enumerate(data):
data[i] = self.d_value.filter_input(entry, error=error)
return data
def prune_data(self, data: Any) -> bool:
# We never prune individual values since that would fundamentally
# change the list, but we can prune completely if empty (and allowed).
return not data and not self._store_default
# When accessed on a FieldInspector we return a sub-field FieldInspector.
# When accessed on an instance we return a BoundListField.
if TYPE_CHECKING:
# Access via type gives our field; via an instance gives a bound field.
@overload
def __get__(self, obj: None, cls: Any = None) -> ListField[T]:
...
@overload
def __get__(self, obj: Any, cls: Any = None) -> BoundListField[T]:
...
def __get__(self, obj: Any, cls: Any = None) -> Any:
...
# Allow setting via a raw value list or a bound list field
@overload
def __set__(self, obj: Any, value: List[T]) -> None:
...
@overload
def __set__(self, obj: Any, value: BoundListField[T]) -> None:
...
def __set__(self, obj: Any, value: Any) -> None:
...
def get_with_data(self, data: Any) -> Any:
return BoundListField(self, data[self.d_key])
class DictField(BaseField, Generic[TK, T]):
"""A field of values in a dict with a specified index type."""
def __init__(self,
d_key: str,
keytype: Type[TK],
field: TypedValue[T],
store_default: bool = True) -> None:
super().__init__(d_key)
self.d_value = field
self._store_default = store_default
self._keytype = keytype
def get_default_data(self) -> dict:
return {}
def filter_input(self, data: Any, error: bool) -> Any:
# If we were passed a BoundDictField, operate on its raw values
if isinstance(data, BoundDictField):
data = data.d_data
if not isinstance(data, dict):
if error:
raise TypeError('dict value expected')
logging.error('Ignoring non-dict data for %s: %s', self, data)
data = {}
data_out = {}
for key, val in data.items():
# For enum keys, make sure its a valid enum.
if issubclass(self._keytype, Enum):
# Our input data can either be an enum or the underlying type.
if isinstance(key, self._keytype):
key = dict_key_to_raw(key, self._keytype)
# key = key.value
else:
try:
_enumval = dict_key_from_raw(key, self._keytype)
# _enumval = enum_by_value(self._keytype, key)
except Exception as exc:
if error:
raise ValueError(
f'No enum of type {self._keytype}'
f' exists with value {key}') from exc
logging.error('Ignoring invalid key type for %s: %s',
self, data)
continue
# For all other keys we can check for exact types.
elif not isinstance(key, self._keytype):
if error:
raise TypeError(
f'Invalid key type; expected {self._keytype},'
f' got {type(key)}.')
logging.error('Ignoring invalid key type for %s: %s', self,
data)
continue
data_out[key] = self.d_value.filter_input(val, error=error)
return data_out
def prune_data(self, data: Any) -> bool:
# We never prune individual values since that would fundamentally
# change the dict, but we can prune completely if empty (and allowed)
return not data and not self._store_default
if TYPE_CHECKING:
# Return our field if accessed via type and bound-dict-field
# if via instance.
@overload
def __get__(self, obj: None, cls: Any = None) -> DictField[TK, T]:
...
@overload
def __get__(self, obj: Any, cls: Any = None) -> BoundDictField[TK, T]:
...
def __get__(self, obj: Any, cls: Any = None) -> Any:
...
# Allow setting via matching dict values or BoundDictFields
@overload
def __set__(self, obj: Any, value: Dict[TK, T]) -> None:
...
@overload
def __set__(self, obj: Any, value: BoundDictField[TK, T]) -> None:
...
def __set__(self, obj: Any, value: Any) -> None:
...
def get_with_data(self, data: Any) -> Any:
return BoundDictField(self._keytype, self, data[self.d_key])
class CompoundListField(BaseField, Generic[TC]):
"""A field consisting of repeated instances of a compound-value.
Element access returns the sub-field, allowing nested field access.
ie: mylist[10].fieldattr = 'foo'
"""
def __init__(self,
d_key: str,
valuetype: TC,
store_default: bool = True) -> None:
super().__init__(d_key)
self.d_value = valuetype
# This doesnt actually exist for us, but want the type-checker
# to think it does (see TYPE_CHECKING note below).
self.d_data: Any
self._store_default = store_default
def filter_input(self, data: Any, error: bool) -> list:
if not isinstance(data, list):
if error:
raise TypeError('list value expected')
logging.error('Ignoring non-list data for %s: %s', self, data)
data = []
assert isinstance(data, list)
# Ok we've got a list; now run everything in it through validation.
for i, subdata in enumerate(data):
data[i] = self.d_value.filter_input(subdata, error=error)
return data
def get_default_data(self) -> list:
return []
def prune_data(self, data: Any) -> bool:
# Run pruning on all individual entries' data through out child field.
# However we don't *completely* prune values from the list since that
# would change it.
for subdata in data:
self.d_value.prune_fields_data(subdata)
# We can also optionally prune the whole list if empty and allowed.
return not data and not self._store_default
if TYPE_CHECKING:
@overload
def __get__(self, obj: None, cls: Any = None) -> CompoundListField[TC]:
...
@overload
def __get__(self,
obj: Any,
cls: Any = None) -> BoundCompoundListField[TC]:
...
def __get__(self, obj: Any, cls: Any = None) -> Any:
...
# Note:
# When setting the list, we tell the type-checker that we also accept
# a raw list of CompoundValue objects, but at runtime we actually
# always deal with BoundCompoundValue objects (see note in
# BoundCompoundListField for why we accept CompoundValue objs)
@overload
def __set__(self, obj: Any, value: List[TC]) -> None:
...
@overload
def __set__(self, obj: Any, value: BoundCompoundListField[TC]) -> None:
...
def __set__(self, obj: Any, value: Any) -> None:
...
def get_with_data(self, data: Any) -> Any:
assert self.d_key in data
return BoundCompoundListField(self, data[self.d_key])
def set_with_data(self, data: Any, value: Any, error: bool) -> Any:
# If we were passed a BoundCompoundListField,
# simply convert it to a flat list of BoundCompoundValue objects which
# is what we work with natively here.
if isinstance(value, BoundCompoundListField):
value = list(value)
if not isinstance(value, list):
raise TypeError(f'CompoundListField expected list value on set;'
f' got {type(value)}.')
# Allow assigning only from a sequence of our existing children.
# (could look into expanding this to other children if we can
# be sure the underlying data will line up; for example two
# CompoundListFields with different child_field values should not
# be inter-assignable.
if not all(isinstance(i, BoundCompoundValue) for i in value):
raise ValueError('CompoundListField assignment must be a '
'list containing only BoundCompoundValue objs.')
# Make sure the data all has the same CompoundValue type and
# compare that type against ours once to make sure its fields match.
# (this will not allow passing CompoundValues from multiple sources
# but I don't know if that would ever come up..)
for i, val in enumerate(value):
if i == 0:
# Do the full field comparison on the first value only..
if not have_matching_fields(val.d_value, self.d_value):
raise ValueError(
'CompoundListField assignment must be a '
'list containing matching CompoundValues.')
else:
# For all remaining values, just ensure they match the first.
if val.d_value is not value[0].d_value:
raise ValueError(
'CompoundListField assignment cannot contain '
'multiple CompoundValue types as sources.')
data[self.d_key] = self.filter_input([i.d_data for i in value],
error=error)
class CompoundDictField(BaseField, Generic[TK, TC]):
"""A field consisting of key-indexed instances of a compound-value.
Element access returns the sub-field, allowing nested field access.
ie: mylist[10].fieldattr = 'foo'
"""
def __init__(self,
d_key: str,
keytype: Type[TK],
valuetype: TC,
store_default: bool = True) -> None:
super().__init__(d_key)
self.d_value = valuetype
# This doesnt actually exist for us, but want the type-checker
# to think it does (see TYPE_CHECKING note below).
self.d_data: Any
self.d_keytype = keytype
self._store_default = store_default
def filter_input(self, data: Any, error: bool) -> dict:
if not isinstance(data, dict):
if error:
raise TypeError('dict value expected')
logging.error('Ignoring non-dict data for %s: %s', self, data)
data = {}
data_out = {}
for key, val in data.items():
# For enum keys, make sure its a valid enum.
if issubclass(self.d_keytype, Enum):
# Our input data can either be an enum or the underlying type.
if isinstance(key, self.d_keytype):
key = dict_key_to_raw(key, self.d_keytype)
# key = key.value
else:
try:
_enumval = dict_key_from_raw(key, self.d_keytype)
# _enumval = enum_by_value(self.d_keytype, key)
except Exception as exc:
if error:
raise ValueError(
f'No enum of type {self.d_keytype}'
f' exists with value {key}') from exc
logging.error('Ignoring invalid key type for %s: %s',
self, data)
continue
# For all other keys we can check for exact types.
elif not isinstance(key, self.d_keytype):
if error:
raise TypeError(
f'Invalid key type; expected {self.d_keytype},'
f' got {type(key)}.')
logging.error('Ignoring invalid key type for %s: %s', self,
data)
continue
data_out[key] = self.d_value.filter_input(val, error=error)
return data_out
def get_default_data(self) -> dict:
return {}
def prune_data(self, data: Any) -> bool:
# Run pruning on all individual entries' data through our child field.
# However we don't *completely* prune values from the list since that
# would change it.
for subdata in data.values():
self.d_value.prune_fields_data(subdata)
# We can also optionally prune the whole list if empty and allowed.
return not data and not self._store_default
# ONLY overriding these in type-checker land to clarify types.
# (see note in BaseField)
if TYPE_CHECKING:
@overload
def __get__(self,
obj: None,
cls: Any = None) -> CompoundDictField[TK, TC]:
...
@overload
def __get__(self,
obj: Any,
cls: Any = None) -> BoundCompoundDictField[TK, TC]:
...
def __get__(self, obj: Any, cls: Any = None) -> Any:
...
# Note:
# When setting the dict, we tell the type-checker that we also accept
# a raw dict of CompoundValue objects, but at runtime we actually
# always deal with BoundCompoundValue objects (see note in
# BoundCompoundDictField for why we accept CompoundValue objs)
@overload
def __set__(self, obj: Any, value: Dict[TK, TC]) -> None:
...
@overload
def __set__(self, obj: Any, value: BoundCompoundDictField[TK,
TC]) -> None:
...
def __set__(self, obj: Any, value: Any) -> None:
...
def get_with_data(self, data: Any) -> Any:
assert self.d_key in data
return BoundCompoundDictField(self, data[self.d_key])
def set_with_data(self, data: Any, value: Any, error: bool) -> Any:
# If we were passed a BoundCompoundDictField,
# simply convert it to a flat dict of BoundCompoundValue objects which
# is what we work with natively here.
if isinstance(value, BoundCompoundDictField):
value = dict(value.items())
if not isinstance(value, dict):
raise TypeError('CompoundDictField expected dict value on set.')
# Allow assigning only from a sequence of our existing children.
# (could look into expanding this to other children if we can
# be sure the underlying data will line up; for example two
# CompoundListFields with different child_field values should not
# be inter-assignable.
if (not all(isinstance(i, BoundCompoundValue)
for i in value.values())):
raise ValueError('CompoundDictField assignment must be a '
'dict containing only BoundCompoundValues.')
# Make sure the data all has the same CompoundValue type and
# compare that type against ours once to make sure its fields match.
# (this will not allow passing CompoundValues from multiple sources
# but I don't know if that would ever come up..)
first_value: Any = None
for i, val in enumerate(value.values()):
if i == 0:
first_value = val.d_value
# Do the full field comparison on the first value only..
if not have_matching_fields(val.d_value, self.d_value):
raise ValueError(
'CompoundListField assignment must be a '
'list containing matching CompoundValues.')
else:
# For all remaining values, just ensure they match the first.
if val.d_value is not first_value:
raise ValueError(
'CompoundListField assignment cannot contain '
'multiple CompoundValue types as sources.')
data[self.d_key] = self.filter_input(
{key: val.d_data
for key, val in value.items()}, error=error)

View file

@ -0,0 +1,468 @@
# Released under the MIT License. See LICENSE for details.
#
"""Various support classes for accessing data and info on fields and values."""
from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar, Generic, overload
from efro.entity._base import (BaseField, dict_key_to_raw, dict_key_from_raw)
if TYPE_CHECKING:
from typing import (Optional, Tuple, Type, Any, Dict, List, Union)
from efro.entity._value import CompoundValue
from efro.entity._field import (ListField, DictField, CompoundListField,
CompoundDictField)
T = TypeVar('T')
TKey = TypeVar('TKey')
TCompound = TypeVar('TCompound', bound='CompoundValue')
TBoundList = TypeVar('TBoundList', bound='BoundCompoundListField')
class BoundCompoundValue:
"""Wraps a CompoundValue object and its entity data.
Allows access to its values through our own equivalent attributes.
"""
def __init__(self, value: CompoundValue, d_data: Union[List[Any],
Dict[str, Any]]):
self.d_value: CompoundValue
self.d_data: Union[List[Any], Dict[str, Any]]
# Need to use base setters to avoid triggering our own overrides.
object.__setattr__(self, 'd_value', value)
object.__setattr__(self, 'd_data', d_data)
def __eq__(self, other: Any) -> Any:
# Allow comparing to compound and bound-compound objects.
from efro.entity.util import compound_eq
return compound_eq(self, other)
def __getattr__(self, name: str, default: Any = None) -> Any:
# If this attribute corresponds to a field on our compound value's
# unbound type, ask it to give us a value using our data
d_value = type(object.__getattribute__(self, 'd_value'))
field = getattr(d_value, name, None)
if isinstance(field, BaseField):
return field.get_with_data(self.d_data)
raise AttributeError
def __setattr__(self, name: str, value: Any) -> None:
# Same deal as __getattr__ basically.
field = getattr(type(object.__getattribute__(self, 'd_value')), name,
None)
if isinstance(field, BaseField):
field.set_with_data(self.d_data, value, error=True)
return
super().__setattr__(name, value)
def reset(self) -> None:
"""Reset this field's data to defaults."""
value = object.__getattribute__(self, 'd_value')
data = object.__getattribute__(self, 'd_data')
assert isinstance(data, dict)
# Need to clear our dict in-place since we have no
# access to our parent which we'd need to assign an empty one.
data.clear()
# Now fill in default data.
value.apply_fields_to_data(data, error=True)
def __repr__(self) -> str:
fstrs: List[str] = []
for field in self.d_value.get_fields():
try:
fstrs.append(str(field) + '=' + repr(getattr(self, field)))
except Exception:
fstrs.append('FAIL' + str(field) + ' ' + str(type(self)))
return type(self.d_value).__name__ + '(' + ', '.join(fstrs) + ')'
class FieldInspector:
"""Used for inspecting fields."""
def __init__(self, root: Any, obj: Any, path: List[str],
dbpath: List[str]) -> None:
self._root = root
self._obj = obj
self._path = path
self._dbpath = dbpath
def __repr__(self) -> str:
path = '.'.join(self._path)
typename = type(self._root).__name__
if path == '':
return f'<FieldInspector: {typename}>'
return f'<FieldInspector: {typename}: {path}>'
def __getattr__(self, name: str, default: Any = None) -> Any:
# pylint: disable=cyclic-import
from efro.entity._field import CompoundField
# If this attribute corresponds to a field on our obj's
# unbound type, return a new inspector for it.
if isinstance(self._obj, CompoundField):
target = self._obj.d_value
else:
target = self._obj
field = getattr(type(target), name, None)
if isinstance(field, BaseField):
newpath = list(self._path)
newpath.append(name)
newdbpath = list(self._dbpath)
assert field.d_key is not None
newdbpath.append(field.d_key)
return FieldInspector(self._root, field, newpath, newdbpath)
raise AttributeError
def get_root(self) -> Any:
"""Return the root object this inspector is targeting."""
return self._root
def get_path(self) -> List[str]:
"""Return the python path components of this inspector."""
return self._path
def get_db_path(self) -> List[str]:
"""Return the database path components of this inspector."""
return self._dbpath
class BoundListField(Generic[T]):
"""ListField bound to data; used for accessing field values."""
def __init__(self, field: ListField[T], d_data: List[Any]):
self.d_field = field
assert isinstance(d_data, list)
self.d_data = d_data
self._i = 0
def __eq__(self, other: Any) -> Any:
# Just convert us into a regular list and run a compare with that.
flattened = [
self.d_field.d_value.filter_output(value) for value in self.d_data
]
return flattened == other
def __repr__(self) -> str:
return '[' + ', '.join(
repr(self.d_field.d_value.filter_output(i))
for i in self.d_data) + ']'
def __len__(self) -> int:
return len(self.d_data)
def __iter__(self) -> Any:
self._i = 0
return self
def append(self, val: T) -> None:
"""Append the provided value to the list."""
self.d_data.append(self.d_field.d_value.filter_input(val, error=True))
def __next__(self) -> T:
if self._i < len(self.d_data):
self._i += 1
val: T = self.d_field.d_value.filter_output(self.d_data[self._i -
1])
return val
raise StopIteration
@overload
def __getitem__(self, key: int) -> T:
...
@overload
def __getitem__(self, key: slice) -> List[T]:
...
def __getitem__(self, key: Any) -> Any:
if isinstance(key, slice):
dofilter = self.d_field.d_value.filter_output
return [
dofilter(self.d_data[i])
for i in range(*key.indices(len(self)))
]
assert isinstance(key, int)
return self.d_field.d_value.filter_output(self.d_data[key])
def __setitem__(self, key: int, value: T) -> None:
if not isinstance(key, int):
raise TypeError('Expected int index.')
self.d_data[key] = self.d_field.d_value.filter_input(value, error=True)
class BoundDictField(Generic[TKey, T]):
"""DictField bound to its data; used for accessing its values."""
def __init__(self, keytype: Type[TKey], field: DictField[TKey, T],
d_data: Dict[TKey, T]):
self._keytype = keytype
self.d_field = field
assert isinstance(d_data, dict)
self.d_data = d_data
def __eq__(self, other: Any) -> Any:
# Just convert us into a regular dict and run a compare with that.
flattened = {
key: self.d_field.d_value.filter_output(value)
for key, value in self.d_data.items()
}
return flattened == other
def __repr__(self) -> str:
return '{' + ', '.join(
repr(dict_key_from_raw(key, self._keytype)) + ': ' +
repr(self.d_field.d_value.filter_output(val))
for key, val in self.d_data.items()) + '}'
def __len__(self) -> int:
return len(self.d_data)
def __getitem__(self, key: TKey) -> T:
keyfilt = dict_key_to_raw(key, self._keytype)
typedval: T = self.d_field.d_value.filter_output(self.d_data[keyfilt])
return typedval
def get(self, key: TKey, default: Optional[T] = None) -> Optional[T]:
"""Get a value if present, or a default otherwise."""
keyfilt = dict_key_to_raw(key, self._keytype)
if keyfilt not in self.d_data:
return default
typedval: T = self.d_field.d_value.filter_output(self.d_data[keyfilt])
return typedval
def __setitem__(self, key: TKey, value: T) -> None:
keyfilt = dict_key_to_raw(key, self._keytype)
self.d_data[keyfilt] = self.d_field.d_value.filter_input(value,
error=True)
def __contains__(self, key: TKey) -> bool:
keyfilt = dict_key_to_raw(key, self._keytype)
return keyfilt in self.d_data
def __delitem__(self, key: TKey) -> None:
keyfilt = dict_key_to_raw(key, self._keytype)
del self.d_data[keyfilt]
def keys(self) -> List[TKey]:
"""Return a list of our keys."""
return [
dict_key_from_raw(k, self._keytype) for k in self.d_data.keys()
]
def values(self) -> List[T]:
"""Return a list of our values."""
return [
self.d_field.d_value.filter_output(value)
for value in self.d_data.values()
]
def items(self) -> List[Tuple[TKey, T]]:
"""Return a list of item/value pairs."""
return [(dict_key_from_raw(key, self._keytype),
self.d_field.d_value.filter_output(value))
for key, value in self.d_data.items()]
class BoundCompoundListField(Generic[TCompound]):
"""A CompoundListField bound to its entity sub-data."""
def __init__(self, field: CompoundListField[TCompound], d_data: List[Any]):
self.d_field = field
self.d_data = d_data
self._i = 0
def __eq__(self, other: Any) -> Any:
from efro.entity.util import have_matching_fields
# We can only be compared to other bound-compound-fields
if not isinstance(other, BoundCompoundListField):
return NotImplemented
# If our compound values have differing fields, we're unequal.
if not have_matching_fields(self.d_field.d_value,
other.d_field.d_value):
return False
# Ok our data schemas match; now just compare our data..
return self.d_data == other.d_data
def __len__(self) -> int:
return len(self.d_data)
def __repr__(self) -> str:
return '[' + ', '.join(
repr(BoundCompoundValue(self.d_field.d_value, i))
for i in self.d_data) + ']'
# Note: to the type checker our gets/sets simply deal with CompoundValue
# objects so the type-checker can cleanly handle their sub-fields.
# However at runtime we deal in BoundCompoundValue objects which use magic
# to tie the CompoundValue object to its data but which the type checker
# can't understand.
if TYPE_CHECKING:
@overload
def __getitem__(self, key: int) -> TCompound:
...
@overload
def __getitem__(self, key: slice) -> List[TCompound]:
...
def __getitem__(self, key: Any) -> Any:
...
def __next__(self) -> TCompound:
...
def append(self) -> TCompound:
"""Append and return a new field entry to the array."""
...
else:
def __getitem__(self, key: Any) -> Any:
if isinstance(key, slice):
return [
BoundCompoundValue(self.d_field.d_value, self.d_data[i])
for i in range(*key.indices(len(self)))
]
assert isinstance(key, int)
return BoundCompoundValue(self.d_field.d_value, self.d_data[key])
def __next__(self):
if self._i < len(self.d_data):
self._i += 1
return BoundCompoundValue(self.d_field.d_value,
self.d_data[self._i - 1])
raise StopIteration
def append(self) -> Any:
"""Append and return a new field entry to the array."""
# push the entity default into data and then let it fill in
# any children/etc.
self.d_data.append(
self.d_field.d_value.filter_input(
self.d_field.d_value.get_default_data(), error=True))
return BoundCompoundValue(self.d_field.d_value, self.d_data[-1])
def __iter__(self: TBoundList) -> TBoundList:
self._i = 0
return self
class BoundCompoundDictField(Generic[TKey, TCompound]):
"""A CompoundDictField bound to its entity sub-data."""
def __init__(self, field: CompoundDictField[TKey, TCompound],
d_data: Dict[Any, Any]):
self.d_field = field
self.d_data = d_data
def __eq__(self, other: Any) -> Any:
from efro.entity.util import have_matching_fields
# We can only be compared to other bound-compound-fields
if not isinstance(other, BoundCompoundDictField):
return NotImplemented
# If our compound values have differing fields, we're unequal.
if not have_matching_fields(self.d_field.d_value,
other.d_field.d_value):
return False
# Ok our data schemas match; now just compare our data..
return self.d_data == other.d_data
def __repr__(self) -> str:
return '{' + ', '.join(
repr(key) + ': ' +
repr(BoundCompoundValue(self.d_field.d_value, value))
for key, value in self.d_data.items()) + '}'
# In the typechecker's eyes, gets/sets on us simply deal in
# CompoundValue object. This allows type-checking to work nicely
# for its sub-fields.
# However in real-life we return BoundCompoundValues which use magic
# to tie the CompoundValue to its data (but which the typechecker
# would not be able to make sense of)
if TYPE_CHECKING:
def get(self, key: TKey) -> Optional[TCompound]:
"""Return a value if present; otherwise None."""
def __getitem__(self, key: TKey) -> TCompound:
...
def values(self) -> List[TCompound]:
"""Return a list of our values."""
def items(self) -> List[Tuple[TKey, TCompound]]:
"""Return key/value pairs for all dict entries."""
def add(self, key: TKey) -> TCompound:
"""Add an entry into the dict, returning it.
Any existing value is replaced."""
else:
def get(self, key):
"""return a value if present; otherwise None."""
keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
data = self.d_data.get(keyfilt)
if data is not None:
return BoundCompoundValue(self.d_field.d_value, data)
return None
def __getitem__(self, key):
keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
return BoundCompoundValue(self.d_field.d_value,
self.d_data[keyfilt])
def values(self):
"""Return a list of our values."""
return list(
BoundCompoundValue(self.d_field.d_value, i)
for i in self.d_data.values())
def items(self):
"""Return key/value pairs for all dict entries."""
return [(dict_key_from_raw(key, self.d_field.d_keytype),
BoundCompoundValue(self.d_field.d_value, value))
for key, value in self.d_data.items()]
def add(self, key: TKey) -> TCompound:
"""Add an entry into the dict, returning it.
Any existing value is replaced."""
keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
# Push the entity default into data and then let it fill in
# any children/etc.
self.d_data[keyfilt] = (self.d_field.d_value.filter_input(
self.d_field.d_value.get_default_data(), error=True))
return BoundCompoundValue(self.d_field.d_value,
self.d_data[keyfilt])
def __len__(self) -> int:
return len(self.d_data)
def __contains__(self, key: TKey) -> bool:
keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
return keyfilt in self.d_data
def __delitem__(self, key: TKey) -> None:
keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
del self.d_data[keyfilt]
def keys(self) -> List[TKey]:
"""Return a list of our keys."""
return [
dict_key_from_raw(k, self.d_field.d_keytype)
for k in self.d_data.keys()
]

View file

@ -0,0 +1,537 @@
# Released under the MIT License. See LICENSE for details.
#
"""Value types for the entity system."""
from __future__ import annotations
import datetime
import inspect
import logging
from collections import abc
from enum import Enum
from typing import TYPE_CHECKING, TypeVar, Generic
# Our Pylint class_generics_filter gives us a false-positive unused-import.
from typing import Tuple, Optional # pylint: disable=W0611
from efro.entity._base import DataHandler, BaseField
from efro.entity.util import compound_eq
if TYPE_CHECKING:
from typing import Optional, Set, List, Dict, Any, Type
T = TypeVar('T')
TE = TypeVar('TE', bound=Enum)
_sanity_tested_types: Set[Type] = set()
_type_field_cache: Dict[Type, Dict[str, BaseField]] = {}
class TypedValue(DataHandler, Generic[T]):
"""Base class for all value types dealing with a single data type."""
class SimpleValue(TypedValue[T]):
"""Standard base class for simple single-value types.
This class provides enough functionality to handle most simple
types such as int/float/etc without too many subclass overrides.
"""
def __init__(self,
default: T,
store_default: bool,
target_type: Type = None,
convert_source_types: Tuple[Type, ...] = (),
allow_none: bool = False) -> None:
"""Init the value field.
If store_default is False, the field value will not be included
in final entity data if it is a default value. Be sure to set
this to True for any fields that will be used for server-side
queries so they are included in indexing.
target_type and convert_source_types are used in the default
filter_input implementation; if passed in data's type is present
in convert_source_types, a target_type will be instantiated
using it. (allows for simple conversions to bool, int, etc)
Data will also be allowed through untouched if it matches target_type.
(types needing further introspection should override filter_input).
Lastly, the value of allow_none is also used in filter_input for
whether values of None should be allowed.
"""
super().__init__()
self._store_default = store_default
self._target_type = target_type
self._convert_source_types = convert_source_types
self._allow_none = allow_none
# We store _default_data in our internal data format so need
# to run user-facing values through our input filter.
# Make sure we do this last since filter_input depends on above vals.
self._default_data: T = self.filter_input(default, error=True)
def __repr__(self) -> str:
if self._target_type is not None:
return f'<Value of type {self._target_type.__name__}>'
return '<Value of unknown type>'
def get_default_data(self) -> Any:
return self._default_data
def prune_data(self, data: Any) -> bool:
return not self._store_default and data == self._default_data
def filter_input(self, data: Any, error: bool) -> Any:
# Let data pass through untouched if its already our target type
if self._target_type is not None:
if isinstance(data, self._target_type):
return data
# ...and also if its None and we're into that sort of thing.
if self._allow_none and data is None:
return data
# If its one of our convertible types, convert.
if (self._convert_source_types
and isinstance(data, self._convert_source_types)):
assert self._target_type is not None
return self._target_type(data)
if error:
errmsg = (f'value of type {self._target_type} or None expected'
if self._allow_none else
f'value of type {self._target_type} expected')
errmsg += f'; got {type(data)}'
raise TypeError(errmsg)
errmsg = f'Ignoring incompatible data for {self};'
errmsg += (f' expected {self._target_type} or None;'
if self._allow_none else f'expected {self._target_type};')
errmsg += f' got {type(data)}'
logging.error(errmsg)
return self.get_default_data()
class StringValue(SimpleValue[str]):
"""Value consisting of a single string."""
def __init__(self, default: str = '', store_default: bool = True) -> None:
super().__init__(default, store_default, str)
class OptionalStringValue(SimpleValue[Optional[str]]):
"""Value consisting of a single string or None."""
def __init__(self,
default: Optional[str] = None,
store_default: bool = True) -> None:
super().__init__(default, store_default, str, allow_none=True)
class BoolValue(SimpleValue[bool]):
"""Value consisting of a single bool."""
def __init__(self,
default: bool = False,
store_default: bool = True) -> None:
super().__init__(default, store_default, bool, (int, float))
class OptionalBoolValue(SimpleValue[Optional[bool]]):
"""Value consisting of a single bool or None."""
def __init__(self,
default: Optional[bool] = None,
store_default: bool = True) -> None:
super().__init__(default,
store_default,
bool, (int, float),
allow_none=True)
def verify_time_input(data: Any, error: bool, allow_none: bool) -> Any:
"""Checks input data for time values."""
pytz_utc: Any
# We don't *require* pytz since it must be installed through pip
# but it is used by firestore client for its date values
# (in which case it should be installed as a dependency anyway).
try:
import pytz
pytz_utc = pytz.utc
except ModuleNotFoundError:
pytz_utc = None
# Filter unallowed None values.
if not allow_none and data is None:
if error:
raise ValueError('datetime value cannot be None')
logging.error('ignoring datetime value of None')
data = (None if allow_none else datetime.datetime.now(
datetime.timezone.utc))
# Parent filter_input does what we need, but let's just make
# sure we *only* accept datetime values that know they're UTC.
elif (isinstance(data, datetime.datetime)
and data.tzinfo is not datetime.timezone.utc
and (pytz_utc is None or data.tzinfo is not pytz_utc)):
if error:
raise ValueError(
'datetime values must have timezone set as timezone.utc')
logging.error(
'ignoring datetime value without timezone.utc set: %s %s',
type(datetime.timezone.utc), type(data.tzinfo))
data = (None if allow_none else datetime.datetime.now(
datetime.timezone.utc))
return data
class DateTimeValue(SimpleValue[datetime.datetime]):
"""Value consisting of a datetime.datetime object.
The default value for this is always the current time in UTC.
"""
def __init__(self, store_default: bool = True) -> None:
# Pass dummy datetime value as default just to satisfy constructor;
# we override get_default_data though so this doesn't get used.
dummy_default = datetime.datetime.now(datetime.timezone.utc)
super().__init__(dummy_default, store_default, datetime.datetime)
def get_default_data(self) -> Any:
# For this class we don't use a static default value;
# default is always now.
return datetime.datetime.now(datetime.timezone.utc)
def filter_input(self, data: Any, error: bool) -> Any:
data = verify_time_input(data, error, allow_none=False)
return super().filter_input(data, error)
class OptionalDateTimeValue(SimpleValue[Optional[datetime.datetime]]):
"""Value consisting of a datetime.datetime object or None."""
def __init__(self, store_default: bool = True) -> None:
super().__init__(None,
store_default,
datetime.datetime,
allow_none=True)
def filter_input(self, data: Any, error: bool) -> Any:
data = verify_time_input(data, error, allow_none=True)
return super().filter_input(data, error)
class IntValue(SimpleValue[int]):
"""Value consisting of a single int."""
def __init__(self, default: int = 0, store_default: bool = True) -> None:
super().__init__(default, store_default, int, (bool, float))
class OptionalIntValue(SimpleValue[Optional[int]]):
"""Value consisting of a single int or None"""
def __init__(self,
default: int = None,
store_default: bool = True) -> None:
super().__init__(default,
store_default,
int, (bool, float),
allow_none=True)
class FloatValue(SimpleValue[float]):
"""Value consisting of a single float."""
def __init__(self,
default: float = 0.0,
store_default: bool = True) -> None:
super().__init__(default, store_default, float, (bool, int))
class OptionalFloatValue(SimpleValue[Optional[float]]):
"""Value consisting of a single float or None."""
def __init__(self,
default: float = None,
store_default: bool = True) -> None:
super().__init__(default,
store_default,
float, (bool, int),
allow_none=True)
class Float3Value(SimpleValue[Tuple[float, float, float]]):
"""Value consisting of 3 floats."""
def __init__(self,
default: Tuple[float, float, float] = (0.0, 0.0, 0.0),
store_default: bool = True) -> None:
super().__init__(default, store_default)
def __repr__(self) -> str:
return '<Value of type float3>'
def filter_input(self, data: Any, error: bool) -> Any:
if (not isinstance(data, abc.Sequence) or len(data) != 3
or any(not isinstance(i, (int, float)) for i in data)):
if error:
raise TypeError('Sequence of 3 float values expected.')
logging.error('Ignoring non-3-float-sequence data for %s: %s',
self, data)
data = self.get_default_data()
# Actually store as list.
return [float(data[0]), float(data[1]), float(data[2])]
def filter_output(self, data: Any) -> Any:
"""Override."""
assert len(data) == 3
return tuple(data)
class BaseEnumValue(TypedValue[T]):
"""Value class for storing Python Enums.
Internally enums are stored as their corresponding int/str/etc. values.
"""
def __init__(self,
enumtype: Type[T],
default: Optional[T] = None,
store_default: bool = True,
allow_none: bool = False) -> None:
super().__init__()
assert issubclass(enumtype, Enum)
vals: List[T] = list(enumtype)
# Bit of sanity checking: make sure this enum has at least
# one value and that its underlying values are all of simple
# json-friendly types.
if not vals:
raise TypeError(f'enum {enumtype} has no values')
for val in vals:
assert isinstance(val, Enum)
if not isinstance(val.value, (int, bool, float, str)):
raise TypeError(f'enum value {val} has an invalid'
f' value type {type(val.value)}')
self._enumtype: Type[Enum] = enumtype
self._store_default: bool = store_default
self._allow_none: bool = allow_none
# We store default data is internal format so need to run
# user-provided value through input filter.
# Make sure to set this last since it could depend on other
# stuff we set here.
if default is None and not self._allow_none:
# Special case: we allow passing None as default even if
# we don't support None as a value; in that case we sub
# in the first enum value.
default = vals[0]
self._default_data: Enum = self.filter_input(default, error=True)
def get_default_data(self) -> Any:
return self._default_data
def prune_data(self, data: Any) -> bool:
return not self._store_default and data == self._default_data
def filter_input(self, data: Any, error: bool) -> Any:
# Allow passing in enum objects directly of course.
if isinstance(data, self._enumtype):
data = data.value
elif self._allow_none and data is None:
pass
else:
# At this point we assume its an enum value
try:
self._enumtype(data)
except ValueError:
if error:
raise ValueError(
f'Invalid value for {self._enumtype}: {data}'
) from None
logging.error('Ignoring invalid value for %s: %s',
self._enumtype, data)
data = self._default_data
return data
def filter_output(self, data: Any) -> Any:
if self._allow_none and data is None:
return None
return self._enumtype(data)
class EnumValue(BaseEnumValue[TE]):
"""Value class for storing Python Enums.
Internally enums are stored as their corresponding int/str/etc. values.
"""
def __init__(self,
enumtype: Type[TE],
default: TE = None,
store_default: bool = True) -> None:
super().__init__(enumtype, default, store_default, allow_none=False)
class OptionalEnumValue(BaseEnumValue[Optional[TE]]):
"""Value class for storing Python Enums (or None).
Internally enums are stored as their corresponding int/str/etc. values.
"""
def __init__(self,
enumtype: Type[TE],
default: TE = None,
store_default: bool = True) -> None:
super().__init__(enumtype, default, store_default, allow_none=True)
class CompoundValue(DataHandler):
"""A value containing one or more named child fields of its own.
Custom classes can be defined that inherit from this and include
any number of Field instances within themself.
"""
def __init__(self, store_default: bool = True) -> None:
super().__init__()
self._store_default = store_default
# Run sanity checks on this type if we haven't.
self.run_type_sanity_checks()
def __eq__(self, other: Any) -> Any:
# Allow comparing to compound and bound-compound objects.
return compound_eq(self, other)
def get_default_data(self) -> dict:
return {}
# NOTE: once we've got bound-compound-fields working in mypy
# we should get rid of this here.
# For now it needs to be here though since bound-compound fields
# come across as these in type-land.
def reset(self) -> None:
"""Resets data to default."""
raise ValueError('Unbound CompoundValue cannot be reset.')
def filter_input(self, data: Any, error: bool) -> dict:
if not isinstance(data, dict):
if error:
raise TypeError('dict value expected')
logging.error('Ignoring non-dict data for %s: %s', self, data)
data = {}
assert isinstance(data, dict)
self.apply_fields_to_data(data, error=error)
return data
def prune_data(self, data: Any) -> bool:
# Let all of our sub-fields prune themselves..
self.prune_fields_data(data)
# Now we can optionally prune ourself completely if there's
# nothing left in our data dict...
return not data and not self._store_default
def prune_fields_data(self, d_data: Dict[str, Any]) -> None:
"""Given a CompoundValue and data, prune any unnecessary data.
will include those set to default values with store_default False.
"""
# Allow all fields to take a pruning pass.
assert isinstance(d_data, dict)
for field in self.get_fields().values():
assert isinstance(field.d_key, str)
# This is supposed to be valid data so there should be *something*
# there for all fields.
if field.d_key not in d_data:
raise RuntimeError(f'expected to find {field.d_key} in data'
f' for {self}; got data {d_data}')
# Now ask the field if this data is necessary. If not, prune it.
if field.prune_data(d_data[field.d_key]):
del d_data[field.d_key]
def apply_fields_to_data(self, d_data: Dict[str, Any],
error: bool) -> None:
"""Apply all of our fields to target data.
If error is True, exceptions will be raised for invalid data;
otherwise it will be overwritten (with logging notices emitted).
"""
assert isinstance(d_data, dict)
for field in self.get_fields().values():
assert isinstance(field.d_key, str)
# First off, make sure *something* is there for this field.
if field.d_key not in d_data:
d_data[field.d_key] = field.get_default_data()
# Now let the field tweak the data as needed so its valid.
d_data[field.d_key] = field.filter_input(d_data[field.d_key],
error=error)
def __repr__(self) -> str:
if not hasattr(self, 'd_data'):
return f'<unbound {type(self).__name__} at {hex(id(self))}>'
fstrs: List[str] = []
assert isinstance(self, CompoundValue)
for field in self.get_fields():
fstrs.append(str(field) + '=' + repr(getattr(self, field)))
return type(self).__name__ + '(' + ', '.join(fstrs) + ')'
@classmethod
def get_fields(cls) -> Dict[str, BaseField]:
"""Return all field instances for this type."""
assert issubclass(cls, CompoundValue)
# If we haven't yet, calculate and cache a complete list of fields
# for this exact type.
if cls not in _type_field_cache:
fields: Dict[str, BaseField] = {}
for icls in inspect.getmro(cls):
for name, field in icls.__dict__.items():
if isinstance(field, BaseField):
fields[name] = field
_type_field_cache[cls] = fields
retval: Dict[str, BaseField] = _type_field_cache[cls]
assert isinstance(retval, dict)
return retval
@classmethod
def run_type_sanity_checks(cls) -> None:
"""Given a type, run one-time sanity checks on it.
These tests ensure child fields are using valid
non-repeating names/etc.
"""
if cls not in _sanity_tested_types:
_sanity_tested_types.add(cls)
# Make sure all embedded fields have a key set and there are no
# duplicates.
field_keys: Set[str] = set()
for field in cls.get_fields().values():
assert isinstance(field.d_key, str)
if field.d_key is None:
raise RuntimeError(f'Child field {field} under {cls}'
'has d_key None')
if field.d_key == '':
raise RuntimeError(f'Child field {field} under {cls}'
'has empty d_key')
# Allow alphanumeric and underscore only.
if not field.d_key.replace('_', '').isalnum():
raise RuntimeError(
f'Child field "{field.d_key}" under {cls}'
f' contains invalid characters; only alphanumeric'
f' and underscore allowed.')
if field.d_key in field_keys:
raise RuntimeError('Multiple child fields with key'
f' "{field.d_key}" found in {cls}')
field_keys.add(field.d_key)

131
dist/ba_data/python/efro/entity/util.py vendored Normal file
View file

@ -0,0 +1,131 @@
# Released under the MIT License. See LICENSE for details.
#
"""Misc utility functionality related to the entity system."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, Union, Tuple, List
from efro.entity._value import CompoundValue
from efro.entity._support import BoundCompoundValue
def diff_compound_values(
obj1: Union[BoundCompoundValue, CompoundValue],
obj2: Union[BoundCompoundValue, CompoundValue]) -> str:
"""Generate a string showing differences between two compound values.
Both must be associated with data and have the same set of fields.
"""
# Ensure fields match and both are attached to data...
value1, data1 = get_compound_value_and_data(obj1)
if data1 is None:
raise ValueError(f'Invalid unbound compound value: {obj1}')
value2, data2 = get_compound_value_and_data(obj2)
if data2 is None:
raise ValueError(f'Invalid unbound compound value: {obj2}')
if not have_matching_fields(value1, value2):
raise ValueError(
f"Can't diff objs with non-matching fields: {value1} and {value2}")
# Ok; let 'er rip...
diff = _diff(obj1, obj2, 2)
return ' <no differences>' if diff == '' else diff
class CompoundValueDiff:
"""Wraps diff_compound_values() in an object for efficiency.
It is preferable to pass this to logging calls instead of the
final diff string since the diff will never be generated if
the associated logging level is not being emitted.
"""
def __init__(self, obj1: Union[BoundCompoundValue, CompoundValue],
obj2: Union[BoundCompoundValue, CompoundValue]):
self._obj1 = obj1
self._obj2 = obj2
def __repr__(self) -> str:
return diff_compound_values(self._obj1, self._obj2)
def _diff(obj1: Union[BoundCompoundValue, CompoundValue],
obj2: Union[BoundCompoundValue, CompoundValue], indent: int) -> str:
from efro.entity._support import BoundCompoundValue
bits: List[str] = []
indentstr = ' ' * indent
vobj1, _data1 = get_compound_value_and_data(obj1)
fields = sorted(vobj1.get_fields().keys())
for field in fields:
val1 = getattr(obj1, field)
val2 = getattr(obj2, field)
# for nested compounds, dive in and do nice piecewise compares
if isinstance(val1, BoundCompoundValue):
assert isinstance(val2, BoundCompoundValue)
diff = _diff(val1, val2, indent + 2)
if diff != '':
bits.append(f'{indentstr}{field}:')
bits.append(diff)
# for all else just do a single line
# (perhaps we could improve on this for other complex types)
else:
if val1 != val2:
bits.append(f'{indentstr}{field}: {val1} -> {val2}')
return '\n'.join(bits)
def have_matching_fields(val1: CompoundValue, val2: CompoundValue) -> bool:
"""Return whether two compound-values have matching sets of fields.
Note this just refers to the field configuration; not data.
"""
# Quick-out: matching types will always have identical fields.
if type(val1) is type(val2):
return True
# Otherwise do a full comparison.
return val1.get_fields() == val2.get_fields()
def get_compound_value_and_data(
obj: Union[BoundCompoundValue,
CompoundValue]) -> Tuple[CompoundValue, Any]:
"""Return value and data for bound or unbound compound values."""
# pylint: disable=cyclic-import
from efro.entity._support import BoundCompoundValue
from efro.entity._value import CompoundValue
if isinstance(obj, BoundCompoundValue):
value = obj.d_value
data = obj.d_data
elif isinstance(obj, CompoundValue):
value = obj
data = getattr(obj, 'd_data', None) # may not exist
else:
raise TypeError(
f'Expected a BoundCompoundValue or CompoundValue; got {type(obj)}')
return value, data
def compound_eq(obj1: Union[BoundCompoundValue, CompoundValue],
obj2: Union[BoundCompoundValue, CompoundValue]) -> Any:
"""Compare two compound value/bound-value objects for equality."""
# Criteria for comparison: both need to be a compound value
# and both must have data (which implies they are either a entity
# or bound to a subfield in an entity).
value1, data1 = get_compound_value_and_data(obj1)
if data1 is None:
return NotImplemented
value2, data2 = get_compound_value_and_data(obj2)
if data2 is None:
return NotImplemented
# Ok we can compare them. To consider them equal we look for
# matching sets of fields and matching data. Note that there
# could be unbound data causing inequality despite their field
# values all matching; not sure if that's what we want.
return have_matching_fields(value1, value2) and data1 == data2

244
dist/ba_data/python/efro/error.py vendored Normal file
View file

@ -0,0 +1,244 @@
# Released under the MIT License. See LICENSE for details.
#
"""Common errors and related functionality."""
from __future__ import annotations
from typing import TYPE_CHECKING
import errno
if TYPE_CHECKING:
pass
class CleanError(Exception):
"""An error that can be presented to the user as a simple message.
These errors should be completely self-explanatory, to the point where
a traceback or other context would not be useful.
A CleanError with no message can be used to inform a script to fail
without printing any message.
This should generally be limited to errors that will *always* be
presented to the user (such as those in high level tool code).
Exceptions that may be caught and handled by other code should use
more descriptive exception types.
"""
def pretty_print(self, flush: bool = False) -> None:
"""Print the error to stdout, using red colored output if available.
If the error has an empty message, prints nothing (not even a newline).
"""
from efro.terminal import Clr
errstr = str(self)
if errstr:
print(f'{Clr.SRED}{errstr}{Clr.RST}', flush=flush)
class CommunicationError(Exception):
"""A communication related error has occurred.
This covers anything network-related going wrong in the sending
of data or receiving of a response. Basically anything that is out
of our control should get lumped in here. This error does not imply
that data was not received on the other end; only that a full
acknowledgement round trip was not completed.
These errors should be gracefully handled whenever possible, as
occasional network issues are unavoidable.
"""
class RemoteError(Exception):
"""An error occurred on the other end of some connection.
This occurs when communication succeeds but another type of error
occurs remotely. The error string can consist of a remote stack
trace or a simple message depending on the context.
Communication systems should raise more specific error types locally
when more introspection/control is needed; this is intended somewhat
as a catch-all.
"""
def __init__(self, msg: str, peer_desc: str):
super().__init__(msg)
self._peer_desc = peer_desc
def __str__(self) -> str:
s = ''.join(str(arg) for arg in self.args)
# Indent so we can more easily tell what is the remote part when
# this is in the middle of a long exception chain.
padding = ' '
s = ''.join(padding + line for line in s.splitlines(keepends=True))
return f'The following occurred on {self._peer_desc}:\n{s}'
class IntegrityError(ValueError):
"""Data has been tampered with or corrupted in some form."""
class AuthenticationError(Exception):
"""Authentication has failed for some operation.
This can be raised if server-side-verification does not match
client-supplied credentials, if an invalid password is supplied
for a sign-in attempt, etc.
"""
def is_urllib_communication_error(exc: BaseException, url: str | None) -> bool:
"""Is the provided exception from urllib a communication-related error?
Url, if provided can provide extra context for when to treat an error
as such an error.
This should be passed an exception which resulted from opening or
reading a urllib Request. It returns True for any errors that could
conceivably arise due to unavailable/poor network connections,
firewall/connectivity issues, or other issues out of our control.
These errors can often be safely ignored or presented to the user
as general 'network-unavailable' states.
"""
import urllib.error
import http.client
import socket
if isinstance(
exc,
(
urllib.error.URLError,
ConnectionError,
http.client.IncompleteRead,
http.client.BadStatusLine,
http.client.RemoteDisconnected,
socket.timeout,
),
):
# Special case: although an HTTPError is a subclass of URLError,
# we don't consider it a communication error. It generally means we
# have successfully communicated with the server but what we are asking
# for is not there/etc.
if isinstance(exc, urllib.error.HTTPError):
# Special sub-case: appspot.com hosting seems to give 403 errors
# (forbidden) to some countries. I'm assuming for legal reasons?..
# Let's consider that a communication error since its out of our
# control so we don't fill up logs with it.
if exc.code == 403 and url is not None and '.appspot.com' in url:
return True
return False
return True
if isinstance(exc, OSError):
if exc.errno == 10051: # Windows unreachable network error.
return True
if exc.errno in {
errno.ETIMEDOUT,
errno.EHOSTUNREACH,
errno.ENETUNREACH,
}:
return True
return False
def is_requests_communication_error(exc: BaseException) -> bool:
"""Is the provided exception a communication-related error from requests?"""
import requests
# Looks like this maps pretty well onto requests' ConnectionError
return isinstance(exc, requests.ConnectionError)
def is_udp_communication_error(exc: BaseException) -> bool:
"""Should this udp-related exception be considered a communication error?
This should be passed an exception which resulted from creating and
using a socket.SOCK_DGRAM type socket. It should return True for any
errors that could conceivably arise due to unavailable/poor network
conditions, firewall/connectivity issues, etc. These issues can often
be safely ignored or presented to the user as general
'network-unavailable' states.
"""
if isinstance(exc, ConnectionRefusedError | TimeoutError):
return True
if isinstance(exc, OSError):
if exc.errno == 10051: # Windows unreachable network error.
return True
if exc.errno in {
errno.EADDRNOTAVAIL,
errno.ETIMEDOUT,
errno.EHOSTUNREACH,
errno.ENETUNREACH,
errno.EINVAL,
errno.EPERM,
errno.EACCES,
# Windows 'invalid argument' error.
10022,
# Windows 'a socket operation was attempted to'
# 'an unreachable network' error.
10051,
}:
return True
return False
def is_asyncio_streams_communication_error(exc: BaseException) -> bool:
"""Should this streams error be considered a communication error?
This should be passed an exception which resulted from creating and
using asyncio streams. It should return True for any errors that could
conceivably arise due to unavailable/poor network connections,
firewall/connectivity issues, etc. These issues can often be safely
ignored or presented to the user as general 'connection-lost' events.
"""
# pylint: disable=too-many-return-statements
import ssl
if isinstance(
exc,
(
ConnectionError,
TimeoutError,
EOFError,
),
):
return True
# Also some specific errno ones.
if isinstance(exc, OSError):
if exc.errno == 10051: # Windows unreachable network error.
return True
if exc.errno in {
errno.ETIMEDOUT,
errno.EHOSTUNREACH,
errno.ENETUNREACH,
}:
return True
# Am occasionally getting a specific SSL error on shutdown which I
# believe is harmless (APPLICATION_DATA_AFTER_CLOSE_NOTIFY).
# It sounds like it may soon be ignored by Python (as of March 2022).
# Let's still complain, however, if we get any SSL errors besides
# this one. https://bugs.python.org/issue39951
if isinstance(exc, ssl.SSLError):
excstr = str(exc)
if 'APPLICATION_DATA_AFTER_CLOSE_NOTIFY' in excstr:
return True
# Also occasionally am getting WRONG_VERSION_NUMBER ssl errors;
# Assuming this just means client is attempting to connect from some
# outdated browser or whatnot.
if 'SSL: WRONG_VERSION_NUMBER' in excstr:
return True
# And seeing this very rarely; assuming its just data corruption?
if 'SSL: DECRYPTION_FAILED_OR_BAD_RECORD_MAC' in excstr:
return True
return False

72
dist/ba_data/python/efro/json.py vendored Normal file
View file

@ -0,0 +1,72 @@
# Released under the MIT License. See LICENSE for details.
#
"""Custom json compressor/decompressor with support for more data times/etc."""
from __future__ import annotations
import datetime
import json
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
# Special attr we included for our extended type information
# (extended-json-type)
TYPE_TAG = '_xjtp'
_pytz_utc: Any
# We don't *require* pytz since it must be installed through pip
# but it is used by firestore client for its utc tzinfos.
# (in which case it should be installed as a dependency anyway)
try:
import pytz
_pytz_utc = pytz.utc
except ModuleNotFoundError:
_pytz_utc = None # pylint: disable=invalid-name
class ExtendedJSONEncoder(json.JSONEncoder):
"""Custom json encoder supporting additional types."""
def default(self, obj: Any) -> Any: # pylint: disable=W0221
if isinstance(obj, datetime.datetime):
# We only support timezone-aware utc times.
if (obj.tzinfo is not datetime.timezone.utc
and (_pytz_utc is None or obj.tzinfo is not _pytz_utc)):
raise ValueError(
'datetime values must have timezone set as timezone.utc')
return {
TYPE_TAG:
'dt',
'v': [
obj.year, obj.month, obj.day, obj.hour, obj.minute,
obj.second, obj.microsecond
],
}
return super().default(obj)
class ExtendedJSONDecoder(json.JSONDecoder):
"""Custom json decoder supporting extended types."""
def __init__(self, *args: Any, **kwargs: Any):
json.JSONDecoder.__init__(self,
object_hook=self.object_hook,
*args,
**kwargs)
def object_hook(self, obj: Any) -> Any: # pylint: disable=E0202
"""Custom hook."""
if TYPE_TAG not in obj:
return obj
objtype = obj[TYPE_TAG]
if objtype == 'dt':
vals = obj.get('v', [])
if len(vals) != 7:
raise ValueError('malformed datetime value')
return datetime.datetime( # type: ignore
*vals, tzinfo=datetime.timezone.utc)
return obj

626
dist/ba_data/python/efro/log.py vendored Normal file
View file

@ -0,0 +1,626 @@
# Released under the MIT License. See LICENSE for details.
#
"""Logging functionality."""
from __future__ import annotations
import sys
import time
import asyncio
import logging
import datetime
import itertools
from enum import Enum
from collections import deque
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated
from threading import Thread, current_thread, Lock
from efro.util import utc_now
from efro.call import tpartial
from efro.terminal import TerminalColor
from efro.dataclassio import ioprepped, IOAttrs, dataclass_to_json
if TYPE_CHECKING:
from pathlib import Path
from typing import Any, Callable, TextIO
class LogLevel(Enum):
"""Severity level for a log entry.
These enums have numeric values so they can be compared in severity.
Note that these values are not currently interchangeable with the
logging.ERROR, logging.DEBUG, etc. values.
"""
DEBUG = 0
INFO = 1
WARNING = 2
ERROR = 3
CRITICAL = 4
@property
def python_logging_level(self) -> int:
"""Give the corresponding logging level."""
return LOG_LEVEL_LEVELNOS[self]
@classmethod
def from_python_logging_level(cls, levelno: int) -> LogLevel:
"""Given a Python logging level, return a LogLevel."""
return LEVELNO_LOG_LEVELS[levelno]
# Python logging levels from LogLevels
LOG_LEVEL_LEVELNOS = {
LogLevel.DEBUG: logging.DEBUG,
LogLevel.INFO: logging.INFO,
LogLevel.WARNING: logging.WARNING,
LogLevel.ERROR: logging.ERROR,
LogLevel.CRITICAL: logging.CRITICAL,
}
# LogLevels from Python logging levels
LEVELNO_LOG_LEVELS = {
logging.DEBUG: LogLevel.DEBUG,
logging.INFO: LogLevel.INFO,
logging.WARNING: LogLevel.WARNING,
logging.ERROR: LogLevel.ERROR,
logging.CRITICAL: LogLevel.CRITICAL,
}
LEVELNO_COLOR_CODES: dict[int, tuple[str, str]] = {
logging.DEBUG: (TerminalColor.CYAN.value, TerminalColor.RESET.value),
logging.INFO: ('', ''),
logging.WARNING: (TerminalColor.YELLOW.value, TerminalColor.RESET.value),
logging.ERROR: (TerminalColor.RED.value, TerminalColor.RESET.value),
logging.CRITICAL: (
TerminalColor.STRONG_MAGENTA.value
+ TerminalColor.BOLD.value
+ TerminalColor.BG_BLACK.value,
TerminalColor.RESET.value,
),
}
@ioprepped
@dataclass
class LogEntry:
"""Single logged message."""
name: Annotated[str, IOAttrs('n', soft_default='root', store_default=False)]
message: Annotated[str, IOAttrs('m')]
level: Annotated[LogLevel, IOAttrs('l')]
time: Annotated[datetime.datetime, IOAttrs('t')]
@ioprepped
@dataclass
class LogArchive:
"""Info and data for a log."""
# Total number of entries submitted to the log.
log_size: Annotated[int, IOAttrs('t')]
# Offset for the entries contained here.
# (10 means our first entry is the 10th in the log, etc.)
start_index: Annotated[int, IOAttrs('c')]
entries: Annotated[list[LogEntry], IOAttrs('e')]
class LogHandler(logging.Handler):
"""Fancy-pants handler for logging output.
Writes logs to disk in structured json format and echoes them
to stdout/stderr with pretty colors.
"""
_event_loop: asyncio.AbstractEventLoop
# IMPORTANT: Any debug prints we do here should ONLY go to echofile.
# Otherwise we can get infinite loops as those prints come back to us
# as new log entries.
def __init__(
self,
path: str | Path | None,
echofile: TextIO | None,
suppress_non_root_debug: bool,
cache_size_limit: int,
cache_time_limit: datetime.timedelta | None,
):
super().__init__()
# pylint: disable=consider-using-with
self._file = None if path is None else open(path, 'w', encoding='utf-8')
self._echofile = echofile
self._callbacks_lock = Lock()
self._callbacks: list[Callable[[LogEntry], None]] = []
self._suppress_non_root_debug = suppress_non_root_debug
self._file_chunks: dict[str, list[str]] = {'stdout': [], 'stderr': []}
self._file_chunk_ship_task: dict[str, asyncio.Task | None] = {
'stdout': None,
'stderr': None,
}
self._cache_size = 0
assert cache_size_limit >= 0
self._cache_size_limit = cache_size_limit
self._cache_time_limit = cache_time_limit
self._cache = deque[tuple[int, LogEntry]]()
self._cache_index_offset = 0
self._cache_lock = Lock()
self._printed_callback_error = False
self._thread_bootstrapped = False
self._thread = Thread(target=self._log_thread_main, daemon=True)
if __debug__:
self._last_slow_emit_warning_time: float | None = None
self._thread.start()
# Spin until our thread is up and running; otherwise we could
# wind up trying to push stuff to our event loop before the
# loop exists.
while not self._thread_bootstrapped:
time.sleep(0.001)
def add_callback(self, call: Callable[[LogEntry], None]) -> None:
"""Add a callback to be run for each LogEntry.
Note that this callback will always run in a background thread.
"""
with self._callbacks_lock:
self._callbacks.append(call)
def _log_thread_main(self) -> None:
self._event_loop = asyncio.new_event_loop()
# In our background thread event loop we do a fair amount of
# slow synchronous stuff such as mucking with the log cache.
# Let's avoid getting tons of warnings about this in debug mode.
self._event_loop.slow_callback_duration = 2.0 # Default is 0.1
# NOTE: if we ever use default threadpool at all we should allow
# setting it for our loop.
asyncio.set_event_loop(self._event_loop)
self._thread_bootstrapped = True
try:
if self._cache_time_limit is not None:
self._event_loop.create_task(self._time_prune_cache())
self._event_loop.run_forever()
except BaseException:
# If this ever goes down we're in trouble.
# We won't be able to log about it though...
# Try to make some noise however we can.
print('LogHandler died!!!', file=sys.stderr)
import traceback
traceback.print_exc()
raise
async def _time_prune_cache(self) -> None:
assert self._cache_time_limit is not None
while bool(True):
await asyncio.sleep(61.27)
now = utc_now()
with self._cache_lock:
# Prune the oldest entry as long as there is a first one that
# is too old.
while (
self._cache
and (now - self._cache[0][1].time) >= self._cache_time_limit
):
popped = self._cache.popleft()
self._cache_size -= popped[0]
self._cache_index_offset += 1
def get_cached(
self, start_index: int = 0, max_entries: int | None = None
) -> LogArchive:
"""Build and return an archive of cached log entries.
This will only include entries that have been processed by the
background thread, so may not include just-submitted logs or
entries for partially written stdout/stderr lines.
Entries from the range [start_index:start_index+max_entries]
which are still present in the cache will be returned.
"""
assert start_index >= 0
if max_entries is not None:
assert max_entries >= 0
with self._cache_lock:
# Transform start_index to our present cache space.
start_index -= self._cache_index_offset
# Calc end-index in our present cache space.
end_index = (
len(self._cache)
if max_entries is None
else start_index + max_entries
)
# Clamp both indexes to both ends of our present space.
start_index = max(0, min(start_index, len(self._cache)))
end_index = max(0, min(end_index, len(self._cache)))
return LogArchive(
log_size=self._cache_index_offset + len(self._cache),
start_index=start_index + self._cache_index_offset,
entries=self._cache_slice(start_index, end_index),
)
def _cache_slice(
self, start: int, end: int, step: int = 1
) -> list[LogEntry]:
# Deque doesn't natively support slicing but we can do it manually.
# It sounds like rotating the deque and pulling from the beginning
# is the most efficient way to do this. The downside is the deque
# gets temporarily modified in the process so we need to make sure
# we're holding the lock.
assert self._cache_lock.locked()
cache = self._cache
cache.rotate(-start)
slc = [e[1] for e in itertools.islice(cache, 0, end - start, step)]
cache.rotate(start)
return slc
@classmethod
def _is_immutable_log_data(cls, data: Any) -> bool:
if isinstance(data, (str, bool, int, float, bytes)):
return True
if isinstance(data, tuple):
return all(cls._is_immutable_log_data(x) for x in data)
return False
def emit(self, record: logging.LogRecord) -> None:
if __debug__:
starttime = time.monotonic()
# Called by logging to send us records.
# Special case: filter out this common extra-chatty category.
# TODO - perhaps should use a standard logging.Filter for this.
if (
self._suppress_non_root_debug
and record.name != 'root'
and record.levelname == 'DEBUG'
):
return
# Optimization: if our log args are all simple immutable values,
# we can just kick the whole thing over to our background thread to
# be formatted there at our leisure. If anything is mutable and
# thus could possibly change between now and then or if we want
# to do immediate file echoing then we need to bite the bullet
# and do that stuff here at the call site.
fast_path = self._echofile is None and self._is_immutable_log_data(
record.args
)
if fast_path:
if __debug__:
formattime = echotime = time.monotonic()
self._event_loop.call_soon_threadsafe(
tpartial(
self._emit_in_thread,
record.name,
record.levelno,
record.created,
record,
)
)
else:
# Slow case; do formatting and echoing here at the log call
# site.
msg = self.format(record)
if __debug__:
formattime = time.monotonic()
# Also immediately print pretty colored output to our echo file
# (generally stderr). We do this part here instead of in our bg
# thread because the delay can throw off command line prompts or
# make tight debugging harder.
if self._echofile is not None:
ends = LEVELNO_COLOR_CODES.get(record.levelno)
if ends is not None:
self._echofile.write(f'{ends[0]}{msg}{ends[1]}\n')
else:
self._echofile.write(f'{msg}\n')
if __debug__:
echotime = time.monotonic()
self._event_loop.call_soon_threadsafe(
tpartial(
self._emit_in_thread,
record.name,
record.levelno,
record.created,
msg,
)
)
if __debug__:
# Make noise if we're taking a significant amount of time here.
# Limit the noise to once every so often though; otherwise we
# could get a feedback loop where every log emit results in a
# warning log which results in another, etc.
now = time.monotonic()
# noinspection PyUnboundLocalVariable
duration = now - starttime
# noinspection PyUnboundLocalVariable
format_duration = formattime - starttime
# noinspection PyUnboundLocalVariable
echo_duration = echotime - formattime
if duration > 0.05 and (
self._last_slow_emit_warning_time is None
or now > self._last_slow_emit_warning_time + 10.0
):
# Logging calls from *within* a logging handler
# sounds sketchy, so let's just kick this over to
# the bg event loop thread we've already got.
self._last_slow_emit_warning_time = now
self._event_loop.call_soon_threadsafe(
tpartial(
logging.warning,
'efro.log.LogHandler emit took too long'
' (%.2fs total; %.2fs format, %.2fs echo,'
' fast_path=%s).',
duration,
format_duration,
echo_duration,
fast_path,
)
)
def _emit_in_thread(
self,
name: str,
levelno: int,
created: float,
message: str | logging.LogRecord,
) -> None:
try:
# If they passed a raw record here, bake it down to a string.
if isinstance(message, logging.LogRecord):
message = self.format(message)
self._emit_entry(
LogEntry(
name=name,
message=message,
level=LEVELNO_LOG_LEVELS.get(levelno, LogLevel.INFO),
time=datetime.datetime.fromtimestamp(
created, datetime.timezone.utc
),
)
)
except Exception:
import traceback
traceback.print_exc(file=self._echofile)
def file_write(self, name: str, output: str) -> None:
"""Send raw stdout/stderr output to the logger to be collated."""
self._event_loop.call_soon_threadsafe(
tpartial(self._file_write_in_thread, name, output)
)
def _file_write_in_thread(self, name: str, output: str) -> None:
try:
assert name in ('stdout', 'stderr')
# Here we try to be somewhat smart about breaking arbitrary
# print output into discrete log entries.
self._file_chunks[name].append(output)
# Individual parts of a print come across as separate writes,
# and the end of a print will be a standalone '\n' by default.
# Let's use that as a hint that we're likely at the end of
# a full print statement and ship what we've got.
if output == '\n':
self._ship_file_chunks(name, cancel_ship_task=True)
else:
# By default just keep adding chunks.
# However we keep a timer running anytime we've got
# unshipped chunks so that we can ship what we've got
# after a short bit if we never get a newline.
ship_task = self._file_chunk_ship_task[name]
if ship_task is None:
self._file_chunk_ship_task[
name
] = self._event_loop.create_task(
self._ship_chunks_task(name),
name='log ship file chunks',
)
except Exception:
import traceback
traceback.print_exc(file=self._echofile)
def file_flush(self, name: str) -> None:
"""Send raw stdout/stderr flush to the logger to be collated."""
self._event_loop.call_soon_threadsafe(
tpartial(self._file_flush_in_thread, name)
)
def _file_flush_in_thread(self, name: str) -> None:
try:
assert name in ('stdout', 'stderr')
# Immediately ship whatever chunks we've got.
if self._file_chunks[name]:
self._ship_file_chunks(name, cancel_ship_task=True)
except Exception:
import traceback
traceback.print_exc(file=self._echofile)
async def _ship_chunks_task(self, name: str) -> None:
self._ship_file_chunks(name, cancel_ship_task=False)
def _ship_file_chunks(self, name: str, cancel_ship_task: bool) -> None:
# Note: Raw print input generally ends in a newline, but that is
# redundant when we break things into log entries and results
# in extra empty lines. So strip off a single trailing newline.
text = ''.join(self._file_chunks[name]).removesuffix('\n')
self._emit_entry(
LogEntry(
name=name, message=text, level=LogLevel.INFO, time=utc_now()
)
)
self._file_chunks[name] = []
ship_task = self._file_chunk_ship_task[name]
if cancel_ship_task and ship_task is not None:
ship_task.cancel()
self._file_chunk_ship_task[name] = None
def _emit_entry(self, entry: LogEntry) -> None:
assert current_thread() is self._thread
# Store to our cache.
if self._cache_size_limit > 0:
with self._cache_lock:
# Do a rough calc of how many bytes this entry consumes.
entry_size = sum(
sys.getsizeof(x)
for x in (
entry,
entry.name,
entry.message,
entry.level,
entry.time,
)
)
self._cache.append((entry_size, entry))
self._cache_size += entry_size
# Prune old until we are back at or under our limit.
while self._cache_size > self._cache_size_limit:
popped = self._cache.popleft()
self._cache_size -= popped[0]
self._cache_index_offset += 1
# Pass to callbacks.
with self._callbacks_lock:
for call in self._callbacks:
try:
call(entry)
except Exception:
# Only print one callback error to avoid insanity.
if not self._printed_callback_error:
import traceback
traceback.print_exc(file=self._echofile)
self._printed_callback_error = True
# Dump to our structured log file.
# TODO: set a timer for flushing; don't flush every line.
if self._file is not None:
entry_s = dataclass_to_json(entry)
assert '\n' not in entry_s # Make sure its a single line.
print(entry_s, file=self._file, flush=True)
class FileLogEcho:
"""A file-like object for forwarding stdout/stderr to a LogHandler."""
def __init__(
self, original: TextIO, name: str, handler: LogHandler
) -> None:
assert name in ('stdout', 'stderr')
self._original = original
self._name = name
self._handler = handler
def write(self, output: Any) -> None:
"""Override standard write call."""
self._original.write(output)
self._handler.file_write(self._name, output)
def flush(self) -> None:
"""Flush the file."""
self._original.flush()
# We also use this as a hint to ship whatever file chunks
# we've accumulated (we have to try and be smart about breaking
# our arbitrary file output into discrete entries).
self._handler.file_flush(self._name)
def isatty(self) -> bool:
"""Are we a terminal?"""
return self._original.isatty()
def setup_logging(
log_path: str | Path | None,
level: LogLevel,
suppress_non_root_debug: bool = False,
log_stdout_stderr: bool = False,
echo_to_stderr: bool = True,
cache_size_limit: int = 0,
cache_time_limit: datetime.timedelta | None = None,
) -> LogHandler:
"""Set up our logging environment.
Returns the custom handler which can be used to fetch information
about logs that have passed through it. (worst log-levels, caches, etc.).
"""
lmap = {
LogLevel.DEBUG: logging.DEBUG,
LogLevel.INFO: logging.INFO,
LogLevel.WARNING: logging.WARNING,
LogLevel.ERROR: logging.ERROR,
LogLevel.CRITICAL: logging.CRITICAL,
}
# Wire logger output to go to a structured log file.
# Also echo it to stderr IF we're running in a terminal.
# UPDATE: Actually gonna always go to stderr. Is there a
# reason we shouldn't? This makes debugging possible if all
# we have is access to a non-interactive terminal or file dump.
# We could add a '--quiet' arg or whatnot to change this behavior.
# Note: by passing in the *original* stderr here before we
# (potentially) replace it, we ensure that our log echos
# won't themselves be intercepted and sent to the logger
# which would create an infinite loop.
loghandler = LogHandler(
log_path,
echofile=sys.stderr if echo_to_stderr else None,
suppress_non_root_debug=suppress_non_root_debug,
cache_size_limit=cache_size_limit,
cache_time_limit=cache_time_limit,
)
# Note: going ahead with force=True here so that we replace any
# existing logger. Though we warn if it looks like we are doing
# that so we can try to avoid creating the first one.
had_previous_handlers = bool(logging.root.handlers)
logging.basicConfig(
level=lmap[level],
format='%(message)s',
handlers=[loghandler],
force=True,
)
if had_previous_handlers:
logging.warning('setup_logging: force-replacing previous handlers.')
# Optionally intercept Python's stdout/stderr output and generate
# log entries from it.
if log_stdout_stderr:
sys.stdout = FileLogEcho( # type: ignore
sys.stdout, 'stdout', loghandler
)
sys.stderr = FileLogEcho( # type: ignore
sys.stderr, 'stderr', loghandler
)
return loghandler

991
dist/ba_data/python/efro/message.py vendored Normal file
View file

@ -0,0 +1,991 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for sending and responding to messages.
Supports static typing for message types and possible return types.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar, Annotated
from dataclasses import dataclass
from enum import Enum
import inspect
import logging
import json
import traceback
from efro.error import CleanError, RemoteError
from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs,
dataclass_to_dict, dataclass_from_dict)
if TYPE_CHECKING:
from typing import Any, Callable, Optional, Sequence, Union, Awaitable
TM = TypeVar('TM', bound='MessageSender')
class Message:
"""Base class for messages."""
@classmethod
def get_response_types(cls) -> list[type[Response]]:
"""Return all message types this Message can result in when sent.
The default implementation specifies EmptyResponse, so messages with
no particular response needs can leave this untouched.
Note that ErrorMessage is handled as a special case and does not
need to be specified here.
"""
return [EmptyResponse]
class Response:
"""Base class for responses to messages."""
# Some standard response types:
class ErrorType(Enum):
"""Type of error that occurred in remote message handling."""
OTHER = 0
CLEAN = 1
@ioprepped
@dataclass
class ErrorResponse(Response):
"""Message saying some error has occurred on the other end.
This type is unique in that it is not returned to the user; it
instead results in a local exception being raised.
"""
error_message: Annotated[str, IOAttrs('m')]
error_type: Annotated[ErrorType, IOAttrs('e')] = ErrorType.OTHER
@ioprepped
@dataclass
class EmptyResponse(Response):
"""The response equivalent of None."""
# TODO: could allow handlers to deal in raw values for these
# types similar to how we allow None in place of EmptyResponse.
# Though not sure if they are widely used enough to warrant the
# extra code complexity.
@ioprepped
@dataclass
class BoolResponse(Response):
"""A simple bool value response."""
value: Annotated[bool, IOAttrs('v')]
@ioprepped
@dataclass
class StringResponse(Response):
"""A simple string value response."""
value: Annotated[str, IOAttrs('v')]
class MessageProtocol:
"""Wrangles a set of message types, formats, and response types.
Both endpoints must be using a compatible Protocol for communication
to succeed. To maintain Protocol compatibility between revisions,
all message types must retain the same id, message attr storage names must
not change, newly added attrs must have default values, etc.
"""
def __init__(self,
message_types: dict[int, type[Message]],
response_types: dict[int, type[Response]],
type_key: Optional[str] = None,
preserve_clean_errors: bool = True,
log_remote_exceptions: bool = True,
trusted_sender: bool = False) -> None:
"""Create a protocol with a given configuration.
Note that common response types are automatically registered
with (unchanging negative ids) so they don't need to be passed
explicitly (but can be if a different id is desired).
If 'type_key' is provided, the message type ID is stored as the
provided key in the message dict; otherwise it will be stored as
part of a top level dict with the message payload appearing as a
child dict. This is mainly for backwards compatibility.
If 'preserve_clean_errors' is True, efro.error.CleanError errors
on the remote end will result in the same error raised locally.
All other Exception types come across as efro.error.RemoteError.
If 'trusted_sender' is True, stringified remote stack traces will
be included in the responses if errors occur.
"""
# pylint: disable=too-many-locals
self.message_types_by_id: dict[int, type[Message]] = {}
self.message_ids_by_type: dict[type[Message], int] = {}
self.response_types_by_id: dict[int, type[Response]] = {}
self.response_ids_by_type: dict[type[Response], int] = {}
for m_id, m_type in message_types.items():
# Make sure only valid message types were passed and each
# id was assigned only once.
assert isinstance(m_id, int)
assert m_id >= 0
assert (is_ioprepped_dataclass(m_type)
and issubclass(m_type, Message))
assert self.message_types_by_id.get(m_id) is None
self.message_types_by_id[m_id] = m_type
self.message_ids_by_type[m_type] = m_id
for r_id, r_type in response_types.items():
assert isinstance(r_id, int)
assert r_id >= 0
assert (is_ioprepped_dataclass(r_type)
and issubclass(r_type, Response))
assert self.response_types_by_id.get(r_id) is None
self.response_types_by_id[r_id] = r_type
self.response_ids_by_type[r_type] = r_id
# Go ahead and auto-register a few common response types
# if the user has not done so explicitly. Use unique IDs which
# will never change or overlap with user ids.
def _reg_if_not(reg_tp: type[Response], reg_id: int) -> None:
if reg_tp in self.response_ids_by_type:
return
assert self.response_types_by_id.get(reg_id) is None
self.response_types_by_id[reg_id] = reg_tp
self.response_ids_by_type[reg_tp] = reg_id
_reg_if_not(ErrorResponse, -1)
_reg_if_not(EmptyResponse, -2)
# _reg_if_not(BoolResponse, -3)
# Some extra-thorough validation in debug mode.
if __debug__:
# Make sure all Message types' return types are valid
# and have been assigned an ID as well.
all_response_types: set[type[Response]] = set()
for m_id, m_type in message_types.items():
m_rtypes = m_type.get_response_types()
assert isinstance(m_rtypes, list)
assert m_rtypes, (
f'Message type {m_type} specifies no return types.')
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
all_response_types.update(m_rtypes)
for cls in all_response_types:
assert is_ioprepped_dataclass(cls)
assert issubclass(cls, Response)
if cls not in self.response_ids_by_type:
raise ValueError(f'Possible response type {cls}'
f' needs to be included in response_types'
f' for this protocol.')
# Make sure all registered types have unique base names.
# We can take advantage of this to generate cleaner looking
# protocol modules. Can revisit if this is ever a problem.
mtypenames = set(tp.__name__ for tp in self.message_ids_by_type)
if len(mtypenames) != len(message_types):
raise ValueError(
'message_types contains duplicate __name__s;'
' all types are required to have unique names.')
self._type_key = type_key
self.preserve_clean_errors = preserve_clean_errors
self.log_remote_exceptions = log_remote_exceptions
self.trusted_sender = trusted_sender
def encode_message(self, message: Message) -> str:
"""Encode a message to a json string for transport."""
return self._encode(message, self.message_ids_by_type, 'message')
def encode_response(self, response: Response) -> str:
"""Encode a response to a json string for transport."""
return self._encode(response, self.response_ids_by_type, 'response')
def _encode(self, message: Any, ids_by_type: dict[type, int],
opname: str) -> str:
"""Encode a message to a json string for transport."""
m_id: Optional[int] = ids_by_type.get(type(message))
if m_id is None:
raise TypeError(f'{opname} type is not registered in protocol:'
f' {type(message)}')
msgdict = dataclass_to_dict(message)
# Encode type as part of the message/response dict if desired
# (for legacy compatibility).
if self._type_key is not None:
if self._type_key in msgdict:
raise RuntimeError(f'Type-key {self._type_key}'
f' found in msg of type {type(message)}')
msgdict[self._type_key] = m_id
out = msgdict
else:
out = {'m': msgdict, 't': m_id}
return json.dumps(out, separators=(',', ':'))
def decode_message(self, data: str) -> Message:
"""Decode a message from a json string."""
out = self._decode(data, self.message_types_by_id, 'message')
assert isinstance(out, Message)
return out
def decode_response(self, data: str) -> Optional[Response]:
"""Decode a response from a json string."""
out = self._decode(data, self.response_types_by_id, 'response')
assert isinstance(out, (Response, type(None)))
return out
# Weeeird; we get mypy errors returning dict[int, type] but
# dict[int, typing.Type] or dict[int, type[Any]] works..
def _decode(self, data: str, types_by_id: dict[int, type[Any]],
opname: str) -> Any:
"""Decode a message from a json string."""
msgfull = json.loads(data)
assert isinstance(msgfull, dict)
msgdict: Optional[dict]
if self._type_key is not None:
m_id = msgfull.pop(self._type_key)
msgdict = msgfull
assert isinstance(m_id, int)
else:
m_id = msgfull.get('t')
msgdict = msgfull.get('m')
assert isinstance(m_id, int)
assert isinstance(msgdict, dict)
# Decode this particular type.
msgtype = types_by_id.get(m_id)
if msgtype is None:
raise TypeError(f'Got unregistered {opname} type id of {m_id}.')
out = dataclass_from_dict(msgtype, msgdict)
# Special case: if we get EmptyResponse, we simply return None.
if isinstance(out, EmptyResponse):
return None
# Special case: a remote error occurred. Raise a local Exception
# instead of returning the message.
if isinstance(out, ErrorResponse):
assert opname == 'response'
if (self.preserve_clean_errors
and out.error_type is ErrorType.CLEAN):
raise CleanError(out.error_message)
raise RemoteError(out.error_message)
return out
def _get_module_header(self, part: str) -> str:
"""Return common parts of generated modules."""
# pylint: disable=too-many-locals, too-many-branches
import textwrap
tpimports: dict[str, list[str]] = {}
imports: dict[str, list[str]] = {}
single_message_type = len(self.message_ids_by_type) == 1
# Always import messages
for msgtype in list(self.message_ids_by_type) + [Message]:
tpimports.setdefault(msgtype.__module__,
[]).append(msgtype.__name__)
for rsp_tp in list(self.response_ids_by_type) + [Response]:
# Skip these as they don't actually show up in code.
if rsp_tp is EmptyResponse or rsp_tp is ErrorResponse:
continue
if (single_message_type and part == 'sender'
and rsp_tp is not Response):
# We need to cast to the single supported response type
# in this case so need response types at runtime.
imports.setdefault(rsp_tp.__module__,
[]).append(rsp_tp.__name__)
else:
tpimports.setdefault(rsp_tp.__module__,
[]).append(rsp_tp.__name__)
import_lines = ''
tpimport_lines = ''
for module, names in sorted(imports.items()):
jnames = ', '.join(names)
line = f'from {module} import {jnames}'
if len(line) > 79:
# Recreate in a wrapping-friendly form.
line = f'from {module} import ({jnames})'
import_lines += f'{line}\n'
for module, names in sorted(tpimports.items()):
jnames = ', '.join(names)
line = f'from {module} import {jnames}'
if len(line) > 75: # Account for indent
# Recreate in a wrapping-friendly form.
line = f'from {module} import ({jnames})'
tpimport_lines += f'{line}\n'
if part == 'sender':
import_lines += ('from efro.message import MessageSender,'
' BoundMessageSender')
tpimport_typing_extras = ''
else:
if single_message_type:
import_lines += ('from efro.message import (MessageReceiver,'
' BoundMessageReceiver, Message, Response)')
else:
import_lines += ('from efro.message import MessageReceiver,'
' BoundMessageReceiver')
tpimport_typing_extras = ', Awaitable'
ovld = ', overload' if not single_message_type else ''
tpimport_lines = textwrap.indent(tpimport_lines, ' ')
out = ('# Released under the MIT License. See LICENSE for details.\n'
f'#\n'
f'"""Auto-generated {part} module. Do not edit by hand."""\n'
f'\n'
f'from __future__ import annotations\n'
f'\n'
f'from typing import TYPE_CHECKING{ovld}\n'
f'\n'
f'{import_lines}\n'
f'\n'
f'if TYPE_CHECKING:\n'
f' from typing import Union, Any, Optional, Callable'
f'{tpimport_typing_extras}\n'
f'{tpimport_lines}'
f'\n'
f'\n')
return out
def do_create_sender_module(self,
basename: str,
protocol_create_code: str,
enable_sync_sends: bool,
enable_async_sends: bool,
private: bool = False) -> str:
"""Used by create_sender_module(); do not call directly."""
# pylint: disable=too-many-locals
import textwrap
msgtypes = list(self.message_ids_by_type.keys())
ppre = '_' if private else ''
out = self._get_module_header('sender')
ccind = textwrap.indent(protocol_create_code, ' ')
out += (f'class {ppre}{basename}(MessageSender):\n'
f' """Protocol-specific sender."""\n'
f'\n'
f' def __init__(self) -> None:\n'
f'{ccind}\n'
f' super().__init__(protocol)\n'
f'\n'
f' def __get__(self,\n'
f' obj: Any,\n'
f' type_in: Any = None)'
f' -> {ppre}Bound{basename}:\n'
f' return {ppre}Bound{basename}'
f'(obj, self)\n'
f'\n'
f'\n'
f'class {ppre}Bound{basename}(BoundMessageSender):\n'
f' """Protocol-specific bound sender."""\n')
def _filt_tp_name(rtype: type[Response]) -> str:
# We accept None to equal EmptyResponse so reflect that
# in the type annotation.
return 'None' if rtype is EmptyResponse else rtype.__name__
# Define handler() overloads for all registered message types.
if msgtypes:
for async_pass in False, True:
if async_pass and not enable_async_sends:
continue
if not async_pass and not enable_sync_sends:
continue
pfx = 'async ' if async_pass else ''
sfx = '_async' if async_pass else ''
awt = 'await ' if async_pass else ''
how = 'asynchronously' if async_pass else 'synchronously'
if len(msgtypes) == 1:
# Special case: with a single message types we don't
# use overloads.
msgtype = msgtypes[0]
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
rtypevar = f'Union[{tps}]'
else:
rtypevar = _filt_tp_name(rtypes[0])
out += (f'\n'
f' {pfx}def send{sfx}(self,'
f' message: {msgtypevar})'
f' -> {rtypevar}:\n'
f' """Send a message {how}."""\n'
f' out = {awt}self._sender.'
f'send{sfx}(self._obj, message)\n'
f' assert isinstance(out, {rtypevar})\n'
f' return out\n')
else:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
rtypevar = f'Union[{tps}]'
else:
rtypevar = _filt_tp_name(rtypes[0])
out += (f'\n'
f' @overload\n'
f' {pfx}def send{sfx}(self,'
f' message: {msgtypevar})'
f' -> {rtypevar}:\n'
f' ...\n')
out += (f'\n'
f' {pfx}def send{sfx}(self, message: Message)'
f' -> Optional[Response]:\n'
f' """Send a message {how}."""\n'
f' return {awt}self._sender.'
f'send{sfx}(self._obj, message)\n')
return out
def do_create_receiver_module(self,
basename: str,
protocol_create_code: str,
is_async: bool,
private: bool = False) -> str:
"""Used by create_receiver_module(); do not call directly."""
# pylint: disable=too-many-locals
import textwrap
desc = 'asynchronous' if is_async else 'synchronous'
ppre = '_' if private else ''
msgtypes = list(self.message_ids_by_type.keys())
out = self._get_module_header('receiver')
ccind = textwrap.indent(protocol_create_code, ' ')
out += (f'class {ppre}{basename}(MessageReceiver):\n'
f' """Protocol-specific {desc} receiver."""\n'
f'\n'
f' is_async = {is_async}\n'
f'\n'
f' def __init__(self) -> None:\n'
f'{ccind}\n'
f' super().__init__(protocol)\n'
f'\n'
f' def __get__(\n'
f' self,\n'
f' obj: Any,\n'
f' type_in: Any = None,\n'
f' ) -> {ppre}Bound{basename}:\n'
f' return {ppre}Bound{basename}('
f'obj, self)\n')
# Define handler() overloads for all registered message types.
def _filt_tp_name(rtype: type[Response]) -> str:
# We accept None to equal EmptyResponse so reflect that
# in the type annotation.
return 'None' if rtype is EmptyResponse else rtype.__name__
if msgtypes:
cbgn = 'Awaitable[' if is_async else ''
cend = ']' if is_async else ''
if len(msgtypes) == 1:
# Special case: when we have a single message type we don't
# use overloads.
msgtype = msgtypes[0]
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
rtypevar = f'Union[{tps}]'
else:
rtypevar = _filt_tp_name(rtypes[0])
rtypevar = f'{cbgn}{rtypevar}{cend}'
out += (
f'\n'
f' def handler(\n'
f' self,\n'
f' call: Callable[[Any, {msgtypevar}], '
f'{rtypevar}],\n'
f' )'
f' -> Callable[[Any, {msgtypevar}], {rtypevar}]:\n'
f' """Decorator to register message handlers."""\n'
f' from typing import cast, Callable, Any\n'
f' self.register_handler(cast(Callable'
f'[[Any, Message], Response], call))\n'
f' return call\n')
else:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
rtypevar = f'Union[{tps}]'
else:
rtypevar = _filt_tp_name(rtypes[0])
rtypevar = f'{cbgn}{rtypevar}{cend}'
out += (f'\n'
f' @overload\n'
f' def handler(\n'
f' self,\n'
f' call: Callable[[Any, {msgtypevar}], '
f'{rtypevar}],\n'
f' )'
f' -> Callable[[Any, {msgtypevar}], {rtypevar}]:\n'
f' ...\n')
out += (
'\n'
' def handler(self, call: Callable) -> Callable:\n'
' """Decorator to register message handlers."""\n'
' self.register_handler(call)\n'
' return call\n')
out += (f'\n'
f'\n'
f'class {ppre}Bound{basename}(BoundMessageReceiver):\n'
f' """Protocol-specific bound receiver."""\n')
if is_async:
out += (
'\n'
' async def handle_raw_message(self, message: str)'
' -> str:\n'
' """Asynchronously handle a raw incoming message."""\n'
' return await'
' self._receiver.handle_raw_message_async(\n'
' self._obj, message)\n')
else:
out += (
'\n'
' def handle_raw_message(self, message: str) -> str:\n'
' """Synchronously handle a raw incoming message."""\n'
' return self._receiver.handle_raw_message'
'(self._obj, message)\n')
return out
class MessageSender:
"""Facilitates sending messages to a target and receiving responses.
This is instantiated at the class level and used to register unbound
class methods to handle raw message sending.
Example:
class MyClass:
msg = MyMessageSender(some_protocol)
@msg.send_method
def send_raw_message(self, message: str) -> str:
# Actually send the message here.
# MyMessageSender class should provide overloads for send(), send_bg(),
# etc. to ensure all sending happens with valid types.
obj = MyClass()
obj.msg.send(SomeMessageType())
"""
def __init__(self, protocol: MessageProtocol) -> None:
self.protocol = protocol
self._send_raw_message_call: Optional[Callable[[Any, str], str]] = None
self._send_async_raw_message_call: Optional[Callable[
[Any, str], Awaitable[str]]] = None
def send_method(
self, call: Callable[[Any, str],
str]) -> Callable[[Any, str], str]:
"""Function decorator for setting raw send method."""
assert self._send_raw_message_call is None
self._send_raw_message_call = call
return call
def send_async_method(
self, call: Callable[[Any, str], Awaitable[str]]
) -> Callable[[Any, str], Awaitable[str]]:
"""Function decorator for setting raw send-async method."""
assert self._send_async_raw_message_call is None
self._send_async_raw_message_call = call
return call
def send(self, bound_obj: Any, message: Message) -> Optional[Response]:
"""Send a message and receive a response.
Will encode the message for transport and call dispatch_raw_message()
"""
if self._send_raw_message_call is None:
raise RuntimeError('send() is unimplemented for this type.')
msg_encoded = self.protocol.encode_message(message)
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
response = self.protocol.decode_response(response_encoded)
assert isinstance(response, (Response, type(None)))
assert (response is None
or type(response) in type(message).get_response_types())
return response
async def send_async(self, bound_obj: Any,
message: Message) -> Optional[Response]:
"""Send a message asynchronously using asyncio.
The message will be encoded for transport and passed to
dispatch_raw_message_async.
"""
if self._send_async_raw_message_call is None:
raise RuntimeError('send_async() is unimplemented for this type.')
msg_encoded = self.protocol.encode_message(message)
response_encoded = await self._send_async_raw_message_call(
bound_obj, msg_encoded)
response = self.protocol.decode_response(response_encoded)
assert isinstance(response, (Response, type(None)))
assert (response is None
or type(response) in type(message).get_response_types())
return response
class BoundMessageSender:
"""Base class for bound senders."""
def __init__(self, obj: Any, sender: MessageSender) -> None:
assert obj is not None
self._obj = obj
self._sender = sender
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this sender."""
return self._sender.protocol
def send_untyped(self, message: Message) -> Optional[Response]:
"""Send a message synchronously.
Whenever possible, use the send() call provided by generated
subclasses instead of this; it will provide better type safety.
"""
return self._sender.send(self._obj, message)
async def send_async_untyped(self, message: Message) -> Optional[Response]:
"""Send a message asynchronously.
Whenever possible, use the send_async() call provided by generated
subclasses instead of this; it will provide better type safety.
"""
return await self._sender.send_async(self._obj, message)
class MessageReceiver:
"""Facilitates receiving & responding to messages from a remote source.
This is instantiated at the class level with unbound methods registered
as handlers for different message types in the protocol.
Example:
class MyClass:
receiver = MyMessageReceiver()
# MyMessageReceiver fills out handler() overloads to ensure all
# registered handlers have valid types/return-types.
@receiver.handler
def handle_some_message_type(self, message: SomeMsg) -> SomeResponse:
# Deal with this message type here.
# This will trigger the registered handler being called.
obj = MyClass()
obj.receiver.handle_raw_message(some_raw_data)
Any unhandled Exception occurring during message handling will result in
an Exception being raised on the sending end.
"""
is_async = False
def __init__(self, protocol: MessageProtocol) -> None:
self.protocol = protocol
self._handlers: dict[type[Message], Callable] = {}
# noinspection PyProtectedMember
def register_handler(
self, call: Callable[[Any, Message], Optional[Response]]) -> None:
"""Register a handler call.
The message type handled by the call is determined by its
type annotation.
"""
# TODO: can use types.GenericAlias in 3.9.
from typing import _GenericAlias # type: ignore
from typing import get_type_hints, get_args
sig = inspect.getfullargspec(call)
# The provided callable should be a method taking one 'msg' arg.
expectedsig = ['self', 'msg']
if sig.args != expectedsig:
raise ValueError(f'Expected callable signature of {expectedsig};'
f' got {sig.args}')
# Make sure we are only given async methods if we are an async handler
# and sync ones otherwise.
is_async = inspect.iscoroutinefunction(call)
if self.is_async != is_async:
msg = ('Expected a sync method; found an async one.' if is_async
else 'Expected an async method; found a sync one.')
raise ValueError(msg)
# Check annotation types to determine what message types we handle.
# Return-type annotation can be a Union, but we probably don't
# have it available at runtime. Explicitly pull it in.
# UPDATE: we've updated our pylint filter to where we should
# have all annotations available.
# anns = get_type_hints(call, localns={'Union': Union})
anns = get_type_hints(call)
msgtype = anns.get('msg')
if not isinstance(msgtype, type):
raise TypeError(
f'expected a type for "msg" annotation; got {type(msgtype)}.')
assert issubclass(msgtype, Message)
ret = anns.get('return')
responsetypes: tuple[Union[type[Any], type[None]], ...]
# Return types can be a single type or a union of types.
if isinstance(ret, _GenericAlias):
targs = get_args(ret)
if not all(isinstance(a, type) for a in targs):
raise TypeError(f'expected only types for "return" annotation;'
f' got {targs}.')
responsetypes = targs
else:
if not isinstance(ret, type):
raise TypeError(f'expected one or more types for'
f' "return" annotation; got a {type(ret)}.')
responsetypes = (ret, )
# Return type of None translates to EmptyResponse.
responsetypes = tuple(EmptyResponse if r is type(None) else r
for r in responsetypes) # noqa
# Make sure our protocol has this message type registered and our
# return types exactly match. (Technically we could return a subset
# of the supported types; can allow this in the future if it makes
# sense).
registered_types = self.protocol.message_ids_by_type.keys()
if msgtype not in registered_types:
raise TypeError(f'Message type {msgtype} is not registered'
f' in this Protocol.')
if msgtype in self._handlers:
raise TypeError(f'Message type {msgtype} already has a registered'
f' handler.')
# Make sure the responses exactly matches what the message expects.
if set(responsetypes) != set(msgtype.get_response_types()):
raise TypeError(
f'Provided response types {responsetypes} do not'
f' match the set expected by message type {msgtype}: '
f'({msgtype.get_response_types()})')
# Ok; we're good!
self._handlers[msgtype] = call
def validate(self, warn_only: bool = False) -> None:
"""Check for handler completeness, valid types, etc."""
for msgtype in self.protocol.message_ids_by_type.keys():
if issubclass(msgtype, Response):
continue
if msgtype not in self._handlers:
msg = (f'Protocol message type {msgtype} is not handled'
f' by receiver type {type(self)}.')
if warn_only:
logging.warning(msg)
else:
raise TypeError(msg)
def _decode_incoming_message(self,
msg: str) -> tuple[Message, type[Message]]:
# Decode the incoming message.
msg_decoded = self.protocol.decode_message(msg)
msgtype = type(msg_decoded)
assert issubclass(msgtype, Message)
return msg_decoded, msgtype
def _encode_response(self, response: Optional[Response],
msgtype: type[Message]) -> str:
# A return value of None equals EmptyResponse.
if response is None:
response = EmptyResponse()
# Re-encode the response.
assert isinstance(response, Response)
# (user should never explicitly return these)
assert not isinstance(response, ErrorResponse)
assert type(response) in msgtype.get_response_types()
return self.protocol.encode_response(response)
def raw_response_for_error(self, exc: Exception) -> str:
"""Return a raw response for an error that occurred during handling."""
if self.protocol.log_remote_exceptions:
logging.exception('Error handling message.')
# If anything goes wrong, return a ErrorResponse instead.
if (isinstance(exc, CleanError)
and self.protocol.preserve_clean_errors):
err_response = ErrorResponse(error_message=str(exc),
error_type=ErrorType.CLEAN)
else:
err_response = ErrorResponse(
error_message=(traceback.format_exc()
if self.protocol.trusted_sender else
'An unknown error has occurred.'),
error_type=ErrorType.OTHER)
return self.protocol.encode_response(err_response)
def handle_raw_message(self, bound_obj: Any, msg: str) -> str:
"""Decode, handle, and return an response for a message."""
assert not self.is_async, "can't call sync handler on async receiver"
try:
msg_decoded, msgtype = self._decode_incoming_message(msg)
handler = self._handlers.get(msgtype)
if handler is None:
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
result = handler(bound_obj, msg_decoded)
return self._encode_response(result, msgtype)
except Exception as exc:
return self.raw_response_for_error(exc)
async def handle_raw_message_async(self, bound_obj: Any, msg: str) -> str:
"""Should be called when the receiver gets a message.
The return value is the raw response to the message.
"""
assert self.is_async, "can't call async handler on sync receiver"
try:
msg_decoded, msgtype = self._decode_incoming_message(msg)
handler = self._handlers.get(msgtype)
if handler is None:
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
result = await handler(bound_obj, msg_decoded)
return self._encode_response(result, msgtype)
except Exception as exc:
return self.raw_response_for_error(exc)
class BoundMessageReceiver:
"""Base bound receiver class."""
def __init__(
self,
obj: Any,
receiver: MessageReceiver,
) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this receiver."""
return self._receiver.protocol
def raw_response_for_error(self, exc: Exception) -> str:
"""Return a raw response for an error that occurred during handling.
This is automatically called from standard handle_raw_message_x()
calls but can be manually invoked if errors occur outside of there.
This gives clients a better idea of what went wrong vs simply
returning invalid data which they might dismiss as a connection
related error.
"""
return self._receiver.raw_response_for_error(exc)
def create_sender_module(basename: str,
protocol_create_code: str,
enable_sync_sends: bool,
enable_async_sends: bool,
private: bool = False) -> str:
"""Create a Python module defining a MessageSender subclass.
This class is primarily for type checking and will contain overrides
for the varieties of send calls for message/response types defined
in the protocol.
Code passed for 'protocol_create_code' should import necessary
modules and assign an instance of the Protocol to a 'protocol'
variable.
Class names are based on basename; a basename 'FooSender' will
result in classes FooSender and BoundFooSender.
If 'private' is True, class-names will be prefixed with an '_'.
Note that line lengths are not clipped, so output may need to be
run through a formatter to prevent lint warnings about excessive
line lengths.
"""
# Exec the passed code to get a protocol which we then use to
# generate module code. The user could simply call
# MessageProtocol.do_create_sender_module() directly, but this allows
# us to verify that the create code works and yields the protocol used
# to generate the code.
protocol = _protocol_from_code(protocol_create_code)
return protocol.do_create_sender_module(
basename=basename,
protocol_create_code=protocol_create_code,
enable_sync_sends=enable_sync_sends,
enable_async_sends=enable_async_sends,
private=private)
def create_receiver_module(basename: str,
protocol_create_code: str,
is_async: bool,
private: bool = False) -> str:
""""Create a Python module defining a MessageReceiver subclass.
This class is primarily for type checking and will contain overrides
for the register method for message/response types defined in
the protocol.
Class names are based on basename; a basename 'FooReceiver' will
result in FooReceiver and BoundFooReceiver.
If 'is_async' is True, handle_raw_message() will be an async method
and the @handler decorator will expect async methods.
If 'private' is True, class-names will be prefixed with an '_'.
Note that line lengths are not clipped, so output may need to be
run through a formatter to prevent lint warnings about excessive
line lengths.
"""
# Exec the passed code to get a protocol which we then use to
# generate module code. The user could simply call
# MessageProtocol.do_create_sender_module() directly, but this allows
# us to verify that the create code works and yields the protocol used
# to generate the code.
protocol = _protocol_from_code(protocol_create_code)
return protocol.do_create_receiver_module(
basename=basename,
protocol_create_code=protocol_create_code,
is_async=is_async,
private=private)
def _protocol_from_code(protocol_create_code: str) -> MessageProtocol:
env: dict = {}
exec(protocol_create_code, env) # pylint: disable=exec-used
protocol = env.get('protocol')
if not isinstance(protocol, MessageProtocol):
raise RuntimeError(
f'protocol_create_code yielded'
f' a {type(protocol)}; expected a MessageProtocol instance.')
return protocol

View file

@ -0,0 +1,45 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for sending and responding to messages.
Supports static typing for message types and possible return types.
"""
from __future__ import annotations
from efro.util import set_canonical_module
from efro.message._protocol import MessageProtocol
from efro.message._sender import MessageSender, BoundMessageSender
from efro.message._receiver import MessageReceiver, BoundMessageReceiver
from efro.message._module import create_sender_module, create_receiver_module
from efro.message._message import (
Message,
Response,
SysResponse,
EmptySysResponse,
ErrorSysResponse,
StringResponse,
BoolResponse,
UnregisteredMessageIDError,
)
__all__ = [
'Message',
'Response',
'SysResponse',
'EmptySysResponse',
'ErrorSysResponse',
'StringResponse',
'BoolResponse',
'MessageProtocol',
'MessageSender',
'BoundMessageSender',
'MessageReceiver',
'BoundMessageReceiver',
'create_sender_module',
'create_receiver_module',
'UnregisteredMessageIDError',
]
# Have these things present themselves cleanly as 'thismodule.SomeClass'
# instead of 'thismodule._internalmodule.SomeClass'
set_canonical_module(module_globals=globals(), names=__all__)

View file

@ -0,0 +1,108 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for sending and responding to messages.
Supports static typing for message types and possible return types.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Annotated
from dataclasses import dataclass
from enum import Enum
from efro.dataclassio import ioprepped, IOAttrs
if TYPE_CHECKING:
pass
class UnregisteredMessageIDError(Exception):
"""A message or response id is not covered by our protocol."""
class Message:
"""Base class for messages."""
@classmethod
def get_response_types(cls) -> list[type[Response] | None]:
"""Return all Response types this Message can return when sent.
The default implementation specifies a None return type.
"""
return [None]
class Response:
"""Base class for responses to messages."""
class SysResponse:
"""Base class for system-responses to messages.
These are only sent/handled by the messaging system itself;
users of the api never see them.
"""
def set_local_exception(self, exc: Exception) -> None:
"""Attach a local exception to facilitate better logging/handling.
Be aware that this data does not get serialized and only
exists on the local object.
"""
setattr(self, '_sr_local_exception', exc)
def get_local_exception(self) -> Exception | None:
"""Fetch a local attached exception."""
value = getattr(self, '_sr_local_exception', None)
assert isinstance(value, Exception | None)
return value
# Some standard response types:
@ioprepped
@dataclass
class ErrorSysResponse(SysResponse):
"""SysResponse saying some error has occurred for the send.
This generally results in an Exception being raised for the caller.
"""
class ErrorType(Enum):
"""Type of error that occurred while sending a message."""
REMOTE = 0
REMOTE_CLEAN = 1
LOCAL = 2
COMMUNICATION = 3
REMOTE_COMMUNICATION = 4
error_message: Annotated[str, IOAttrs('m')]
error_type: Annotated[ErrorType, IOAttrs('e')] = ErrorType.REMOTE
@ioprepped
@dataclass
class EmptySysResponse(SysResponse):
"""The response equivalent of None."""
# TODO: could allow handlers to deal in raw values for these
# types similar to how we allow None in place of EmptySysResponse.
# Though not sure if they are widely used enough to warrant the
# extra code complexity.
@ioprepped
@dataclass
class BoolResponse(Response):
"""A simple bool value response."""
value: Annotated[bool, IOAttrs('v')]
@ioprepped
@dataclass
class StringResponse(Response):
"""A simple string value response."""
value: Annotated[str, IOAttrs('v')]

View file

@ -0,0 +1,109 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for sending and responding to messages.
Supports static typing for message types and possible return types.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from efro.message._protocol import MessageProtocol
if TYPE_CHECKING:
pass
def create_sender_module(
basename: str,
protocol_create_code: str,
enable_sync_sends: bool,
enable_async_sends: bool,
private: bool = False,
protocol_module_level_import_code: str | None = None,
build_time_protocol_create_code: str | None = None,
) -> str:
"""Create a Python module defining a MessageSender subclass.
This class is primarily for type checking and will contain overrides
for the varieties of send calls for message/response types defined
in the protocol.
Code passed for 'protocol_create_code' should import necessary
modules and assign an instance of the Protocol to a 'protocol'
variable.
Class names are based on basename; a basename 'FooSender' will
result in classes FooSender and BoundFooSender.
If 'private' is True, class-names will be prefixed with an '_'.
Note: output code may have long lines and should generally be run
through a formatter. We should perhaps move this functionality to
efrotools so we can include that functionality inline.
"""
protocol = _protocol_from_code(
build_time_protocol_create_code
if build_time_protocol_create_code is not None
else protocol_create_code
)
return protocol.do_create_sender_module(
basename=basename,
protocol_create_code=protocol_create_code,
enable_sync_sends=enable_sync_sends,
enable_async_sends=enable_async_sends,
private=private,
protocol_module_level_import_code=protocol_module_level_import_code,
)
def create_receiver_module(
basename: str,
protocol_create_code: str,
is_async: bool,
private: bool = False,
protocol_module_level_import_code: str | None = None,
build_time_protocol_create_code: str | None = None,
) -> str:
""" "Create a Python module defining a MessageReceiver subclass.
This class is primarily for type checking and will contain overrides
for the register method for message/response types defined in
the protocol.
Class names are based on basename; a basename 'FooReceiver' will
result in FooReceiver and BoundFooReceiver.
If 'is_async' is True, handle_raw_message() will be an async method
and the @handler decorator will expect async methods.
If 'private' is True, class-names will be prefixed with an '_'.
Note that line lengths are not clipped, so output may need to be
run through a formatter to prevent lint warnings about excessive
line lengths.
"""
protocol = _protocol_from_code(
build_time_protocol_create_code
if build_time_protocol_create_code is not None
else protocol_create_code
)
return protocol.do_create_receiver_module(
basename=basename,
protocol_create_code=protocol_create_code,
is_async=is_async,
private=private,
protocol_module_level_import_code=protocol_module_level_import_code,
)
def _protocol_from_code(protocol_create_code: str) -> MessageProtocol:
env: dict = {}
exec(protocol_create_code, env) # pylint: disable=exec-used
protocol = env.get('protocol')
if not isinstance(protocol, MessageProtocol):
raise RuntimeError(
f'protocol_create_code yielded'
f' a {type(protocol)}; expected a MessageProtocol instance.'
)
return protocol

View file

@ -0,0 +1,653 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for sending and responding to messages.
Supports static typing for message types and possible return types.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import traceback
import json
from efro.error import CleanError, CommunicationError
from efro.dataclassio import (
is_ioprepped_dataclass,
dataclass_to_dict,
dataclass_from_dict,
)
from efro.message._message import (
Message,
Response,
SysResponse,
ErrorSysResponse,
EmptySysResponse,
UnregisteredMessageIDError,
)
if TYPE_CHECKING:
from typing import Any, Literal
class MessageProtocol:
"""Wrangles a set of message types, formats, and response types.
Both endpoints must be using a compatible Protocol for communication
to succeed. To maintain Protocol compatibility between revisions,
all message types must retain the same id, message attr storage
names must not change, newly added attrs must have default values,
etc.
"""
def __init__(
self,
message_types: dict[int, type[Message]],
response_types: dict[int, type[Response]],
forward_communication_errors: bool = False,
forward_clean_errors: bool = False,
remote_errors_include_stack_traces: bool = False,
log_remote_errors: bool = True,
) -> None:
"""Create a protocol with a given configuration.
If 'forward_communication_errors' is True,
efro.error.CommunicationErrors raised on the receiver end will
result in a matching error raised back on the sender. This can
be useful if the receiver will be in some way forwarding
messages along and the sender doesn't need to know where
communication breakdowns occurred; only that they did.
If 'forward_clean_errors' is True, efro.error.CleanError
exceptions raised on the receiver end will result in a matching
CleanError raised back on the sender.
When an exception is not covered by the optional forwarding
mechanisms above, it will come across as efro.error.RemoteError
and the exception will be logged on the receiver
end - at least by default (see details below).
If 'remote_errors_include_stack_traces' is True, stringified
stack traces will be returned with efro.error.RemoteError
exceptions. This is useful for debugging but should only be
enabled in cases where the sender is trusted to see internal
details of the receiver.
By default, when a message-handling exception will result in an
efro.error.RemoteError being returned to the sender, the
exception will be logged on the receiver. This is because the
goal is usually to avoid returning opaque RemoteErrors and to
instead return something meaningful as part of the expected
response type (even if that value itself represents a logical
error state). If 'log_remote_errors' is False, however, such
exceptions will not be logged on the receiver. This can be
useful in combination with 'remote_errors_include_stack_traces'
and 'forward_clean_errors' in situations where all error
logging/management will be happening on the sender end. Be
aware, however, that in that case it may be possible for
communication errors to prevent such error messages from
ever being seen.
"""
# pylint: disable=too-many-locals
self.message_types_by_id: dict[int, type[Message]] = {}
self.message_ids_by_type: dict[type[Message], int] = {}
self.response_types_by_id: dict[
int, type[Response] | type[SysResponse]
] = {}
self.response_ids_by_type: dict[
type[Response] | type[SysResponse], int
] = {}
for m_id, m_type in message_types.items():
# Make sure only valid message types were passed and each
# id was assigned only once.
assert isinstance(m_id, int)
assert m_id >= 0
assert is_ioprepped_dataclass(m_type) and issubclass(
m_type, Message
)
assert self.message_types_by_id.get(m_id) is None
self.message_types_by_id[m_id] = m_type
self.message_ids_by_type[m_type] = m_id
for r_id, r_type in response_types.items():
assert isinstance(r_id, int)
assert r_id >= 0
assert is_ioprepped_dataclass(r_type) and issubclass(
r_type, Response
)
assert self.response_types_by_id.get(r_id) is None
self.response_types_by_id[r_id] = r_type
self.response_ids_by_type[r_type] = r_id
# Register our SysResponse types. These use negative
# IDs so as to never overlap with user Response types.
def _reg_sys(reg_tp: type[SysResponse], reg_id: int) -> None:
assert self.response_types_by_id.get(reg_id) is None
self.response_types_by_id[reg_id] = reg_tp
self.response_ids_by_type[reg_tp] = reg_id
_reg_sys(ErrorSysResponse, -1)
_reg_sys(EmptySysResponse, -2)
# Some extra-thorough validation in debug mode.
if __debug__:
# Make sure all Message types' return types are valid
# and have been assigned an ID as well.
all_response_types: set[type[Response] | None] = set()
for m_id, m_type in message_types.items():
m_rtypes = m_type.get_response_types()
assert isinstance(m_rtypes, list)
assert (
m_rtypes
), f'Message type {m_type} specifies no return types.'
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
for m_rtype in m_rtypes:
all_response_types.add(m_rtype)
for cls in all_response_types:
if cls is None:
continue
assert is_ioprepped_dataclass(cls)
assert issubclass(cls, Response)
if cls not in self.response_ids_by_type:
raise ValueError(
f'Possible response type {cls} needs to be included'
f' in response_types for this protocol.'
)
# Make sure all registered types have unique base names.
# We can take advantage of this to generate cleaner looking
# protocol modules. Can revisit if this is ever a problem.
mtypenames = set(tp.__name__ for tp in self.message_ids_by_type)
if len(mtypenames) != len(message_types):
raise ValueError(
'message_types contains duplicate __name__s;'
' all types are required to have unique names.'
)
self.forward_clean_errors = forward_clean_errors
self.forward_communication_errors = forward_communication_errors
self.remote_errors_include_stack_traces = (
remote_errors_include_stack_traces
)
self.log_remote_errors = log_remote_errors
@staticmethod
def encode_dict(obj: dict) -> str:
"""Json-encode a provided dict."""
return json.dumps(obj, separators=(',', ':'))
def message_to_dict(self, message: Message) -> dict:
"""Encode a message to a json ready dict."""
return self._to_dict(message, self.message_ids_by_type, 'message')
def response_to_dict(self, response: Response | SysResponse) -> dict:
"""Encode a response to a json ready dict."""
return self._to_dict(response, self.response_ids_by_type, 'response')
def error_to_response(self, exc: Exception) -> tuple[SysResponse, bool]:
"""Translate an Exception to a SysResponse.
Also returns whether the error should be logged if this happened
within handle_raw_message().
"""
# If anything goes wrong, return a ErrorSysResponse instead.
# (either CLEAN or generic REMOTE)
if self.forward_clean_errors and isinstance(exc, CleanError):
return (
ErrorSysResponse(
error_message=str(exc),
error_type=ErrorSysResponse.ErrorType.REMOTE_CLEAN,
),
False,
)
if self.forward_communication_errors and isinstance(
exc, CommunicationError
):
return (
ErrorSysResponse(
error_message=str(exc),
error_type=ErrorSysResponse.ErrorType.REMOTE_COMMUNICATION,
),
False,
)
return (
ErrorSysResponse(
error_message=(
traceback.format_exc()
if self.remote_errors_include_stack_traces
else 'An internal error has occurred.'
),
error_type=ErrorSysResponse.ErrorType.REMOTE,
),
self.log_remote_errors,
)
def _to_dict(
self, message: Any, ids_by_type: dict[type, int], opname: str
) -> dict:
"""Encode a message to a json string for transport."""
m_id: int | None = ids_by_type.get(type(message))
if m_id is None:
raise TypeError(
f'{opname} type is not registered in protocol:'
f' {type(message)}'
)
out = {'t': m_id, 'm': dataclass_to_dict(message)}
return out
@staticmethod
def decode_dict(data: str) -> dict:
"""Decode data to a dict."""
out = json.loads(data)
assert isinstance(out, dict)
return out
def message_from_dict(self, data: dict) -> Message:
"""Decode a message from a json string."""
out = self._from_dict(data, self.message_types_by_id, 'message')
assert isinstance(out, Message)
return out
def response_from_dict(self, data: dict) -> Response | SysResponse:
"""Decode a response from a json string."""
out = self._from_dict(data, self.response_types_by_id, 'response')
assert isinstance(out, Response | SysResponse)
return out
# Weeeird; we get mypy errors returning dict[int, type] but
# dict[int, typing.Type] or dict[int, type[Any]] works..
def _from_dict(
self, data: dict, types_by_id: dict[int, type[Any]], opname: str
) -> Any:
"""Decode a message from a json string."""
msgdict: dict | None
m_id = data.get('t')
# Allow omitting 'm' dict if its empty.
msgdict = data.get('m', {})
assert isinstance(m_id, int)
assert isinstance(msgdict, dict)
# Decode this particular type.
msgtype = types_by_id.get(m_id)
if msgtype is None:
raise UnregisteredMessageIDError(
f'Got unregistered {opname} id of {m_id}.'
)
return dataclass_from_dict(msgtype, msgdict)
def _get_module_header(
self,
part: Literal['sender', 'receiver'],
extra_import_code: str | None,
enable_async_sends: bool,
) -> str:
"""Return common parts of generated modules."""
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
import textwrap
tpimports: dict[str, list[str]] = {}
imports: dict[str, list[str]] = {}
single_message_type = len(self.message_ids_by_type) == 1
msgtypes = list(self.message_ids_by_type)
if part == 'sender':
msgtypes.append(Message)
for msgtype in msgtypes:
tpimports.setdefault(msgtype.__module__, []).append(
msgtype.__name__
)
rsptypes = list(self.response_ids_by_type)
if part == 'sender':
rsptypes.append(Response)
for rsp_tp in rsptypes:
# Skip these as they don't actually show up in code.
if rsp_tp is EmptySysResponse or rsp_tp is ErrorSysResponse:
continue
if (
single_message_type
and part == 'sender'
and rsp_tp is not Response
):
# We need to cast to the single supported response type
# in this case so need response types at runtime.
imports.setdefault(rsp_tp.__module__, []).append(
rsp_tp.__name__
)
else:
tpimports.setdefault(rsp_tp.__module__, []).append(
rsp_tp.__name__
)
import_lines = ''
tpimport_lines = ''
for module, names in sorted(imports.items()):
jnames = ', '.join(names)
line = f'from {module} import {jnames}'
if len(line) > 79:
# Recreate in a wrapping-friendly form.
line = f'from {module} import ({jnames})'
import_lines += f'{line}\n'
for module, names in sorted(tpimports.items()):
jnames = ', '.join(names)
line = f'from {module} import {jnames}'
if len(line) > 75: # Account for indent
# Recreate in a wrapping-friendly form.
line = f'from {module} import ({jnames})'
tpimport_lines += f'{line}\n'
if part == 'sender':
import_lines += (
'from efro.message import MessageSender, BoundMessageSender'
)
tpimport_typing_extras = ''
else:
if single_message_type:
import_lines += (
'from efro.message import (MessageReceiver,'
' BoundMessageReceiver, Message, Response)'
)
else:
import_lines += (
'from efro.message import MessageReceiver,'
' BoundMessageReceiver'
)
tpimport_typing_extras = ', Awaitable'
if extra_import_code is not None:
import_lines += f'\n{extra_import_code}\n'
ovld = ', overload' if not single_message_type else ''
ovld2 = (
', cast, Awaitable'
if (single_message_type and part == 'sender' and enable_async_sends)
else ''
)
tpimport_lines = textwrap.indent(tpimport_lines, ' ')
baseimps = ['Any']
if part == 'receiver':
baseimps.append('Callable')
if part == 'sender' and enable_async_sends:
baseimps.append('Awaitable')
baseimps_s = ', '.join(baseimps)
out = (
'# Released under the MIT License. See LICENSE for details.\n'
f'#\n'
f'"""Auto-generated {part} module. Do not edit by hand."""\n'
f'\n'
f'from __future__ import annotations\n'
f'\n'
f'from typing import TYPE_CHECKING{ovld}{ovld2}\n'
f'\n'
f'{import_lines}\n'
f'\n'
f'if TYPE_CHECKING:\n'
f' from typing import {baseimps_s}'
f'{tpimport_typing_extras}\n'
f'{tpimport_lines}'
f'\n'
f'\n'
)
return out
def do_create_sender_module(
self,
basename: str,
protocol_create_code: str,
enable_sync_sends: bool,
enable_async_sends: bool,
private: bool = False,
protocol_module_level_import_code: str | None = None,
) -> str:
"""Used by create_sender_module(); do not call directly."""
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
import textwrap
msgtypes = list(self.message_ids_by_type.keys())
ppre = '_' if private else ''
out = self._get_module_header(
'sender',
extra_import_code=protocol_module_level_import_code,
enable_async_sends=enable_async_sends,
)
ccind = textwrap.indent(protocol_create_code, ' ')
out += (
f'class {ppre}{basename}(MessageSender):\n'
f' """Protocol-specific sender."""\n'
f'\n'
f' def __init__(self) -> None:\n'
f'{ccind}\n'
f' super().__init__(protocol)\n'
f'\n'
f' def __get__(\n'
f' self, obj: Any, type_in: Any = None\n'
f' ) -> {ppre}Bound{basename}:\n'
f' return {ppre}Bound{basename}(obj, self)\n'
f'\n'
f'\n'
f'class {ppre}Bound{basename}(BoundMessageSender):\n'
f' """Protocol-specific bound sender."""\n'
)
def _filt_tp_name(rtype: type[Response] | None) -> str:
return 'None' if rtype is None else rtype.__name__
# Define handler() overloads for all registered message types.
if msgtypes:
for async_pass in False, True:
if async_pass and not enable_async_sends:
continue
if not async_pass and not enable_sync_sends:
continue
pfx = 'async ' if async_pass else ''
sfx = '_async' if async_pass else ''
# awt = 'await ' if async_pass else ''
awt = ''
how = 'asynchronously' if async_pass else 'synchronously'
if len(msgtypes) == 1:
# Special case: with a single message types we don't
# use overloads.
msgtype = msgtypes[0]
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
rtypevar = ' | '.join(_filt_tp_name(t) for t in rtypes)
else:
rtypevar = _filt_tp_name(rtypes[0])
if async_pass:
rtypevar = f'Awaitable[{rtypevar}]'
out += (
f'\n'
f' def send{sfx}(self,'
f' message: {msgtypevar})'
f' -> {rtypevar}:\n'
f' """Send a message {how}."""\n'
f' out = {awt}self._sender.'
f'send{sfx}(self._obj, message)\n'
)
if not async_pass:
out += (
f' assert isinstance(out, {rtypevar})\n'
' return out\n'
)
else:
out += f' return cast({rtypevar}, out)\n'
else:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
rtypevar = ' | '.join(
_filt_tp_name(t) for t in rtypes
)
else:
rtypevar = _filt_tp_name(rtypes[0])
out += (
f'\n'
f' @overload\n'
f' {pfx}def send{sfx}(self,'
f' message: {msgtypevar})'
f' -> {rtypevar}:\n'
f' ...\n'
)
rtypevar = 'Response | None'
if async_pass:
rtypevar = f'Awaitable[{rtypevar}]'
out += (
f'\n'
f' def send{sfx}(self, message: Message)'
f' -> {rtypevar}:\n'
f' """Send a message {how}."""\n'
f' return {awt}self._sender.'
f'send{sfx}(self._obj, message)\n'
)
return out
def do_create_receiver_module(
self,
basename: str,
protocol_create_code: str,
is_async: bool,
private: bool = False,
protocol_module_level_import_code: str | None = None,
) -> str:
"""Used by create_receiver_module(); do not call directly."""
# pylint: disable=too-many-locals
import textwrap
desc = 'asynchronous' if is_async else 'synchronous'
ppre = '_' if private else ''
msgtypes = list(self.message_ids_by_type.keys())
out = self._get_module_header(
'receiver',
extra_import_code=protocol_module_level_import_code,
enable_async_sends=False,
)
ccind = textwrap.indent(protocol_create_code, ' ')
out += (
f'class {ppre}{basename}(MessageReceiver):\n'
f' """Protocol-specific {desc} receiver."""\n'
f'\n'
f' is_async = {is_async}\n'
f'\n'
f' def __init__(self) -> None:\n'
f'{ccind}\n'
f' super().__init__(protocol)\n'
f'\n'
f' def __get__(\n'
f' self,\n'
f' obj: Any,\n'
f' type_in: Any = None,\n'
f' ) -> {ppre}Bound{basename}:\n'
f' return {ppre}Bound{basename}('
f'obj, self)\n'
)
# Define handler() overloads for all registered message types.
def _filt_tp_name(rtype: type[Response] | None) -> str:
return 'None' if rtype is None else rtype.__name__
if msgtypes:
cbgn = 'Awaitable[' if is_async else ''
cend = ']' if is_async else ''
if len(msgtypes) == 1:
# Special case: when we have a single message type we don't
# use overloads.
msgtype = msgtypes[0]
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
rtypevar = ' | '.join(_filt_tp_name(t) for t in rtypes)
else:
rtypevar = _filt_tp_name(rtypes[0])
rtypevar = f'{cbgn}{rtypevar}{cend}'
out += (
f'\n'
f' def handler(\n'
f' self,\n'
f' call: Callable[[Any, {msgtypevar}], '
f'{rtypevar}],\n'
f' )'
f' -> Callable[[Any, {msgtypevar}], {rtypevar}]:\n'
f' """Decorator to register message handlers."""\n'
f' from typing import cast, Callable, Any\n'
f'\n'
f' self.register_handler(cast(Callable'
f'[[Any, Message], Response], call))\n'
f' return call\n'
)
else:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
rtypevar = ' | '.join(_filt_tp_name(t) for t in rtypes)
else:
rtypevar = _filt_tp_name(rtypes[0])
rtypevar = f'{cbgn}{rtypevar}{cend}'
out += (
f'\n'
f' @overload\n'
f' def handler(\n'
f' self,\n'
f' call: Callable[[Any, {msgtypevar}], '
f'{rtypevar}],\n'
f' )'
f' -> Callable[[Any, {msgtypevar}], {rtypevar}]:\n'
f' ...\n'
)
out += (
'\n'
' def handler(self, call: Callable) -> Callable:\n'
' """Decorator to register message handlers."""\n'
' self.register_handler(call)\n'
' return call\n'
)
out += (
f'\n'
f'\n'
f'class {ppre}Bound{basename}(BoundMessageReceiver):\n'
f' """Protocol-specific bound receiver."""\n'
)
if is_async:
out += (
'\n'
' def handle_raw_message(\n'
' self, message: str, raise_unregistered: bool = False\n'
' ) -> Awaitable[str]:\n'
' """Asynchronously handle a raw incoming message."""\n'
' return self._receiver.'
'handle_raw_message_async(\n'
' self._obj, message, raise_unregistered\n'
' )\n'
)
else:
out += (
'\n'
' def handle_raw_message(\n'
' self, message: str, raise_unregistered: bool = False\n'
' ) -> str:\n'
' """Synchronously handle a raw incoming message."""\n'
' return self._receiver.handle_raw_message(\n'
' self._obj, message, raise_unregistered\n'
' )\n'
)
return out

View file

@ -0,0 +1,420 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for sending and responding to messages.
Supports static typing for message types and possible return types.
"""
from __future__ import annotations
import types
import inspect
import logging
from typing import TYPE_CHECKING
from efro.message._message import (
Message,
Response,
EmptySysResponse,
UnregisteredMessageIDError,
)
if TYPE_CHECKING:
from typing import Any, Callable, Awaitable
from efro.message._protocol import MessageProtocol
from efro.message._message import SysResponse
class MessageReceiver:
"""Facilitates receiving & responding to messages from a remote source.
This is instantiated at the class level with unbound methods registered
as handlers for different message types in the protocol.
Example:
class MyClass:
receiver = MyMessageReceiver()
# MyMessageReceiver fills out handler() overloads to ensure all
# registered handlers have valid types/return-types.
@receiver.handler
def handle_some_message_type(self, message: SomeMsg) -> SomeResponse:
# Deal with this message type here.
# This will trigger the registered handler being called.
obj = MyClass()
obj.receiver.handle_raw_message(some_raw_data)
Any unhandled Exception occurring during message handling will result in
an Exception being raised on the sending end.
"""
is_async = False
def __init__(self, protocol: MessageProtocol) -> None:
self.protocol = protocol
self._handlers: dict[type[Message], Callable] = {}
self._decode_filter_call: Callable[
[Any, dict, Message], None
] | None = None
self._encode_filter_call: Callable[
[Any, Message | None, Response | SysResponse, dict], None
] | None = None
# noinspection PyProtectedMember
def register_handler(
self, call: Callable[[Any, Message], Response | None]
) -> None:
"""Register a handler call.
The message type handled by the call is determined by its
type annotation.
"""
# TODO: can use types.GenericAlias in 3.9.
# (hmm though now that we're there, it seems a drop-in
# replace gives us errors. Should re-test in 3.10 as it seems
# that typing_extensions handles it differently in that case)
from typing import _GenericAlias # type: ignore
from typing import get_type_hints, get_args
sig = inspect.getfullargspec(call)
# The provided callable should be a method taking one 'msg' arg.
expectedsig = ['self', 'msg']
if sig.args != expectedsig:
raise ValueError(
f'Expected callable signature of {expectedsig};'
f' got {sig.args}'
)
# Make sure we are only given async methods if we are an async handler
# and sync ones otherwise.
# UPDATE - can't do this anymore since we now sometimes use
# regular functions which return awaitables instead of having
# the entire function be async.
# is_async = inspect.iscoroutinefunction(call)
# if self.is_async != is_async:
# msg = (
# 'Expected a sync method; found an async one.'
# if is_async
# else 'Expected an async method; found a sync one.'
# )
# raise ValueError(msg)
# Check annotation types to determine what message types we handle.
# Return-type annotation can be a Union, but we probably don't
# have it available at runtime. Explicitly pull it in.
# UPDATE: we've updated our pylint filter to where we should
# have all annotations available.
# anns = get_type_hints(call, localns={'Union': Union})
anns = get_type_hints(call)
msgtype = anns.get('msg')
if not isinstance(msgtype, type):
raise TypeError(
f'expected a type for "msg" annotation; got {type(msgtype)}.'
)
assert issubclass(msgtype, Message)
ret = anns.get('return')
responsetypes: tuple[type[Any] | None, ...]
# Return types can be a single type or a union of types.
if isinstance(ret, (_GenericAlias, types.UnionType)):
targs = get_args(ret)
if not all(isinstance(a, (type, type(None))) for a in targs):
raise TypeError(
f'expected only types for "return" annotation;'
f' got {targs}.'
)
responsetypes = targs
else:
if not isinstance(ret, (type, type(None))):
raise TypeError(
f'expected one or more types for'
f' "return" annotation; got a {type(ret)}.'
)
# This seems like maybe a mypy bug. Appeared after adding
# types.UnionType above.
responsetypes = (ret,)
# This will contain NoneType for empty return cases, but
# we expect it to be None.
responsetypes = tuple(
None if r is type(None) else r for r in responsetypes
)
# Make sure our protocol has this message type registered and our
# return types exactly match. (Technically we could return a subset
# of the supported types; can allow this in the future if it makes
# sense).
registered_types = self.protocol.message_ids_by_type.keys()
if msgtype not in registered_types:
raise TypeError(
f'Message type {msgtype} is not registered'
f' in this Protocol.'
)
if msgtype in self._handlers:
raise TypeError(
f'Message type {msgtype} already has a registered' f' handler.'
)
# Make sure the responses exactly matches what the message expects.
if set(responsetypes) != set(msgtype.get_response_types()):
raise TypeError(
f'Provided response types {responsetypes} do not'
f' match the set expected by message type {msgtype}: '
f'({msgtype.get_response_types()})'
)
# Ok; we're good!
self._handlers[msgtype] = call
def decode_filter_method(
self, call: Callable[[Any, dict, Message], None]
) -> Callable[[Any, dict, Message], None]:
"""Function decorator for defining a decode filter.
Decode filters can be used to extract extra data from incoming
message dicts. This version will work for both handle_raw_message()
and handle_raw_message_async()
"""
assert self._decode_filter_call is None
self._decode_filter_call = call
return call
def encode_filter_method(
self,
call: Callable[
[Any, Message | None, Response | SysResponse, dict], None
],
) -> Callable[[Any, Message | None, Response, dict], None]:
"""Function decorator for defining an encode filter.
Encode filters can be used to add extra data to the message
dict before is is encoded to a string and sent out.
"""
assert self._encode_filter_call is None
self._encode_filter_call = call
return call
def validate(self, log_only: bool = False) -> None:
"""Check for handler completeness, valid types, etc."""
for msgtype in self.protocol.message_ids_by_type.keys():
if issubclass(msgtype, Response):
continue
if msgtype not in self._handlers:
msg = (
f'Protocol message type {msgtype} is not handled'
f' by receiver type {type(self)}.'
)
if log_only:
logging.error(msg)
else:
raise TypeError(msg)
def _decode_incoming_message_base(
self, bound_obj: Any, msg: str
) -> tuple[Any, dict, Message]:
# Decode the incoming message.
msg_dict = self.protocol.decode_dict(msg)
msg_decoded = self.protocol.message_from_dict(msg_dict)
assert isinstance(msg_decoded, Message)
if self._decode_filter_call is not None:
self._decode_filter_call(bound_obj, msg_dict, msg_decoded)
return bound_obj, msg_dict, msg_decoded
def _decode_incoming_message(self, bound_obj: Any, msg: str) -> Message:
bound_obj, _msg_dict, msg_decoded = self._decode_incoming_message_base(
bound_obj=bound_obj, msg=msg
)
return msg_decoded
def encode_user_response(
self, bound_obj: Any, message: Message, response: Response | None
) -> str:
"""Encode a response provided by the user for sending."""
assert isinstance(response, Response | None)
# (user should never explicitly return error-responses)
assert (
response is None or type(response) in message.get_response_types()
)
# A return value of None equals EmptySysResponse.
out_response: Response | SysResponse
if response is None:
out_response = EmptySysResponse()
else:
out_response = response
response_dict = self.protocol.response_to_dict(out_response)
if self._encode_filter_call is not None:
self._encode_filter_call(
bound_obj, message, out_response, response_dict
)
return self.protocol.encode_dict(response_dict)
def encode_error_response(
self, bound_obj: Any, message: Message | None, exc: Exception
) -> tuple[str, bool]:
"""Given an error, return sysresponse str and whether to log."""
response, dolog = self.protocol.error_to_response(exc)
response_dict = self.protocol.response_to_dict(response)
if self._encode_filter_call is not None:
self._encode_filter_call(
bound_obj, message, response, response_dict
)
return self.protocol.encode_dict(response_dict), dolog
def handle_raw_message(
self, bound_obj: Any, msg: str, raise_unregistered: bool = False
) -> str:
"""Decode, handle, and return an response for a message.
if 'raise_unregistered' is True, will raise an
efro.message.UnregisteredMessageIDError for messages not handled by
the protocol. In all other cases local errors will translate to
error responses returned to the sender.
"""
assert not self.is_async, "can't call sync handler on async receiver"
msg_decoded: Message | None = None
msgtype: type[Message] | None = None
try:
msg_decoded = self._decode_incoming_message(bound_obj, msg)
msgtype = type(msg_decoded)
handler = self._handlers.get(msgtype)
if handler is None:
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
response = handler(bound_obj, msg_decoded)
assert isinstance(response, Response | None)
return self.encode_user_response(bound_obj, msg_decoded, response)
except Exception as exc:
if raise_unregistered and isinstance(
exc, UnregisteredMessageIDError
):
raise
rstr, dolog = self.encode_error_response(
bound_obj, msg_decoded, exc
)
if dolog:
if msgtype is not None:
logging.exception(
'Error handling %s.%s message.',
msgtype.__module__,
msgtype.__qualname__,
)
else:
logging.exception('Error in efro.message handling.')
return rstr
def handle_raw_message_async(
self, bound_obj: Any, msg: str, raise_unregistered: bool = False
) -> Awaitable[str]:
"""Should be called when the receiver gets a message.
The return value is the raw response to the message.
"""
# Note: This call is synchronous so that the first part of it can
# happen synchronously. If the whole call were async we wouldn't be
# able to guarantee that messages handlers would be called in the
# order the messages were received.
assert self.is_async, "can't call async handler on sync receiver"
msg_decoded: Message | None = None
msgtype: type[Message] | None = None
try:
msg_decoded = self._decode_incoming_message(bound_obj, msg)
msgtype = type(msg_decoded)
handler = self._handlers.get(msgtype)
if handler is None:
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
handler_awaitable = handler(bound_obj, msg_decoded)
except Exception as exc:
if raise_unregistered and isinstance(
exc, UnregisteredMessageIDError
):
raise
return self._handle_raw_message_async_error(
bound_obj, msg_decoded, msgtype, exc
)
# Return an awaitable to handle the rest asynchronously.
return self._handle_raw_message_async(
bound_obj, msg_decoded, msgtype, handler_awaitable
)
async def _handle_raw_message_async_error(
self,
bound_obj: Any,
msg_decoded: Message | None,
msgtype: type[Message] | None,
exc: Exception,
) -> str:
rstr, dolog = self.encode_error_response(bound_obj, msg_decoded, exc)
if dolog:
if msgtype is not None:
logging.exception(
'Error handling %s.%s message.',
msgtype.__module__,
msgtype.__qualname__,
)
else:
logging.exception('Error in efro.message handling.')
return rstr
async def _handle_raw_message_async(
self,
bound_obj: Any,
msg_decoded: Message,
msgtype: type[Message] | None,
handler_awaitable: Awaitable[Response | None],
) -> str:
"""Should be called when the receiver gets a message.
The return value is the raw response to the message.
"""
try:
response = await handler_awaitable
assert isinstance(response, Response | None)
return self.encode_user_response(bound_obj, msg_decoded, response)
except Exception as exc:
return await self._handle_raw_message_async_error(
bound_obj, msg_decoded, msgtype, exc
)
class BoundMessageReceiver:
"""Base bound receiver class."""
def __init__(
self,
obj: Any,
receiver: MessageReceiver,
) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this receiver."""
return self._receiver.protocol
def encode_error_response(self, exc: Exception) -> str:
"""Given an error, return a response ready to send.
This should be used for any errors that happen outside of
standard handle_raw_message calls. Any errors within those
calls will be automatically returned as encoded strings.
"""
# Passing None for Message here; we would only have that available
# for things going wrong in the handler (which this is not for).
return self._receiver.encode_error_response(self._obj, None, exc)[0]

View file

@ -0,0 +1,465 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for sending and responding to messages.
Supports static typing for message types and possible return types.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from efro.error import CleanError, RemoteError, CommunicationError
from efro.message._message import EmptySysResponse, ErrorSysResponse, Response
if TYPE_CHECKING:
from typing import Any, Callable, Awaitable
from efro.message._message import Message, SysResponse
from efro.message._protocol import MessageProtocol
class MessageSender:
"""Facilitates sending messages to a target and receiving responses.
This is instantiated at the class level and used to register unbound
class methods to handle raw message sending.
Example:
class MyClass:
msg = MyMessageSender(some_protocol)
@msg.send_method
def send_raw_message(self, message: str) -> str:
# Actually send the message here.
# MyMessageSender class should provide overloads for send(), send_async(),
# etc. to ensure all sending happens with valid types.
obj = MyClass()
obj.msg.send(SomeMessageType())
"""
def __init__(self, protocol: MessageProtocol) -> None:
self.protocol = protocol
self._send_raw_message_call: Callable[[Any, str], str] | None = None
self._send_async_raw_message_call: Callable[
[Any, str], Awaitable[str]
] | None = None
self._send_async_raw_message_ex_call: Callable[
[Any, str, Message], Awaitable[str]
] | None = None
self._encode_filter_call: Callable[
[Any, Message, dict], None
] | None = None
self._decode_filter_call: Callable[
[Any, Message, dict, Response | SysResponse], None
] | None = None
self._peer_desc_call: Callable[[Any], str] | None = None
def send_method(
self, call: Callable[[Any, str], str]
) -> Callable[[Any, str], str]:
"""Function decorator for setting raw send method.
Send methods take strings and should return strings.
CommunicationErrors raised here will be returned to the sender
as such; all other exceptions will result in a RuntimeError for
the sender.
"""
assert self._send_raw_message_call is None
self._send_raw_message_call = call
return call
def send_async_method(
self, call: Callable[[Any, str], Awaitable[str]]
) -> Callable[[Any, str], Awaitable[str]]:
"""Function decorator for setting raw send-async method.
Send methods take strings and should return strings.
CommunicationErrors raised here will be returned to the sender
as such; all other exceptions will result in a RuntimeError for
the sender.
IMPORTANT: Generally async send methods should not be implemented
as 'async' methods, but instead should be regular methods that
return awaitable objects. This way it can be guaranteed that
outgoing messages are synchronously enqueued in the correct
order, and then async calls can be returned which finish each
send. If the entire call is async, they may be enqueued out of
order in rare cases.
"""
assert self._send_async_raw_message_call is None
self._send_async_raw_message_call = call
return call
def send_async_ex_method(
self, call: Callable[[Any, str, Message], Awaitable[str]]
) -> Callable[[Any, str, Message], Awaitable[str]]:
"""Function decorator for extended send-async method.
Version of send_async_method which is also is passed the original
unencoded message; can be useful for cases where metadata is sent
along with messages referring to their payloads/etc.
"""
assert self._send_async_raw_message_ex_call is None
self._send_async_raw_message_ex_call = call
return call
def encode_filter_method(
self, call: Callable[[Any, Message, dict], None]
) -> Callable[[Any, Message, dict], None]:
"""Function decorator for defining an encode filter.
Encode filters can be used to add extra data to the message
dict before is is encoded to a string and sent out.
"""
assert self._encode_filter_call is None
self._encode_filter_call = call
return call
def decode_filter_method(
self, call: Callable[[Any, Message, dict, Response | SysResponse], None]
) -> Callable[[Any, Message, dict, Response], None]:
"""Function decorator for defining a decode filter.
Decode filters can be used to extract extra data from incoming
message dicts.
"""
assert self._decode_filter_call is None
self._decode_filter_call = call
return call
def peer_desc_method(
self, call: Callable[[Any], str]
) -> Callable[[Any], str]:
"""Function decorator for defining peer descriptions.
These are included in error messages or other diagnostics.
"""
assert self._peer_desc_call is None
self._peer_desc_call = call
return call
def send(self, bound_obj: Any, message: Message) -> Response | None:
"""Send a message synchronously."""
return self.unpack_raw_response(
bound_obj=bound_obj,
message=message,
raw_response=self.fetch_raw_response(
bound_obj=bound_obj,
message=message,
),
)
def send_async(
self, bound_obj: Any, message: Message
) -> Awaitable[Response | None]:
"""Send a message asynchronously."""
# Note: This call is synchronous so that the first part of it can
# happen synchronously. If the whole call were async we wouldn't be
# able to guarantee that messages sent in order would actually go
# out in order.
raw_response_awaitable = self.fetch_raw_response_async(
bound_obj=bound_obj,
message=message,
)
# Now return an awaitable that will finish the send.
return self._send_async_awaitable(
bound_obj, message, raw_response_awaitable
)
async def _send_async_awaitable(
self,
bound_obj: Any,
message: Message,
raw_response_awaitable: Awaitable[Response | SysResponse],
) -> Response | None:
return self.unpack_raw_response(
bound_obj=bound_obj,
message=message,
raw_response=await raw_response_awaitable,
)
def fetch_raw_response(
self, bound_obj: Any, message: Message
) -> Response | SysResponse:
"""Send a message synchronously.
Generally you can just call send(); these split versions are
for when message sending and response handling need to happen
in different contexts/threads.
"""
if self._send_raw_message_call is None:
raise RuntimeError('send() is unimplemented for this type.')
msg_encoded = self._encode_message(bound_obj, message)
try:
response_encoded = self._send_raw_message_call(
bound_obj, msg_encoded
)
except Exception as exc:
response = ErrorSysResponse(
error_message='Error in MessageSender @send_method.',
error_type=(
ErrorSysResponse.ErrorType.COMMUNICATION
if isinstance(exc, CommunicationError)
else ErrorSysResponse.ErrorType.LOCAL
),
)
# Can include the actual exception since we'll be looking at
# this locally; might be helpful.
response.set_local_exception(exc)
return response
return self._decode_raw_response(bound_obj, message, response_encoded)
def fetch_raw_response_async(
self, bound_obj: Any, message: Message
) -> Awaitable[Response | SysResponse]:
"""Fetch a raw message response awaitable.
The result of this should be awaited and then passed to
unpack_raw_response() to produce the final message result.
Generally you can just call send(); calling fetch and unpack
manually is for when message sending and response handling need
to happen in different contexts/threads.
"""
# Note: This call is synchronous so that the first part of it can
# happen synchronously. If the whole call were async we wouldn't be
# able to guarantee that messages sent in order would actually go
# out in order.
if (
self._send_async_raw_message_call is None
and self._send_async_raw_message_ex_call is None
):
raise RuntimeError('send_async() is unimplemented for this type.')
msg_encoded = self._encode_message(bound_obj, message)
try:
if self._send_async_raw_message_ex_call is not None:
send_awaitable = self._send_async_raw_message_ex_call(
bound_obj, msg_encoded, message
)
else:
assert self._send_async_raw_message_call is not None
send_awaitable = self._send_async_raw_message_call(
bound_obj, msg_encoded
)
except Exception as exc:
return self._error_awaitable(exc)
# Now return an awaitable to finish the job.
return self._fetch_raw_response_awaitable(
bound_obj, message, send_awaitable
)
async def _error_awaitable(self, exc: Exception) -> SysResponse:
response = ErrorSysResponse(
error_message='Error in MessageSender @send_async_method.',
error_type=(
ErrorSysResponse.ErrorType.COMMUNICATION
if isinstance(exc, CommunicationError)
else ErrorSysResponse.ErrorType.LOCAL
),
)
# Can include the actual exception since we'll be looking at
# this locally; might be helpful.
response.set_local_exception(exc)
return response
async def _fetch_raw_response_awaitable(
self, bound_obj: Any, message: Message, send_awaitable: Awaitable[str]
) -> Response | SysResponse:
try:
response_encoded = await send_awaitable
except Exception as exc:
response = ErrorSysResponse(
error_message='Error in MessageSender @send_async_method.',
error_type=(
ErrorSysResponse.ErrorType.COMMUNICATION
if isinstance(exc, CommunicationError)
else ErrorSysResponse.ErrorType.LOCAL
),
)
# Can include the actual exception since we'll be looking at
# this locally; might be helpful.
response.set_local_exception(exc)
return response
return self._decode_raw_response(bound_obj, message, response_encoded)
def unpack_raw_response(
self,
bound_obj: Any,
message: Message,
raw_response: Response | SysResponse,
) -> Response | None:
"""Convert a raw fetched response into a final response/error/etc.
Generally you can just call send(); calling fetch and unpack
manually is for when message sending and response handling need
to happen in different contexts/threads.
"""
response = self._unpack_raw_response(bound_obj, raw_response)
assert (
response is None
or type(response) in type(message).get_response_types()
)
return response
def _encode_message(self, bound_obj: Any, message: Message) -> str:
"""Encode a message for sending."""
msg_dict = self.protocol.message_to_dict(message)
if self._encode_filter_call is not None:
self._encode_filter_call(bound_obj, message, msg_dict)
return self.protocol.encode_dict(msg_dict)
def _decode_raw_response(
self, bound_obj: Any, message: Message, response_encoded: str
) -> Response | SysResponse:
"""Create a Response from returned data.
These Responses may encapsulate things like remote errors and
should not be handed directly to users. _unpack_raw_response()
should be used to translate to special values like None or raise
Exceptions. This function itself should never raise Exceptions.
"""
response: Response | SysResponse
try:
response_dict = self.protocol.decode_dict(response_encoded)
response = self.protocol.response_from_dict(response_dict)
if self._decode_filter_call is not None:
self._decode_filter_call(
bound_obj, message, response_dict, response
)
except Exception as exc:
response = ErrorSysResponse(
error_message='Error decoding raw response.',
error_type=ErrorSysResponse.ErrorType.LOCAL,
)
# Since we'll be looking at this locally, we can include
# extra info for logging/etc.
response.set_local_exception(exc)
return response
def _unpack_raw_response(
self, bound_obj: Any, raw_response: Response | SysResponse
) -> Response | None:
"""Given a raw Response, unpacks to special values or Exceptions.
The result of this call is what should be passed to users.
For complex messaging situations such as response callbacks
operating across different threads, this last stage should be
run such that any raised Exception is active when the callback
fires; not on the thread where the message was sent.
"""
# EmptySysResponse translates to None
if isinstance(raw_response, EmptySysResponse):
return None
# Some error occurred. Raise a local Exception for it.
if isinstance(raw_response, ErrorSysResponse):
# Errors that happened locally can attach their exceptions
# here for extra logging goodness.
local_exception = raw_response.get_local_exception()
if (
raw_response.error_type
is ErrorSysResponse.ErrorType.COMMUNICATION
):
raise CommunicationError(
raw_response.error_message
) from local_exception
# If something went wrong on *our* end of the connection,
# don't say it was a remote error.
if raw_response.error_type is ErrorSysResponse.ErrorType.LOCAL:
raise RuntimeError(
raw_response.error_message
) from local_exception
# If they want to support clean errors, do those.
if (
self.protocol.forward_clean_errors
and raw_response.error_type
is ErrorSysResponse.ErrorType.REMOTE_CLEAN
):
raise CleanError(
raw_response.error_message
) from local_exception
if (
self.protocol.forward_communication_errors
and raw_response.error_type
is ErrorSysResponse.ErrorType.REMOTE_COMMUNICATION
):
raise CommunicationError(
raw_response.error_message
) from local_exception
# Everything else gets lumped in as a remote error.
raise RemoteError(
raw_response.error_message,
peer_desc=(
'peer'
if self._peer_desc_call is None
else self._peer_desc_call(bound_obj)
),
) from local_exception
assert isinstance(raw_response, Response)
return raw_response
class BoundMessageSender:
"""Base class for bound senders."""
def __init__(self, obj: Any, sender: MessageSender) -> None:
# Note: not checking obj here since we want to support
# at least our protocol property when accessed via type.
self._obj = obj
self._sender = sender
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this sender."""
return self._sender.protocol
def send_untyped(self, message: Message) -> Response | None:
"""Send a message synchronously.
Whenever possible, use the send() call provided by generated
subclasses instead of this; it will provide better type safety.
"""
assert self._obj is not None
return self._sender.send(bound_obj=self._obj, message=message)
def send_async_untyped(
self, message: Message
) -> Awaitable[Response | None]:
"""Send a message asynchronously.
Whenever possible, use the send_async() call provided by generated
subclasses instead of this; it will provide better type safety.
"""
assert self._obj is not None
return self._sender.send_async(bound_obj=self._obj, message=message)
def fetch_raw_response_async_untyped(
self, message: Message
) -> Awaitable[Response | SysResponse]:
"""Split send (part 1 of 2)."""
assert self._obj is not None
return self._sender.fetch_raw_response_async(
bound_obj=self._obj, message=message
)
def unpack_raw_response_untyped(
self, message: Message, raw_response: Response | SysResponse
) -> Response | None:
"""Split send (part 2 of 2)."""
return self._sender.unpack_raw_response(
bound_obj=self._obj, message=message, raw_response=raw_response
)

75
dist/ba_data/python/efro/net.py vendored Normal file
View file

@ -0,0 +1,75 @@
# Released under the MIT License. See LICENSE for details.
#
"""Network related functionality."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
pass
def is_urllib_network_error(exc: BaseException) -> bool:
"""Is the provided exception a network-related error?
This should be passed an exception which resulted from opening or
reading a urllib Request. It should return True for any errors that
could conceivably arise due to unavailable/poor network connections,
firewall/connectivity issues, etc. These issues can often be safely
ignored or presented to the user as general 'network-unavailable'
states.
"""
import urllib.request
import urllib.error
import http.client
import errno
import socket
if isinstance(
exc,
(urllib.error.URLError, ConnectionError, http.client.IncompleteRead,
http.client.BadStatusLine, socket.timeout)):
return True
if isinstance(exc, OSError):
if exc.errno == 10051: # Windows unreachable network error.
return True
if exc.errno in {
errno.ETIMEDOUT,
errno.EHOSTUNREACH,
errno.ENETUNREACH,
}:
return True
return False
def is_udp_network_error(exc: BaseException) -> bool:
"""Is the provided exception a network-related error?
This should be passed an exception which resulted from creating and
using a socket.SOCK_DGRAM type socket. It should return True for any
errors that could conceivably arise due to unavailable/poor network
connections, firewall/connectivity issues, etc. These issues can often
be safely ignored or presented to the user as general
'network-unavailable' states.
"""
import errno
if isinstance(exc, ConnectionRefusedError):
return True
if isinstance(exc, OSError):
if exc.errno == 10051: # Windows unreachable network error.
return True
if exc.errno in {
errno.EADDRNOTAVAIL,
errno.ETIMEDOUT,
errno.EHOSTUNREACH,
errno.ENETUNREACH,
errno.EINVAL,
errno.EPERM,
errno.EACCES,
# Windows 'invalid argument' error.
10022,
# Windows 'a socket operation was attempted to'
# 'an unreachable network' error.
10051,
}:
return True
return False

945
dist/ba_data/python/efro/rpc.py vendored Normal file
View file

@ -0,0 +1,945 @@
# Released under the MIT License. See LICENSE for details.
#
"""Remote procedure call related functionality."""
from __future__ import annotations
import time
import asyncio
import logging
import weakref
from enum import Enum
from collections import deque
from dataclasses import dataclass
from threading import current_thread
from typing import TYPE_CHECKING, Annotated
from efro.util import assert_never
from efro.error import (
CommunicationError,
is_asyncio_streams_communication_error,
)
from efro.dataclassio import (
dataclass_to_json,
dataclass_from_json,
ioprepped,
IOAttrs,
)
if TYPE_CHECKING:
from typing import Literal, Awaitable, Callable
# Terminology:
# Packet: A chunk of data consisting of a type and some type-dependent
# payload. Even though we use streams we organize our transmission
# into 'packets'.
# Message: User data which we transmit using one or more packets.
class _PacketType(Enum):
HANDSHAKE = 0
KEEPALIVE = 1
MESSAGE = 2
RESPONSE = 3
MESSAGE_BIG = 4
RESPONSE_BIG = 5
_BYTE_ORDER: Literal['big'] = 'big'
@ioprepped
@dataclass
class _PeerInfo:
# So we can gracefully evolve how we communicate in the future.
protocol: Annotated[int, IOAttrs('p')]
# How often we'll be sending out keepalives (in seconds).
keepalive_interval: Annotated[float, IOAttrs('k')]
# Note: we are expected to be forward and backward compatible; we can
# increment protocol freely and expect everyone else to still talk to us.
# Likewise we should retain logic to communicate with older protocols.
# Protocol history:
# 1 - initial release
# 2 - gained big (32-bit len val) package/response packets
OUR_PROTOCOL = 2
def ssl_stream_writer_underlying_transport_info(
writer: asyncio.StreamWriter,
) -> str:
"""For debugging SSL Stream connections; returns raw transport info."""
# Note: accessing internals here so just returning info and not
# actual objs to reduce potential for breakage.
transport = getattr(writer, '_transport', None)
if transport is not None:
sslproto = getattr(transport, '_ssl_protocol', None)
if sslproto is not None:
raw_transport = getattr(sslproto, '_transport', None)
if raw_transport is not None:
return str(raw_transport)
return '(not found)'
def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None:
"""Ensure a writer is closed; hacky workaround for odd hang."""
from efro.call import tpartial
from threading import Thread
# Disabling for now..
if bool(True):
return
# Hopefully can remove this in Python 3.11?...
# see issue with is_closing() below for more details.
transport = getattr(writer, '_transport', None)
if transport is not None:
sslproto = getattr(transport, '_ssl_protocol', None)
if sslproto is not None:
raw_transport = getattr(sslproto, '_transport', None)
if raw_transport is not None:
Thread(
target=tpartial(
_do_writer_force_close_check,
weakref.ref(raw_transport),
),
daemon=True,
).start()
def _do_writer_force_close_check(transport_weak: weakref.ref) -> None:
try:
# Attempt to bail as soon as the obj dies.
# If it hasn't done so by our timeout, force-kill it.
starttime = time.monotonic()
while time.monotonic() - starttime < 10.0:
time.sleep(0.1)
if transport_weak() is None:
return
transport = transport_weak()
if transport is not None:
logging.info('Forcing abort on stuck transport %s.', transport)
transport.abort()
except Exception:
logging.warning('Error in writer-force-close-check', exc_info=True)
class _InFlightMessage:
"""Represents a message that is out on the wire."""
def __init__(self) -> None:
self._response: bytes | None = None
self._got_response = asyncio.Event()
self.wait_task = asyncio.create_task(
self._wait(), name='rpc in flight msg wait'
)
async def _wait(self) -> bytes:
await self._got_response.wait()
assert self._response is not None
return self._response
def set_response(self, data: bytes) -> None:
"""Set response data."""
assert self._response is None
self._response = data
self._got_response.set()
class _KeepaliveTimeoutError(Exception):
"""Raised if we time out due to not receiving keepalives."""
class RPCEndpoint:
"""Facilitates asynchronous multiplexed remote procedure calls.
Be aware that, while multiple calls can be in flight in either direction
simultaneously, packets are still sent serially in a single
stream. So excessively long messages/responses will delay all other
communication. If/when this becomes an issue we can look into breaking up
long messages into multiple packets.
"""
# Set to True on an instance to test keepalive failures.
test_suppress_keepalives: bool = False
# How long we should wait before giving up on a message by default.
# Note this includes processing time on the other end.
DEFAULT_MESSAGE_TIMEOUT = 60.0
# How often we send out keepalive packets by default.
DEFAULT_KEEPALIVE_INTERVAL = 10.73 # (avoid too regular of values)
# How long we can go without receiving a keepalive packet before we
# disconnect.
DEFAULT_KEEPALIVE_TIMEOUT = 30.0
def __init__(
self,
handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
label: str,
debug_print: bool = False,
debug_print_io: bool = False,
debug_print_call: Callable[[str], None] | None = None,
keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT,
) -> None:
self._handle_raw_message_call = handle_raw_message_call
self._reader = reader
self._writer = writer
self.debug_print = debug_print
self.debug_print_io = debug_print_io
if debug_print_call is None:
debug_print_call = print
self.debug_print_call: Callable[[str], None] = debug_print_call
self._label = label
self._thread = current_thread()
self._closing = False
self._did_wait_closed = False
self._event_loop = asyncio.get_running_loop()
self._out_packets = deque[bytes]()
self._have_out_packets = asyncio.Event()
self._run_called = False
self._peer_info: _PeerInfo | None = None
self._keepalive_interval = keepalive_interval
self._keepalive_timeout = keepalive_timeout
self._did_close_writer = False
self._did_wait_closed_writer = False
self._did_out_packets_buildup_warning = False
self._total_bytes_read = 0
self._create_time = time.monotonic()
# Need to hold weak-refs to these otherwise it creates dep-loops
# which keeps us alive.
self._tasks: list[asyncio.Task] = []
# When we last got a keepalive or equivalent (time.monotonic value)
self._last_keepalive_receive_time: float | None = None
# (Start near the end to make sure our looping logic is sound).
self._next_message_id = 65530
self._in_flight_messages: dict[int, _InFlightMessage] = {}
if self.debug_print:
peername = self._writer.get_extra_info('peername')
self.debug_print_call(
f'{self._label}: connected to {peername} at {self._tm()}.'
)
def __del__(self) -> None:
if self._run_called:
if not self._did_close_writer:
logging.warning(
'RPCEndpoint %d dying with run'
' called but writer not closed (transport=%s).',
id(self),
ssl_stream_writer_underlying_transport_info(self._writer),
)
elif not self._did_wait_closed_writer:
logging.warning(
'RPCEndpoint %d dying with run called'
' but writer not wait-closed (transport=%s).',
id(self),
ssl_stream_writer_underlying_transport_info(self._writer),
)
# Currently seeing rare issue where sockets don't go down;
# let's add a timer to force the issue until we can figure it out.
ssl_stream_writer_force_close_check(self._writer)
async def run(self) -> None:
"""Run the endpoint until the connection is lost or closed.
Handles closing the provided reader/writer on close.
"""
try:
await self._do_run()
except asyncio.CancelledError:
# We aren't really designed to be cancelled so let's warn
# if it happens.
logging.warning(
'RPCEndpoint.run got CancelledError;'
' want to try and avoid this.'
)
raise
async def _do_run(self) -> None:
self._check_env()
if self._run_called:
raise RuntimeError('Run can be called only once per endpoint.')
self._run_called = True
core_tasks = [
asyncio.create_task(
self._run_core_task('keepalive', self._run_keepalive_task()),
name='rpc keepalive',
),
asyncio.create_task(
self._run_core_task('read', self._run_read_task()),
name='rpc read',
),
asyncio.create_task(
self._run_core_task('write', self._run_write_task()),
name='rpc write',
),
]
self._tasks += core_tasks
# Run our core tasks until they all complete.
results = await asyncio.gather(*core_tasks, return_exceptions=True)
# Core tasks should handle their own errors; the only ones
# we expect to bubble up are CancelledError.
for result in results:
# We want to know if any errors happened aside from CancelledError
# (which are BaseExceptions, not Exception).
if isinstance(result, Exception):
logging.warning(
'Got unexpected error from %s core task: %s',
self._label,
result,
)
if not all(task.done() for task in core_tasks):
logging.warning(
'RPCEndpoint %d: not all core tasks marked done after gather.',
id(self),
)
# Shut ourself down.
try:
self.close()
await self.wait_closed()
except Exception:
logging.exception('Error closing %s.', self._label)
if self.debug_print:
self.debug_print_call(f'{self._label}: finished.')
def send_message(
self,
message: bytes,
timeout: float | None = None,
close_on_error: bool = True,
) -> Awaitable[bytes]:
"""Send a message to the peer and return a response.
If timeout is not provided, the default will be used.
Raises a CommunicationError if the round trip is not completed
for any reason.
By default, the entire endpoint will go down in the case of
errors. This allows messages to be treated as 'reliable' with
respect to a given endpoint. Pass close_on_error=False to
override this for a particular message.
"""
# Note: This call is synchronous so that the first part of it
# (enqueueing outgoing messages) happens synchronously. If it were
# a pure async call it could be possible for send order to vary
# based on how the async tasks get processed.
if self.debug_print_io:
self.debug_print_call(
f'{self._label}: sending message of size {len(message)}'
f' at {self._tm()}.'
)
self._check_env()
if self._closing:
raise CommunicationError('Endpoint is closed.')
if self.debug_print_io:
self.debug_print_call(
f'{self._label}: have peerinfo? {self._peer_info is not None}.'
)
# message_id is a 16 bit looping value.
message_id = self._next_message_id
self._next_message_id = (self._next_message_id + 1) % 65536
if self.debug_print_io:
self.debug_print_call(
f'{self._label}: will enqueue at {self._tm()}.'
)
# FIXME - should handle backpressure (waiting here if there are
# enough packets already enqueued).
if len(message) > 65535:
# Payload consists of type (1b), message_id (2b),
# len (4b), and data.
self._enqueue_outgoing_packet(
_PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER)
+ message_id.to_bytes(2, _BYTE_ORDER)
+ len(message).to_bytes(4, _BYTE_ORDER)
+ message
)
else:
# Payload consists of type (1b), message_id (2b),
# len (2b), and data.
self._enqueue_outgoing_packet(
_PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER)
+ message_id.to_bytes(2, _BYTE_ORDER)
+ len(message).to_bytes(2, _BYTE_ORDER)
+ message
)
if self.debug_print_io:
self.debug_print_call(
f'{self._label}: enqueued message of size {len(message)}'
f' at {self._tm()}.'
)
# Make an entry so we know this message is out there.
assert message_id not in self._in_flight_messages
msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
# Also add its task to our list so we properly cancel it if we die.
self._prune_tasks() # Keep our list from filling with dead tasks.
self._tasks.append(msgobj.wait_task)
# Note: we always want to incorporate a timeout. Individual
# messages may hang or error on the other end and this ensures
# we won't build up lots of zombie tasks waiting around for
# responses that will never arrive.
if timeout is None:
timeout = self.DEFAULT_MESSAGE_TIMEOUT
assert timeout is not None
bytes_awaitable = msgobj.wait_task
# Now complete the send asynchronously.
return self._send_message(
message, timeout, close_on_error, bytes_awaitable, message_id
)
async def _send_message(
self,
message: bytes,
timeout: float | None,
close_on_error: bool,
bytes_awaitable: asyncio.Task[bytes],
message_id: int,
) -> bytes:
# We need to know their protocol, so if we haven't gotten a handshake
# from them yet, just wait.
while self._peer_info is None:
await asyncio.sleep(0.01)
assert self._peer_info is not None
if self._peer_info.protocol == 1:
if len(message) > 65535:
raise RuntimeError('Message cannot be larger than 65535 bytes')
try:
return await asyncio.wait_for(bytes_awaitable, timeout=timeout)
except asyncio.CancelledError as exc:
# Question: we assume this means the above wait_for() was
# cancelled; how do we distinguish between this and *us* being
# cancelled though?
if self.debug_print:
self.debug_print_call(
f'{self._label}: message {message_id} was cancelled.'
)
if close_on_error:
self.close()
raise CommunicationError() from exc
except Exception as exc:
# If our timer timed-out or anything else went wrong with
# the stream, lump it in as a communication error.
if isinstance(
exc, asyncio.TimeoutError
) or is_asyncio_streams_communication_error(exc):
if self.debug_print:
self.debug_print_call(
f'{self._label}: got {type(exc)} sending message'
f' {message_id}; raising CommunicationError.'
)
# Stop waiting on the response.
bytes_awaitable.cancel()
# Remove the record of this message.
del self._in_flight_messages[message_id]
if close_on_error:
self.close()
# Let the user know something went wrong.
raise CommunicationError() from exc
# Some unexpected error; let it bubble up.
raise
def close(self) -> None:
"""I said seagulls; mmmm; stop it now."""
self._check_env()
if self._closing:
return
if self.debug_print:
self.debug_print_call(f'{self._label}: closing...')
self._closing = True
# Kill all of our in-flight tasks.
if self.debug_print:
self.debug_print_call(f'{self._label}: cancelling tasks...')
for task in self._get_live_tasks():
task.cancel()
# Close our writer.
assert not self._did_close_writer
if self.debug_print:
self.debug_print_call(f'{self._label}: closing writer...')
self._writer.close()
self._did_close_writer = True
# We don't need this anymore and it is likely to be creating a
# dependency loop.
del self._handle_raw_message_call
def is_closing(self) -> bool:
"""Have we begun the process of closing?"""
return self._closing
async def wait_closed(self) -> None:
"""I said seagulls; mmmm; stop it now.
Wait for the endpoint to finish closing. This is called by run()
so generally does not need to be explicitly called.
"""
# pylint: disable=too-many-branches
self._check_env()
# Make sure we only *enter* this call once.
if self._did_wait_closed:
return
self._did_wait_closed = True
if not self._closing:
raise RuntimeError('Must be called after close()')
if not self._did_close_writer:
logging.warning(
'RPCEndpoint wait_closed() called but never'
' explicitly closed writer.'
)
live_tasks = self._get_live_tasks()
# Don't need our task list anymore; this should
# break any cyclical refs from tasks referring to us.
self._tasks = []
if self.debug_print:
self.debug_print_call(
f'{self._label}: waiting for tasks to finish: '
f' ({live_tasks=})...'
)
# Wait for all of our in-flight tasks to wrap up.
results = await asyncio.gather(*live_tasks, return_exceptions=True)
for result in results:
# We want to know if any errors happened aside from CancelledError
# (which are BaseExceptions, not Exception).
if isinstance(result, Exception):
logging.warning(
'Got unexpected error cleaning up %s task: %s',
self._label,
result,
)
if not all(task.done() for task in live_tasks):
logging.warning(
'RPCEndpoint %d: not all live tasks marked done after gather.',
id(self),
)
if self.debug_print:
self.debug_print_call(
f'{self._label}: tasks finished; waiting for writer close...'
)
# Now wait for our writer to finish going down.
# When we close our writer it generally triggers errors
# in our current blocked read/writes. However that same
# error is also sometimes returned from _writer.wait_closed().
# See connection_lost() in asyncio/streams.py to see why.
# So let's silently ignore it when that happens.
assert self._writer.is_closing()
try:
# It seems that as of Python 3.9.x it is possible for this to hang
# indefinitely. See https://github.com/python/cpython/issues/83939
# It sounds like this should be fixed in 3.11 but for now just
# forcing the issue with a timeout here.
await asyncio.wait_for(
self._writer.wait_closed(),
# timeout=60.0 * 6.0,
timeout=30.0,
)
except asyncio.TimeoutError:
logging.info(
'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
self._label,
ssl_stream_writer_underlying_transport_info(self._writer),
)
if self.debug_print:
self.debug_print_call(
f'{self._label}: got timeout in _writer.wait_closed();'
' This should be fixed in future Python versions.'
)
except Exception as exc:
if not self._is_expected_connection_error(exc):
logging.exception('Error closing _writer for %s.', self._label)
else:
if self.debug_print:
self.debug_print_call(
f'{self._label}: silently ignoring error in'
f' _writer.wait_closed(): {exc}.'
)
except asyncio.CancelledError:
logging.warning(
'RPCEndpoint.wait_closed()'
' got asyncio.CancelledError; not expected.'
)
raise
assert not self._did_wait_closed_writer
self._did_wait_closed_writer = True
def _tm(self) -> str:
"""Simple readable time value for debugging."""
tval = time.monotonic() % 100.0
return f'{tval:.2f}'
async def _run_read_task(self) -> None:
"""Read from the peer."""
self._check_env()
assert self._peer_info is None
# Bug fix: if we don't have this set we will never time out
# if we never receive any data from the other end.
self._last_keepalive_receive_time = time.monotonic()
# The first thing they should send us is their handshake; then
# we'll know if/how we can talk to them.
mlen = await self._read_int_32()
message = await self._reader.readexactly(mlen)
self._total_bytes_read += mlen
self._peer_info = dataclass_from_json(_PeerInfo, message.decode())
self._last_keepalive_receive_time = time.monotonic()
if self.debug_print:
self.debug_print_call(
f'{self._label}: received handshake at {self._tm()}.'
)
# Now just sit and handle stuff as it comes in.
while True:
if self._closing:
return
# Read message type.
mtype = _PacketType(await self._read_int_8())
if mtype is _PacketType.HANDSHAKE:
raise RuntimeError('Got multiple handshakes')
if mtype is _PacketType.KEEPALIVE:
if self.debug_print_io:
self.debug_print_call(
f'{self._label}: received keepalive'
f' at {self._tm()}.'
)
self._last_keepalive_receive_time = time.monotonic()
elif mtype is _PacketType.MESSAGE:
await self._handle_message_packet(big=False)
elif mtype is _PacketType.MESSAGE_BIG:
await self._handle_message_packet(big=True)
elif mtype is _PacketType.RESPONSE:
await self._handle_response_packet(big=False)
elif mtype is _PacketType.RESPONSE_BIG:
await self._handle_response_packet(big=True)
else:
assert_never(mtype)
async def _handle_message_packet(self, big: bool) -> None:
assert self._peer_info is not None
msgid = await self._read_int_16()
if big:
msglen = await self._read_int_32()
else:
msglen = await self._read_int_16()
msg = await self._reader.readexactly(msglen)
self._total_bytes_read += msglen
if self.debug_print_io:
self.debug_print_call(
f'{self._label}: received message {msgid}'
f' of size {msglen} at {self._tm()}.'
)
# Create a message-task to handle this message and return
# a response (we don't want to block while that happens).
assert not self._closing
self._prune_tasks() # Keep from filling with dead tasks.
self._tasks.append(
asyncio.create_task(
self._handle_raw_message(message_id=msgid, message=msg),
name='efro rpc message handle',
)
)
if self.debug_print:
self.debug_print_call(
f'{self._label}: done handling message at {self._tm()}.'
)
async def _handle_response_packet(self, big: bool) -> None:
assert self._peer_info is not None
msgid = await self._read_int_16()
# Protocol 2 gained 32 bit data lengths.
if big:
rsplen = await self._read_int_32()
else:
rsplen = await self._read_int_16()
if self.debug_print_io:
self.debug_print_call(
f'{self._label}: received response {msgid}'
f' of size {rsplen} at {self._tm()}.'
)
rsp = await self._reader.readexactly(rsplen)
self._total_bytes_read += rsplen
msgobj = self._in_flight_messages.get(msgid)
if msgobj is None:
# It's possible for us to get a response to a message
# that has timed out. In this case we will have no local
# record of it.
if self.debug_print:
self.debug_print_call(
f'{self._label}: got response for nonexistent'
f' message id {msgid}; perhaps it timed out?'
)
else:
msgobj.set_response(rsp)
async def _run_write_task(self) -> None:
"""Write to the peer."""
self._check_env()
# Introduce ourself so our peer knows how it can talk to us.
data = dataclass_to_json(
_PeerInfo(
protocol=OUR_PROTOCOL,
keepalive_interval=self._keepalive_interval,
)
).encode()
self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data)
# Now just write out-messages as they come in.
while True:
# Wait until some data comes in.
await self._have_out_packets.wait()
assert self._out_packets
data = self._out_packets.popleft()
# Important: only clear this once all packets are sent.
if not self._out_packets:
self._have_out_packets.clear()
self._writer.write(data)
# This should keep our writer from buffering huge amounts
# of outgoing data. We must remember though that we also
# need to prevent _out_packets from growing too large and
# that part's on us.
await self._writer.drain()
# For now we're not applying backpressure, but let's make
# noise if this gets out of hand.
if len(self._out_packets) > 200:
if not self._did_out_packets_buildup_warning:
logging.warning(
'_out_packets building up too'
' much on RPCEndpoint %s.',
id(self),
)
self._did_out_packets_buildup_warning = True
async def _run_keepalive_task(self) -> None:
"""Send periodic keepalive packets."""
self._check_env()
# We explicitly send our own keepalive packets so we can stay
# more on top of the connection state and possibly decide to
# kill it when contact is lost more quickly than the OS would
# do itself (or at least keep the user informed that the
# connection is lagging). It sounds like we could have the TCP
# layer do this sort of thing itself but that might be
# OS-specific so gonna go this way for now.
while True:
assert not self._closing
await asyncio.sleep(self._keepalive_interval)
if not self.test_suppress_keepalives:
self._enqueue_outgoing_packet(
_PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER)
)
# Also go ahead and handle dropping the connection if we
# haven't heard from the peer in a while.
# NOTE: perhaps we want to do something more exact than
# this which only checks once per keepalive-interval?..
now = time.monotonic()
if (
self._last_keepalive_receive_time is not None
and now - self._last_keepalive_receive_time
> self._keepalive_timeout
):
if self.debug_print:
since = now - self._last_keepalive_receive_time
self.debug_print_call(
f'{self._label}: reached keepalive time-out'
f' ({since:.1f}s).'
)
raise _KeepaliveTimeoutError()
async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None:
try:
await call
except Exception as exc:
# We expect connection errors to put us here, but make noise
# if something else does.
if not self._is_expected_connection_error(exc):
logging.exception(
'Unexpected error in rpc %s %s task'
' (age=%.1f, total_bytes_read=%d).',
self._label,
tasklabel,
time.monotonic() - self._create_time,
self._total_bytes_read,
)
else:
if self.debug_print:
self.debug_print_call(
f'{self._label}: {tasklabel} task will exit cleanly'
f' due to {exc!r}.'
)
finally:
# Any core task exiting triggers shutdown.
if self.debug_print:
self.debug_print_call(
f'{self._label}: {tasklabel} task exiting...'
)
self.close()
async def _handle_raw_message(
self, message_id: int, message: bytes
) -> None:
try:
response = await self._handle_raw_message_call(message)
except Exception:
# We expect local message handler to always succeed.
# If that doesn't happen, make a fuss so we know to fix it.
# The other end will simply never get a response to this
# message.
logging.exception('Error handling raw rpc message')
return
assert self._peer_info is not None
if self._peer_info.protocol == 1:
if len(response) > 65535:
raise RuntimeError('Response cannot be larger than 65535 bytes')
# Now send back our response.
# Payload consists of type (1b), msgid (2b), len (2b), and data.
if len(response) > 65535:
self._enqueue_outgoing_packet(
_PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER)
+ message_id.to_bytes(2, _BYTE_ORDER)
+ len(response).to_bytes(4, _BYTE_ORDER)
+ response
)
else:
self._enqueue_outgoing_packet(
_PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER)
+ message_id.to_bytes(2, _BYTE_ORDER)
+ len(response).to_bytes(2, _BYTE_ORDER)
+ response
)
async def _read_int_8(self) -> int:
out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER)
self._total_bytes_read += 1
return out
async def _read_int_16(self) -> int:
out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER)
self._total_bytes_read += 2
return out
async def _read_int_32(self) -> int:
out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER)
self._total_bytes_read += 4
return out
@classmethod
def _is_expected_connection_error(cls, exc: Exception) -> bool:
"""Stuff we expect to end our connection in normal circumstances."""
if isinstance(exc, _KeepaliveTimeoutError):
return True
return is_asyncio_streams_communication_error(exc)
def _check_env(self) -> None:
# I was seeing that asyncio stuff wasn't working as expected if
# created in one thread and used in another (and have verified
# that this is part of the design), so let's enforce a single
# thread for all use of an instance.
if current_thread() is not self._thread:
raise RuntimeError(
'This must be called from the same thread'
' that the endpoint was created in.'
)
# This should always be the case if thread is the same.
assert asyncio.get_running_loop() is self._event_loop
def _enqueue_outgoing_packet(self, data: bytes) -> None:
"""Enqueue a raw packet to be sent. Must be called from our loop."""
self._check_env()
if self.debug_print_io:
self.debug_print_call(
f'{self._label}: enqueueing outgoing packet'
f' {data[:50]!r} at {self._tm()}.'
)
# Add the data and let our write task know about it.
self._out_packets.append(data)
self._have_out_packets.set()
def _prune_tasks(self) -> None:
self._tasks = self._get_live_tasks()
def _get_live_tasks(self) -> list[asyncio.Task]:
return [t for t in self._tasks if not t.done()]

321
dist/ba_data/python/efro/terminal.py vendored Normal file
View file

@ -0,0 +1,321 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality related to terminal IO."""
from __future__ import annotations
import sys
import os
from enum import Enum, unique
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, ClassVar
@unique
class TerminalColor(Enum):
"""Color codes for printing to terminals.
Generally the Clr class should be used when incorporating color into
terminal output, as it handles non-color-supporting terminals/etc.
"""
# Styles
RESET = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
INVERSE = '\033[7m'
# Normal foreground colors
BLACK = '\033[30m'
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
MAGENTA = '\033[35m'
CYAN = '\033[36m'
WHITE = '\033[37m'
# Normal background colors.
BG_BLACK = '\033[40m'
BG_RED = '\033[41m'
BG_GREEN = '\033[42m'
BG_YELLOW = '\033[43m'
BG_BLUE = '\033[44m'
BG_MAGENTA = '\033[45m'
BG_CYAN = '\033[46m'
BG_WHITE = '\033[47m'
# Strong foreground colors
STRONG_BLACK = '\033[90m'
STRONG_RED = '\033[91m'
STRONG_GREEN = '\033[92m'
STRONG_YELLOW = '\033[93m'
STRONG_BLUE = '\033[94m'
STRONG_MAGENTA = '\033[95m'
STRONG_CYAN = '\033[96m'
STRONG_WHITE = '\033[97m'
# Strong background colors.
STRONG_BG_BLACK = '\033[100m'
STRONG_BG_RED = '\033[101m'
STRONG_BG_GREEN = '\033[102m'
STRONG_BG_YELLOW = '\033[103m'
STRONG_BG_BLUE = '\033[104m'
STRONG_BG_MAGENTA = '\033[105m'
STRONG_BG_CYAN = '\033[106m'
STRONG_BG_WHITE = '\033[107m'
def _default_color_enabled() -> bool:
"""Return whether we should enable ANSI color codes by default."""
import platform
# If we're not attached to a terminal, go with no-color.
if not sys.__stdout__.isatty():
return False
# Another common way to say the terminal can't do fancy stuff like color:
if os.environ.get('TERM') == 'dumb':
return False
# On windows, try to enable ANSI color mode.
if platform.system() == 'Windows':
return _windows_enable_color()
# We seem to be a terminal with color support; let's do it!
return True
# noinspection PyPep8Naming
def _windows_enable_color() -> bool:
"""Attempt to enable ANSI color on windows terminal; return success."""
# pylint: disable=invalid-name, import-error, undefined-variable
# Pulled from: https://bugs.python.org/issue30075
import msvcrt
import ctypes
from ctypes import wintypes
kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) # type: ignore
ERROR_INVALID_PARAMETER = 0x0057
ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004
def _check_bool(result: Any, _func: Any, args: Any) -> Any:
if not result:
raise ctypes.WinError(ctypes.get_last_error()) # type: ignore
return args
LPDWORD = ctypes.POINTER(wintypes.DWORD)
kernel32.GetConsoleMode.errcheck = _check_bool
kernel32.GetConsoleMode.argtypes = (wintypes.HANDLE, LPDWORD)
kernel32.SetConsoleMode.errcheck = _check_bool
kernel32.SetConsoleMode.argtypes = (wintypes.HANDLE, wintypes.DWORD)
def set_conout_mode(new_mode: int, mask: int = 0xFFFFFFFF) -> int:
# don't assume StandardOutput is a console.
# open CONOUT$ instead
fdout = os.open('CONOUT$', os.O_RDWR)
try:
hout = msvcrt.get_osfhandle(fdout) # type: ignore
# pylint: disable=useless-suppression
# pylint: disable=no-value-for-parameter
old_mode = wintypes.DWORD()
# pylint: enable=useless-suppression
kernel32.GetConsoleMode(hout, ctypes.byref(old_mode))
mode = (new_mode & mask) | (old_mode.value & ~mask)
kernel32.SetConsoleMode(hout, mode)
return old_mode.value
finally:
os.close(fdout)
def enable_vt_mode() -> int:
mode = mask = ENABLE_VIRTUAL_TERMINAL_PROCESSING
try:
return set_conout_mode(mode, mask)
except WindowsError as exc: # type: ignore
if exc.winerror == ERROR_INVALID_PARAMETER:
raise NotImplementedError from exc
raise
try:
enable_vt_mode()
return True
except NotImplementedError:
return False
class ClrBase:
"""Base class for color convenience class."""
RST: ClassVar[str]
BLD: ClassVar[str]
UND: ClassVar[str]
INV: ClassVar[str]
# Normal foreground colors
BLK: ClassVar[str]
RED: ClassVar[str]
GRN: ClassVar[str]
YLW: ClassVar[str]
BLU: ClassVar[str]
MAG: ClassVar[str]
CYN: ClassVar[str]
WHT: ClassVar[str]
# Normal background colors.
BBLK: ClassVar[str]
BRED: ClassVar[str]
BGRN: ClassVar[str]
BYLW: ClassVar[str]
BBLU: ClassVar[str]
BMAG: ClassVar[str]
BCYN: ClassVar[str]
BWHT: ClassVar[str]
# Strong foreground colors
SBLK: ClassVar[str]
SRED: ClassVar[str]
SGRN: ClassVar[str]
SYLW: ClassVar[str]
SBLU: ClassVar[str]
SMAG: ClassVar[str]
SCYN: ClassVar[str]
SWHT: ClassVar[str]
# Strong background colors.
SBBLK: ClassVar[str]
SBRED: ClassVar[str]
SBGRN: ClassVar[str]
SBYLW: ClassVar[str]
SBBLU: ClassVar[str]
SBMAG: ClassVar[str]
SBCYN: ClassVar[str]
SBWHT: ClassVar[str]
class ClrAlways(ClrBase):
"""Convenience class for color terminal output.
This version has colors always enabled. Generally you should use Clr which
points to the correct enabled/disabled class depending on the environment.
"""
color_enabled = True
# Styles
RST = TerminalColor.RESET.value
BLD = TerminalColor.BOLD.value
UND = TerminalColor.UNDERLINE.value
INV = TerminalColor.INVERSE.value
# Normal foreground colors
BLK = TerminalColor.BLACK.value
RED = TerminalColor.RED.value
GRN = TerminalColor.GREEN.value
YLW = TerminalColor.YELLOW.value
BLU = TerminalColor.BLUE.value
MAG = TerminalColor.MAGENTA.value
CYN = TerminalColor.CYAN.value
WHT = TerminalColor.WHITE.value
# Normal background colors.
BBLK = TerminalColor.BG_BLACK.value
BRED = TerminalColor.BG_RED.value
BGRN = TerminalColor.BG_GREEN.value
BYLW = TerminalColor.BG_YELLOW.value
BBLU = TerminalColor.BG_BLUE.value
BMAG = TerminalColor.BG_MAGENTA.value
BCYN = TerminalColor.BG_CYAN.value
BWHT = TerminalColor.BG_WHITE.value
# Strong foreground colors
SBLK = TerminalColor.STRONG_BLACK.value
SRED = TerminalColor.STRONG_RED.value
SGRN = TerminalColor.STRONG_GREEN.value
SYLW = TerminalColor.STRONG_YELLOW.value
SBLU = TerminalColor.STRONG_BLUE.value
SMAG = TerminalColor.STRONG_MAGENTA.value
SCYN = TerminalColor.STRONG_CYAN.value
SWHT = TerminalColor.STRONG_WHITE.value
# Strong background colors.
SBBLK = TerminalColor.STRONG_BG_BLACK.value
SBRED = TerminalColor.STRONG_BG_RED.value
SBGRN = TerminalColor.STRONG_BG_GREEN.value
SBYLW = TerminalColor.STRONG_BG_YELLOW.value
SBBLU = TerminalColor.STRONG_BG_BLUE.value
SBMAG = TerminalColor.STRONG_BG_MAGENTA.value
SBCYN = TerminalColor.STRONG_BG_CYAN.value
SBWHT = TerminalColor.STRONG_BG_WHITE.value
class ClrNever(ClrBase):
"""Convenience class for color terminal output.
This version has colors disabled. Generally you should use Clr which
points to the correct enabled/disabled class depending on the environment.
"""
color_enabled = False
# Styles
RST = ''
BLD = ''
UND = ''
INV = ''
# Normal foreground colors
BLK = ''
RED = ''
GRN = ''
YLW = ''
BLU = ''
MAG = ''
CYN = ''
WHT = ''
# Normal background colors.
BBLK = ''
BRED = ''
BGRN = ''
BYLW = ''
BBLU = ''
BMAG = ''
BCYN = ''
BWHT = ''
# Strong foreground colors
SBLK = ''
SRED = ''
SGRN = ''
SYLW = ''
SBLU = ''
SMAG = ''
SCYN = ''
SWHT = ''
# Strong background colors.
SBBLK = ''
SBRED = ''
SBGRN = ''
SBYLW = ''
SBBLU = ''
SBMAG = ''
SBCYN = ''
SBWHT = ''
_envval = os.environ.get('EFRO_TERMCOLORS')
_color_enabled: bool = (
True
if _envval == '1'
else False
if _envval == '0'
else _default_color_enabled()
)
Clr: type[ClrBase]
if _color_enabled:
Clr = ClrAlways
else:
Clr = ClrNever

735
dist/ba_data/python/efro/util.py vendored Normal file
View file

@ -0,0 +1,735 @@
# Released under the MIT License. See LICENSE for details.
#
"""Small handy bits of functionality."""
from __future__ import annotations
import os
import time
import weakref
import datetime
import functools
from enum import Enum
from typing import TYPE_CHECKING, cast, TypeVar, Generic
_pytz_utc: Any
# We don't *require* pytz, but we want to support it for tzinfos if available.
try:
import pytz
_pytz_utc = pytz.utc
except ModuleNotFoundError:
_pytz_utc = None # pylint: disable=invalid-name
if TYPE_CHECKING:
import asyncio
from efro.call import Call as Call # 'as Call' so we re-export.
from typing import Any, Callable, NoReturn
T = TypeVar('T')
ValT = TypeVar('ValT')
ArgT = TypeVar('ArgT')
SelfT = TypeVar('SelfT')
RetT = TypeVar('RetT')
EnumT = TypeVar('EnumT', bound=Enum)
class _EmptyObj:
pass
# TODO: kill this and just use efro.call.tpartial
if TYPE_CHECKING:
Call = Call
else:
Call = functools.partial
def enum_by_value(cls: type[EnumT], value: Any) -> EnumT:
"""Create an enum from a value.
This is basically the same as doing 'obj = EnumType(value)' except
that it works around an issue where a reference loop is created
if an exception is thrown due to an invalid value. Since we disable
the cyclic garbage collector for most of the time, such loops can lead
to our objects sticking around longer than we want.
This issue has been submitted to Python as a bug so hopefully we can
remove this eventually if it gets fixed: https://bugs.python.org/issue42248
UPDATE: This has been fixed as of later 3.8 builds, so we can kill this
off once we are 3.9+ across the board.
"""
# Note: we don't recreate *ALL* the functionality of the Enum constructor
# such as the _missing_ hook; but this should cover our basic needs.
value2member_map = getattr(cls, '_value2member_map_')
assert value2member_map is not None
try:
out = value2member_map[value]
assert isinstance(out, cls)
return out
except KeyError:
# pylint: disable=consider-using-f-string
raise ValueError(
'%r is not a valid %s' % (value, cls.__name__)
) from None
def check_utc(value: datetime.datetime) -> None:
"""Ensure a datetime value is timezone-aware utc."""
if value.tzinfo is not datetime.timezone.utc and (
_pytz_utc is None or value.tzinfo is not _pytz_utc
):
raise ValueError(
'datetime value does not have timezone set as'
' datetime.timezone.utc'
)
def utc_now() -> datetime.datetime:
"""Get offset-aware current utc time.
This should be used for all datetimes getting sent over the network,
used with the entity system, etc.
(datetime.utcnow() gives a utc time value, but it is not timezone-aware
which makes it less safe to use)
"""
return datetime.datetime.now(datetime.timezone.utc)
def utc_today() -> datetime.datetime:
"""Get offset-aware midnight in the utc time zone."""
now = datetime.datetime.now(datetime.timezone.utc)
return datetime.datetime(
year=now.year, month=now.month, day=now.day, tzinfo=now.tzinfo
)
def utc_this_hour() -> datetime.datetime:
"""Get offset-aware beginning of the current hour in the utc time zone."""
now = datetime.datetime.now(datetime.timezone.utc)
return datetime.datetime(
year=now.year,
month=now.month,
day=now.day,
hour=now.hour,
tzinfo=now.tzinfo,
)
def utc_this_minute() -> datetime.datetime:
"""Get offset-aware beginning of current minute in the utc time zone."""
now = datetime.datetime.now(datetime.timezone.utc)
return datetime.datetime(
year=now.year,
month=now.month,
day=now.day,
hour=now.hour,
minute=now.minute,
tzinfo=now.tzinfo,
)
def empty_weakref(objtype: type[T]) -> weakref.ref[T]:
"""Return an invalidated weak-reference for the specified type."""
# At runtime, all weakrefs are the same; our type arg is just
# for the static type checker.
del objtype # Unused.
# Just create an object and let it die. Is there a cleaner way to do this?
return weakref.ref(_EmptyObj()) # type: ignore
def data_size_str(bytecount: int) -> str:
"""Given a size in bytes, returns a short human readable string.
This should be 6 or fewer chars for most all sane file sizes.
"""
# pylint: disable=too-many-return-statements
if bytecount <= 999:
return f'{bytecount} B'
kbytecount = bytecount / 1024
if round(kbytecount, 1) < 10.0:
return f'{kbytecount:.1f} KB'
if round(kbytecount, 0) < 999:
return f'{kbytecount:.0f} KB'
mbytecount = bytecount / (1024 * 1024)
if round(mbytecount, 1) < 10.0:
return f'{mbytecount:.1f} MB'
if round(mbytecount, 0) < 999:
return f'{mbytecount:.0f} MB'
gbytecount = bytecount / (1024 * 1024 * 1024)
if round(gbytecount, 1) < 10.0:
return f'{mbytecount:.1f} GB'
return f'{gbytecount:.0f} GB'
class DirtyBit:
"""Manages whether a thing is dirty and regulates attempts to clean it.
To use, simply set the 'dirty' value on this object to True when some
action is needed, and then check the 'should_update' value to regulate
when attempts to clean it should be made. Set 'dirty' back to False after
a successful update.
If 'use_lock' is True, an asyncio Lock will be created and incorporated
into update attempts to prevent simultaneous updates (should_update will
only return True when the lock is unlocked). Note that It is up to the user
to lock/unlock the lock during the actual update attempt.
If a value is passed for 'auto_dirty_seconds', the dirtybit will flip
itself back to dirty after being clean for the given amount of time.
'min_update_interval' can be used to enforce a minimum update
interval even when updates are successful (retry_interval only applies
when updates fail)
"""
def __init__(
self,
dirty: bool = False,
retry_interval: float = 5.0,
use_lock: bool = False,
auto_dirty_seconds: float | None = None,
min_update_interval: float | None = None,
):
curtime = time.time()
self._retry_interval = retry_interval
self._auto_dirty_seconds = auto_dirty_seconds
self._min_update_interval = min_update_interval
self._dirty = dirty
self._next_update_time: float | None = curtime if dirty else None
self._last_update_time: float | None = None
self._next_auto_dirty_time: float | None = (
(curtime + self._auto_dirty_seconds)
if (not dirty and self._auto_dirty_seconds is not None)
else None
)
self._use_lock = use_lock
self.lock: asyncio.Lock
if self._use_lock:
import asyncio
self.lock = asyncio.Lock()
@property
def dirty(self) -> bool:
"""Whether the target is currently dirty.
This should be set to False once an update is successful.
"""
return self._dirty
@dirty.setter
def dirty(self, value: bool) -> None:
# If we're freshly clean, set our next auto-dirty time (if we have
# one).
if self._dirty and not value and self._auto_dirty_seconds is not None:
self._next_auto_dirty_time = time.time() + self._auto_dirty_seconds
# If we're freshly dirty, schedule an immediate update.
if not self._dirty and value:
self._next_update_time = time.time()
# If they want to enforce a minimum update interval,
# push out the next update time if it hasn't been long enough.
if (
self._min_update_interval is not None
and self._last_update_time is not None
):
self._next_update_time = max(
self._next_update_time,
self._last_update_time + self._min_update_interval,
)
self._dirty = value
@property
def should_update(self) -> bool:
"""Whether an attempt should be made to clean the target now.
Always returns False if the target is not dirty.
Takes into account the amount of time passed since the target
was marked dirty or since should_update last returned True.
"""
curtime = time.time()
# Auto-dirty ourself if we're into that.
if (
self._next_auto_dirty_time is not None
and curtime > self._next_auto_dirty_time
):
self.dirty = True
self._next_auto_dirty_time = None
if not self._dirty:
return False
if self._use_lock and self.lock.locked():
return False
assert self._next_update_time is not None
if curtime > self._next_update_time:
self._next_update_time = curtime + self._retry_interval
self._last_update_time = curtime
return True
return False
class DispatchMethodWrapper(Generic[ArgT, RetT]):
"""Type-aware standin for the dispatch func returned by dispatchmethod."""
def __call__(self, arg: ArgT) -> RetT:
raise RuntimeError('Should not get here')
@staticmethod
def register(
func: Callable[[Any, Any], RetT]
) -> Callable[[Any, Any], RetT]:
"""Register a new dispatch handler for this dispatch-method."""
raise RuntimeError('Should not get here')
registry: dict[Any, Callable]
# noinspection PyProtectedMember,PyTypeHints
def dispatchmethod(
func: Callable[[Any, ArgT], RetT]
) -> DispatchMethodWrapper[ArgT, RetT]:
"""A variation of functools.singledispatch for methods.
Note: as of Python 3.9 there is now functools.singledispatchmethod,
but it currently (as of Jan 2021) is not type-aware (at least in mypy),
which gives us a reason to keep this one around for now.
"""
from functools import singledispatch, update_wrapper
origwrapper: Any = singledispatch(func)
# Pull this out so hopefully origwrapper can die,
# otherwise we reference origwrapper in our wrapper.
dispatch = origwrapper.dispatch
# All we do here is recreate the end of functools.singledispatch
# where it returns a wrapper except instead of the wrapper using the
# first arg to the function ours uses the second (to skip 'self').
# This was made against Python 3.7; we should probably check up on
# this in later versions in case anything has changed.
# (or hopefully they'll add this functionality to their version)
# NOTE: sounds like we can use functools singledispatchmethod in 3.8
def wrapper(*args: Any, **kw: Any) -> Any:
if not args or len(args) < 2:
raise TypeError(
f'{funcname} requires at least ' '2 positional arguments'
)
return dispatch(args[1].__class__)(*args, **kw)
funcname = getattr(func, '__name__', 'dispatchmethod method')
wrapper.register = origwrapper.register # type: ignore
wrapper.dispatch = dispatch # type: ignore
wrapper.registry = origwrapper.registry # type: ignore
# pylint: disable=protected-access
wrapper._clear_cache = origwrapper._clear_cache # type: ignore
update_wrapper(wrapper, func)
# pylint: enable=protected-access
return cast(DispatchMethodWrapper, wrapper)
def valuedispatch(call: Callable[[ValT], RetT]) -> ValueDispatcher[ValT, RetT]:
"""Decorator for functions to allow dispatching based on a value.
This differs from functools.singledispatch in that it dispatches based
on the value of an argument, not based on its type.
The 'register' method of a value-dispatch function can be used
to assign new functions to handle particular values.
Unhandled values wind up in the original dispatch function."""
return ValueDispatcher(call)
class ValueDispatcher(Generic[ValT, RetT]):
"""Used by the valuedispatch decorator"""
def __init__(self, call: Callable[[ValT], RetT]) -> None:
self._base_call = call
self._handlers: dict[ValT, Callable[[], RetT]] = {}
def __call__(self, value: ValT) -> RetT:
handler = self._handlers.get(value)
if handler is not None:
return handler()
return self._base_call(value)
def _add_handler(
self, value: ValT, call: Callable[[], RetT]
) -> Callable[[], RetT]:
if value in self._handlers:
raise RuntimeError(f'Duplicate handlers added for {value}')
self._handlers[value] = call
return call
def register(
self, value: ValT
) -> Callable[[Callable[[], RetT]], Callable[[], RetT]]:
"""Add a handler to the dispatcher."""
from functools import partial
return partial(self._add_handler, value)
def valuedispatch1arg(
call: Callable[[ValT, ArgT], RetT]
) -> ValueDispatcher1Arg[ValT, ArgT, RetT]:
"""Like valuedispatch but for functions taking an extra argument."""
return ValueDispatcher1Arg(call)
class ValueDispatcher1Arg(Generic[ValT, ArgT, RetT]):
"""Used by the valuedispatch1arg decorator"""
def __init__(self, call: Callable[[ValT, ArgT], RetT]) -> None:
self._base_call = call
self._handlers: dict[ValT, Callable[[ArgT], RetT]] = {}
def __call__(self, value: ValT, arg: ArgT) -> RetT:
handler = self._handlers.get(value)
if handler is not None:
return handler(arg)
return self._base_call(value, arg)
def _add_handler(
self, value: ValT, call: Callable[[ArgT], RetT]
) -> Callable[[ArgT], RetT]:
if value in self._handlers:
raise RuntimeError(f'Duplicate handlers added for {value}')
self._handlers[value] = call
return call
def register(
self, value: ValT
) -> Callable[[Callable[[ArgT], RetT]], Callable[[ArgT], RetT]]:
"""Add a handler to the dispatcher."""
from functools import partial
return partial(self._add_handler, value)
if TYPE_CHECKING:
class ValueDispatcherMethod(Generic[ValT, RetT]):
"""Used by the valuedispatchmethod decorator."""
def __call__(self, value: ValT) -> RetT:
...
def register(
self, value: ValT
) -> Callable[[Callable[[SelfT], RetT]], Callable[[SelfT], RetT]]:
"""Add a handler to the dispatcher."""
...
def valuedispatchmethod(
call: Callable[[SelfT, ValT], RetT]
) -> ValueDispatcherMethod[ValT, RetT]:
"""Like valuedispatch but works with methods instead of functions."""
# NOTE: It seems that to wrap a method with a decorator and have self
# dispatching do the right thing, we must return a function and not
# an executable object. So for this version we store our data here
# in the function call dict and simply return a call.
_base_call = call
_handlers: dict[ValT, Callable[[SelfT], RetT]] = {}
def _add_handler(value: ValT, addcall: Callable[[SelfT], RetT]) -> None:
if value in _handlers:
raise RuntimeError(f'Duplicate handlers added for {value}')
_handlers[value] = addcall
def _register(value: ValT) -> Callable[[Callable[[SelfT], RetT]], None]:
from functools import partial
return partial(_add_handler, value)
def _call_wrapper(self: SelfT, value: ValT) -> RetT:
handler = _handlers.get(value)
if handler is not None:
return handler(self)
return _base_call(self, value)
# We still want to use our returned object to register handlers, but we're
# actually just returning a function. So manually stuff the call onto it.
setattr(_call_wrapper, 'register', _register)
# To the type checker's eyes we return a ValueDispatchMethod instance;
# this lets it know about our register func and type-check its usage.
# In reality we just return a raw function call (for reasons listed above).
# pylint: disable=undefined-variable, no-else-return
if TYPE_CHECKING:
return ValueDispatcherMethod[ValT, RetT]()
else:
return _call_wrapper
def make_hash(obj: Any) -> int:
"""Makes a hash from a dictionary, list, tuple or set to any level,
that contains only other hashable types (including any lists, tuples,
sets, and dictionaries).
Note that this uses Python's hash() function internally so collisions/etc.
may be more common than with fancy cryptographic hashes.
Also be aware that Python's hash() output varies across processes, so
this should only be used for values that will remain in a single process.
"""
import copy
if isinstance(obj, (set, tuple, list)):
return hash(tuple(make_hash(e) for e in obj))
if not isinstance(obj, dict):
return hash(obj)
new_obj = copy.deepcopy(obj)
for k, v in new_obj.items():
new_obj[k] = make_hash(v)
# NOTE: there is sorted works correctly because it compares only
# unique first values (i.e. dict keys)
return hash(tuple(frozenset(sorted(new_obj.items()))))
def asserttype(obj: Any, typ: type[T]) -> T:
"""Return an object typed as a given type.
Assert is used to check its actual type, so only use this when
failures are not expected. Otherwise use checktype.
"""
assert isinstance(typ, type), 'only actual types accepted'
assert isinstance(obj, typ)
return obj
def asserttype_o(obj: Any, typ: type[T]) -> T | None:
"""Return an object typed as a given optional type.
Assert is used to check its actual type, so only use this when
failures are not expected. Otherwise use checktype.
"""
assert isinstance(typ, type), 'only actual types accepted'
assert isinstance(obj, (typ, type(None)))
return obj
def checktype(obj: Any, typ: type[T]) -> T:
"""Return an object typed as a given type.
Always checks the type at runtime with isinstance and throws a TypeError
on failure. Use asserttype for more efficient (but less safe) equivalent.
"""
assert isinstance(typ, type), 'only actual types accepted'
if not isinstance(obj, typ):
raise TypeError(f'Expected a {typ}; got a {type(obj)}.')
return obj
def checktype_o(obj: Any, typ: type[T]) -> T | None:
"""Return an object typed as a given optional type.
Always checks the type at runtime with isinstance and throws a TypeError
on failure. Use asserttype for more efficient (but less safe) equivalent.
"""
assert isinstance(typ, type), 'only actual types accepted'
if not isinstance(obj, (typ, type(None))):
raise TypeError(f'Expected a {typ} or None; got a {type(obj)}.')
return obj
def warntype(obj: Any, typ: type[T]) -> T:
"""Return an object typed as a given type.
Always checks the type at runtime and simply logs a warning if it is
not what is expected.
"""
assert isinstance(typ, type), 'only actual types accepted'
if not isinstance(obj, typ):
import logging
logging.warning('warntype: expected a %s, got a %s', typ, type(obj))
return obj # type: ignore
def warntype_o(obj: Any, typ: type[T]) -> T | None:
"""Return an object typed as a given type.
Always checks the type at runtime and simply logs a warning if it is
not what is expected.
"""
assert isinstance(typ, type), 'only actual types accepted'
if not isinstance(obj, (typ, type(None))):
import logging
logging.warning(
'warntype: expected a %s or None, got a %s', typ, type(obj)
)
return obj # type: ignore
def assert_non_optional(obj: T | None) -> T:
"""Return an object with Optional typing removed.
Assert is used to check its actual type, so only use this when
failures are not expected. Use check_non_optional otherwise.
"""
assert obj is not None
return obj
def check_non_optional(obj: T | None) -> T:
"""Return an object with Optional typing removed.
Always checks the actual type and throws a TypeError on failure.
Use assert_non_optional for a more efficient (but less safe) equivalent.
"""
if obj is None:
raise TypeError('Got None value in check_non_optional.')
return obj
def smoothstep(edge0: float, edge1: float, x: float) -> float:
"""A smooth transition function.
Returns a value that smoothly moves from 0 to 1 as we go between edges.
Values outside of the range return 0 or 1.
"""
y = min(1.0, max(0.0, (x - edge0) / (edge1 - edge0)))
return y * y * (3.0 - 2.0 * y)
def linearstep(edge0: float, edge1: float, x: float) -> float:
"""A linear transition function.
Returns a value that linearly moves from 0 to 1 as we go between edges.
Values outside of the range return 0 or 1.
"""
return max(0.0, min(1.0, (x - edge0) / (edge1 - edge0)))
def _compact_id(num: int, chars: str) -> str:
if num < 0:
raise ValueError('Negative integers not allowed.')
# Chars must be in sorted order for sorting to work correctly
# on our output.
assert ''.join(sorted(list(chars))) == chars
base = len(chars)
out = ''
while num:
out += chars[num % base]
num //= base
return out[::-1] or '0'
def human_readable_compact_id(num: int) -> str:
"""Given a positive int, return a compact string representation for it.
Handy for visualizing unique numeric ids using as few as possible chars.
This representation uses only lowercase letters and numbers (minus the
following letters for readability):
's' is excluded due to similarity to '5'.
'l' is excluded due to similarity to '1'.
'i' is excluded due to similarity to '1'.
'o' is excluded due to similarity to '0'.
'z' is excluded due to similarity to '2'.
Therefore for n chars this can store values of 21^n.
When reading human input consisting of these IDs, it may be desirable
to map the disallowed chars to their corresponding allowed ones
('o' -> '0', etc).
Sort order for these ids is the same as the original numbers.
If more compactness is desired at the expense of readability,
use compact_id() instead.
"""
return _compact_id(num, '0123456789abcdefghjkmnpqrtuvwxy')
def compact_id(num: int) -> str:
"""Given a positive int, return a compact string representation for it.
Handy for visualizing unique numeric ids using as few as possible chars.
This version is more compact than human_readable_compact_id() but less
friendly to humans due to using both capital and lowercase letters,
both 'O' and '0', etc.
Therefore for n chars this can store values of 62^n.
Sort order for these ids is the same as the original numbers.
"""
return _compact_id(
num, '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
)
# NOTE: Even though this is available as part of typing_extensions, keeping
# it in here for now so we don't require typing_extensions as a dependency.
# Once 3.11 rolls around we can kill this and use typing.assert_never.
def assert_never(value: NoReturn) -> NoReturn:
"""Trick for checking exhaustive handling of Enums, etc.
See https://github.com/python/typing/issues/735
"""
assert False, f'Unhandled value: {value} ({type(value).__name__})'
def unchanging_hostname() -> str:
"""Return an unchanging name for the local device.
Similar to the `hostname` call (or os.uname().nodename in Python)
except attempts to give a name that doesn't change depending on
network conditions. (A Mac will tend to go from Foo to Foo.local,
Foo.lan etc. throughout its various adventures)
"""
import platform
import subprocess
# On Mac, this should give the computer name assigned in System Prefs.
if platform.system() == 'Darwin':
return (
subprocess.run(
['scutil', '--get', 'ComputerName'],
check=True,
capture_output=True,
)
.stdout.decode()
.strip()
.replace(' ', '-')
)
return os.uname().nodename
def set_canonical_module(
module_globals: dict[str, Any], names: list[str]
) -> None:
"""Override any __module__ attrs on passed classes/etc.
This allows classes to present themselves using clean paths such as
mymodule.MyClass instead of possibly ugly internal ones such as
mymodule._internal._stuff.MyClass.
"""
modulename = module_globals.get('__name__')
if not isinstance(modulename, str):
raise RuntimeError('Unable to get module name.')
for name in names:
obj = module_globals[name]
existing = getattr(obj, '__module__', None)
try:
if existing is not None and existing != modulename:
obj.__module__ = modulename
except Exception:
import logging
logging.warning(
'set_canonical_module: unable to change __module__'
" from '%s' to '%s' on %s object at '%s'.",
existing,
modulename,
type(obj),
name,
)