mirror of
https://github.com/hypervortex/VH-Bombsquad-Modded-Server-Files
synced 2025-11-07 17:36:08 +00:00
Initial commit
This commit is contained in:
parent
bc49523c99
commit
44d606cce7
1929 changed files with 612166 additions and 0 deletions
7
dist/ba_data/python/efro/__init__.py
vendored
Normal file
7
dist/ba_data/python/efro/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Common bits of functionality shared between all efro projects.
|
||||
|
||||
Things in here should be hardened, highly type-safe, and well-covered by unit
|
||||
tests since they are widely used in live client and server code.
|
||||
"""
|
||||
BIN
dist/ba_data/python/efro/__pycache__/__init__.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/__pycache__/__init__.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/__pycache__/call.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/__pycache__/call.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/__pycache__/error.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/__pycache__/error.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/__pycache__/log.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/__pycache__/log.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/__pycache__/rpc.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/__pycache__/rpc.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/__pycache__/terminal.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/__pycache__/terminal.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/__pycache__/util.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/__pycache__/util.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
366
dist/ba_data/python/efro/call.py
vendored
Normal file
366
dist/ba_data/python/efro/call.py
vendored
Normal file
|
|
@ -0,0 +1,366 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Call related functionality shared between all efro components."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic, Callable, cast
|
||||
import functools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, overload
|
||||
|
||||
CT = TypeVar('CT', bound=Callable)
|
||||
|
||||
|
||||
class _CallbackCall(Generic[CT]):
|
||||
"""Descriptor for exposing a call with a type defined by a TypeVar."""
|
||||
|
||||
def __get__(self, obj: Any, type_in: Any = None) -> CT:
|
||||
return cast(CT, None)
|
||||
|
||||
|
||||
class CallbackSet(Generic[CT]):
|
||||
"""Wrangles callbacks for a particular event in a type-safe manner."""
|
||||
|
||||
# In the type-checker's eyes, our 'run' attr is a CallbackCall which
|
||||
# returns a callable with the type we were created with. This lets us
|
||||
# type-check our run calls. (Is there another way to expose a function
|
||||
# with a signature defined by a generic?..)
|
||||
# At runtime, run() simply passes its args verbatim to its registered
|
||||
# callbacks; no types are checked.
|
||||
if TYPE_CHECKING:
|
||||
run: _CallbackCall[CT] = _CallbackCall()
|
||||
else:
|
||||
|
||||
def run(self, *args, **keywds):
|
||||
"""Run all callbacks."""
|
||||
print('HELLO FROM RUN', *args, **keywds)
|
||||
|
||||
def __init__(self) -> None:
|
||||
print('CallbackSet()')
|
||||
|
||||
def __del__(self) -> None:
|
||||
print('~CallbackSet()')
|
||||
|
||||
def add(self, call: CT) -> None:
|
||||
"""Add a callback to be run."""
|
||||
print('Would add call', call)
|
||||
|
||||
|
||||
# Define Call() which can be used in type-checking call-wrappers that behave
|
||||
# similarly to functools.partial (in that they take a callable and some
|
||||
# positional arguments to be passed to it).
|
||||
|
||||
# In type-checking land, We define several different _CallXArg classes
|
||||
# corresponding to different argument counts and define Call() as an
|
||||
# overloaded function which returns one of them based on how many args are
|
||||
# passed.
|
||||
|
||||
# To use this, simply assign your call type to this Call for type checking:
|
||||
# Example:
|
||||
# class _MyCallWrapper:
|
||||
# <runtime class defined here>
|
||||
# if TYPE_CHECKING:
|
||||
# MyCallWrapper = efro.call.Call
|
||||
# else:
|
||||
# MyCallWrapper = _MyCallWrapper
|
||||
|
||||
# Note that this setup currently only works with positional arguments; if you
|
||||
# would like to pass args via keyword you can wrap a lambda or local function
|
||||
# which takes keyword args and converts to a call containing keywords.
|
||||
|
||||
if TYPE_CHECKING:
|
||||
In1T = TypeVar('In1T')
|
||||
In2T = TypeVar('In2T')
|
||||
In3T = TypeVar('In3T')
|
||||
In4T = TypeVar('In4T')
|
||||
In5T = TypeVar('In5T')
|
||||
In6T = TypeVar('In6T')
|
||||
In7T = TypeVar('In7T')
|
||||
OutT = TypeVar('OutT')
|
||||
|
||||
class _CallNoArgs(Generic[OutT]):
|
||||
"""Single argument variant of call wrapper."""
|
||||
|
||||
def __init__(self, _call: Callable[[], OutT]):
|
||||
...
|
||||
|
||||
def __call__(self) -> OutT:
|
||||
...
|
||||
|
||||
class _Call1Arg(Generic[In1T, OutT]):
|
||||
"""Single argument variant of call wrapper."""
|
||||
|
||||
def __init__(self, _call: Callable[[In1T], OutT]):
|
||||
...
|
||||
|
||||
def __call__(self, _arg1: In1T) -> OutT:
|
||||
...
|
||||
|
||||
class _Call2Args(Generic[In1T, In2T, OutT]):
|
||||
"""Two argument variant of call wrapper"""
|
||||
|
||||
def __init__(self, _call: Callable[[In1T, In2T], OutT]):
|
||||
...
|
||||
|
||||
def __call__(self, _arg1: In1T, _arg2: In2T) -> OutT:
|
||||
...
|
||||
|
||||
class _Call3Args(Generic[In1T, In2T, In3T, OutT]):
|
||||
"""Three argument variant of call wrapper"""
|
||||
|
||||
def __init__(self, _call: Callable[[In1T, In2T, In3T], OutT]):
|
||||
...
|
||||
|
||||
def __call__(self, _arg1: In1T, _arg2: In2T, _arg3: In3T) -> OutT:
|
||||
...
|
||||
|
||||
class _Call4Args(Generic[In1T, In2T, In3T, In4T, OutT]):
|
||||
"""Four argument variant of call wrapper"""
|
||||
|
||||
def __init__(self, _call: Callable[[In1T, In2T, In3T, In4T], OutT]):
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self, _arg1: In1T, _arg2: In2T, _arg3: In3T, _arg4: In4T
|
||||
) -> OutT:
|
||||
...
|
||||
|
||||
class _Call5Args(Generic[In1T, In2T, In3T, In4T, In5T, OutT]):
|
||||
"""Five argument variant of call wrapper"""
|
||||
|
||||
def __init__(
|
||||
self, _call: Callable[[In1T, In2T, In3T, In4T, In5T], OutT]
|
||||
):
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
_arg1: In1T,
|
||||
_arg2: In2T,
|
||||
_arg3: In3T,
|
||||
_arg4: In4T,
|
||||
_arg5: In5T,
|
||||
) -> OutT:
|
||||
...
|
||||
|
||||
class _Call6Args(Generic[In1T, In2T, In3T, In4T, In5T, In6T, OutT]):
|
||||
"""Six argument variant of call wrapper"""
|
||||
|
||||
def __init__(
|
||||
self, _call: Callable[[In1T, In2T, In3T, In4T, In5T, In6T], OutT]
|
||||
):
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
_arg1: In1T,
|
||||
_arg2: In2T,
|
||||
_arg3: In3T,
|
||||
_arg4: In4T,
|
||||
_arg5: In5T,
|
||||
_arg6: In6T,
|
||||
) -> OutT:
|
||||
...
|
||||
|
||||
class _Call7Args(Generic[In1T, In2T, In3T, In4T, In5T, In6T, In7T, OutT]):
|
||||
"""Seven argument variant of call wrapper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_call: Callable[[In1T, In2T, In3T, In4T, In5T, In6T, In7T], OutT],
|
||||
):
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
_arg1: In1T,
|
||||
_arg2: In2T,
|
||||
_arg3: In3T,
|
||||
_arg4: In4T,
|
||||
_arg5: In5T,
|
||||
_arg6: In6T,
|
||||
_arg7: In7T,
|
||||
) -> OutT:
|
||||
...
|
||||
|
||||
# No arg call; no args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(call: Callable[[], OutT]) -> _CallNoArgs[OutT]:
|
||||
...
|
||||
|
||||
# 1 arg call; 1 arg bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(call: Callable[[In1T], OutT], arg1: In1T) -> _CallNoArgs[OutT]:
|
||||
...
|
||||
|
||||
# 1 arg call; no args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(call: Callable[[In1T], OutT]) -> _Call1Arg[In1T, OutT]:
|
||||
...
|
||||
|
||||
# 2 arg call; 2 args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T], OutT], arg1: In1T, arg2: In2T
|
||||
) -> _CallNoArgs[OutT]:
|
||||
...
|
||||
|
||||
# 2 arg call; 1 arg bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T], OutT], arg1: In1T
|
||||
) -> _Call1Arg[In2T, OutT]:
|
||||
...
|
||||
|
||||
# 2 arg call; no args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T], OutT]
|
||||
) -> _Call2Args[In1T, In2T, OutT]:
|
||||
...
|
||||
|
||||
# 3 arg call; 3 args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T, In3T], OutT],
|
||||
arg1: In1T,
|
||||
arg2: In2T,
|
||||
arg3: In3T,
|
||||
) -> _CallNoArgs[OutT]:
|
||||
...
|
||||
|
||||
# 3 arg call; 2 args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T, In3T], OutT], arg1: In1T, arg2: In2T
|
||||
) -> _Call1Arg[In3T, OutT]:
|
||||
...
|
||||
|
||||
# 3 arg call; 1 arg bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T, In3T], OutT], arg1: In1T
|
||||
) -> _Call2Args[In2T, In3T, OutT]:
|
||||
...
|
||||
|
||||
# 3 arg call; no args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T, In3T], OutT]
|
||||
) -> _Call3Args[In1T, In2T, In3T, OutT]:
|
||||
...
|
||||
|
||||
# 4 arg call; 4 args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T, In3T, In4T], OutT],
|
||||
arg1: In1T,
|
||||
arg2: In2T,
|
||||
arg3: In3T,
|
||||
arg4: In4T,
|
||||
) -> _CallNoArgs[OutT]:
|
||||
...
|
||||
|
||||
# 4 arg call; 3 args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T, In3T, In4T], OutT],
|
||||
arg1: In1T,
|
||||
arg2: In2T,
|
||||
arg3: In3T,
|
||||
) -> _Call1Arg[In4T, OutT]:
|
||||
...
|
||||
|
||||
# 4 arg call; 2 args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T, In3T, In4T], OutT],
|
||||
arg1: In1T,
|
||||
arg2: In2T,
|
||||
) -> _Call2Args[In3T, In4T, OutT]:
|
||||
...
|
||||
|
||||
# 4 arg call; 1 arg bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T, In3T, In4T], OutT],
|
||||
arg1: In1T,
|
||||
) -> _Call3Args[In2T, In3T, In4T, OutT]:
|
||||
...
|
||||
|
||||
# 4 arg call; no args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T, In3T, In4T], OutT],
|
||||
) -> _Call4Args[In1T, In2T, In3T, In4T, OutT]:
|
||||
...
|
||||
|
||||
# 5 arg call; 5 args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T, In3T, In4T, In5T], OutT],
|
||||
arg1: In1T,
|
||||
arg2: In2T,
|
||||
arg3: In3T,
|
||||
arg4: In4T,
|
||||
arg5: In5T,
|
||||
) -> _CallNoArgs[OutT]:
|
||||
...
|
||||
|
||||
# 6 arg call; 6 args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T, In3T, In4T, In5T, In6T], OutT],
|
||||
arg1: In1T,
|
||||
arg2: In2T,
|
||||
arg3: In3T,
|
||||
arg4: In4T,
|
||||
arg5: In5T,
|
||||
arg6: In6T,
|
||||
) -> _CallNoArgs[OutT]:
|
||||
...
|
||||
|
||||
# 7 arg call; 7 args bundled.
|
||||
# noinspection PyPep8Naming
|
||||
@overload
|
||||
def Call(
|
||||
call: Callable[[In1T, In2T, In3T, In4T, In5T, In6T, In7T], OutT],
|
||||
arg1: In1T,
|
||||
arg2: In2T,
|
||||
arg3: In3T,
|
||||
arg4: In4T,
|
||||
arg5: In5T,
|
||||
arg6: In6T,
|
||||
arg7: In7T,
|
||||
) -> _CallNoArgs[OutT]:
|
||||
...
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
def Call(*_args: Any, **_keywds: Any) -> Any:
|
||||
...
|
||||
|
||||
# (Type-safe Partial)
|
||||
# A convenient wrapper around functools.partial which adds type-safety
|
||||
# (though it does not support keyword arguments).
|
||||
tpartial = Call
|
||||
else:
|
||||
tpartial = functools.partial
|
||||
49
dist/ba_data/python/efro/cloudshell.py
vendored
Normal file
49
dist/ba_data/python/efro/cloudshell.py
vendored
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""My nifty ssh/mosh/rsync mishmash."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
|
||||
from efro.dataclassio import ioprepped
|
||||
|
||||
|
||||
class LockType(Enum):
|
||||
"""Types of locks that can be acquired on a host."""
|
||||
|
||||
HOST = 'host'
|
||||
WORKSPACE = 'workspace'
|
||||
PYCHARM = 'pycharm'
|
||||
CLION = 'clion'
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class HostConfig:
|
||||
"""Config for a cloud machine to run commands on.
|
||||
|
||||
precommand, if set, will be run before the passed commands.
|
||||
Note that it is not run in interactive mode (when no command is given).
|
||||
"""
|
||||
|
||||
address: str | None = None
|
||||
user: str = 'ubuntu'
|
||||
port: int = 22
|
||||
mosh_port: int | None = None
|
||||
mosh_server_path: str | None = None
|
||||
mosh_shell: str = 'sh'
|
||||
workspaces_root: str = '/home/${USER}/cloudshell_workspaces'
|
||||
sync_perms: bool = True
|
||||
precommand: str | None = None
|
||||
managed: bool = False
|
||||
idle_minutes: int = 5
|
||||
can_sudo_reboot: bool = False
|
||||
max_sessions: int = 3
|
||||
reboot_wait_seconds: int = 20
|
||||
reboot_attempts: int = 1
|
||||
|
||||
def resolved_workspaces_root(self) -> str:
|
||||
"""Returns workspaces_root with standard substitutions."""
|
||||
return self.workspaces_root.replace('${USER}', self.user)
|
||||
295
dist/ba_data/python/efro/dataclasses.py
vendored
Normal file
295
dist/ba_data/python/efro/dataclasses.py
vendored
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Custom functionality for dealing with dataclasses."""
|
||||
# Note: We do lots of comparing of exact types here which is normally
|
||||
# frowned upon (stuff like isinstance() is usually encouraged).
|
||||
# pylint: disable=unidiomatic-typecheck
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic
|
||||
|
||||
from efro.util import enum_by_value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Dict, Type, Tuple, Optional
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
SIMPLE_NAMES_TO_TYPES: Dict[str, Type] = {
|
||||
'int': int,
|
||||
'bool': bool,
|
||||
'str': str,
|
||||
'float': float,
|
||||
}
|
||||
SIMPLE_TYPES_TO_NAMES = {tp: nm for nm, tp in SIMPLE_NAMES_TO_TYPES.items()}
|
||||
|
||||
|
||||
def dataclass_to_dict(obj: Any, coerce_to_float: bool = True) -> dict:
|
||||
"""Given a dataclass object, emit a json-friendly dict.
|
||||
|
||||
All values will be checked to ensure they match the types specified
|
||||
on fields. Note that only a limited set of types is supported.
|
||||
|
||||
If coerce_to_float is True, integer values present on float typed fields
|
||||
will be converted to floats in the dict output. If False, a TypeError
|
||||
will be triggered.
|
||||
"""
|
||||
|
||||
out = _Outputter(obj, create=True, coerce_to_float=coerce_to_float).run()
|
||||
assert isinstance(out, dict)
|
||||
return out
|
||||
|
||||
|
||||
def dataclass_from_dict(cls: Type[T],
|
||||
values: dict,
|
||||
coerce_to_float: bool = True) -> T:
|
||||
"""Given a dict, instantiates a dataclass of the given type.
|
||||
|
||||
The dict must be in the json-friendly format as emitted from
|
||||
dataclass_to_dict. This means that sequence values such as tuples or
|
||||
sets should be passed as lists, enums should be passed as their
|
||||
associated values, and nested dataclasses should be passed as dicts.
|
||||
|
||||
If coerce_to_float is True, int values passed for float typed fields
|
||||
will be converted to float values. Otherwise a TypeError is raised.
|
||||
"""
|
||||
return _Inputter(cls, coerce_to_float=coerce_to_float).run(values)
|
||||
|
||||
|
||||
def dataclass_validate(obj: Any, coerce_to_float: bool = True) -> None:
|
||||
"""Ensure that current values in a dataclass are the correct types."""
|
||||
_Outputter(obj, create=False, coerce_to_float=coerce_to_float).run()
|
||||
|
||||
|
||||
def _field_type_str(cls: Type, field: dataclasses.Field) -> str:
|
||||
# We expect to be operating under 'from __future__ import annotations'
|
||||
# so field types should always be strings for us; not actual types.
|
||||
# (Can pull this check out once we get to Python 3.10)
|
||||
typestr: str = field.type # type: ignore
|
||||
|
||||
if not isinstance(typestr, str):
|
||||
raise RuntimeError(
|
||||
f'Dataclass {cls.__name__} seems to have'
|
||||
f' been created without "from __future__ import annotations";'
|
||||
f' those dataclasses are unsupported here.')
|
||||
return typestr
|
||||
|
||||
|
||||
def _raise_type_error(fieldpath: str, valuetype: Type,
|
||||
expected: Tuple[Type, ...]) -> None:
|
||||
"""Raise an error when a field value's type does not match expected."""
|
||||
assert isinstance(expected, tuple)
|
||||
assert all(isinstance(e, type) for e in expected)
|
||||
if len(expected) == 1:
|
||||
expected_str = expected[0].__name__
|
||||
else:
|
||||
names = ', '.join(t.__name__ for t in expected)
|
||||
expected_str = f'Union[{names}]'
|
||||
raise TypeError(f'Invalid value type for "{fieldpath}";'
|
||||
f' expected "{expected_str}", got'
|
||||
f' "{valuetype.__name__}".')
|
||||
|
||||
|
||||
class _Outputter:
|
||||
|
||||
def __init__(self, obj: Any, create: bool, coerce_to_float: bool) -> None:
|
||||
self._obj = obj
|
||||
self._create = create
|
||||
self._coerce_to_float = coerce_to_float
|
||||
|
||||
def run(self) -> Any:
|
||||
"""Do the thing."""
|
||||
return self._dataclass_to_output(self._obj, '')
|
||||
|
||||
def _value_to_output(self, fieldpath: str, typestr: str,
|
||||
value: Any) -> Any:
|
||||
# pylint: disable=too-many-return-statements
|
||||
# pylint: disable=too-many-branches
|
||||
|
||||
# For simple flat types, look for exact matches:
|
||||
simpletype = SIMPLE_NAMES_TO_TYPES.get(typestr)
|
||||
if simpletype is not None:
|
||||
if type(value) is not simpletype:
|
||||
# Special case: if they want to coerce ints to floats, do so.
|
||||
if (self._coerce_to_float and simpletype is float
|
||||
and type(value) is int):
|
||||
return float(value) if self._create else None
|
||||
_raise_type_error(fieldpath, type(value), (simpletype, ))
|
||||
return value
|
||||
|
||||
if typestr.startswith('Optional[') and typestr.endswith(']'):
|
||||
subtypestr = typestr[9:-1]
|
||||
# Handle the 'None' case special and do the default otherwise.
|
||||
if value is None:
|
||||
return None
|
||||
return self._value_to_output(fieldpath, subtypestr, value)
|
||||
|
||||
if typestr.startswith('List[') and typestr.endswith(']'):
|
||||
subtypestr = typestr[5:-1]
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f'Expected a list for {fieldpath};'
|
||||
f' found a {type(value)}')
|
||||
if self._create:
|
||||
return [
|
||||
self._value_to_output(fieldpath, subtypestr, x)
|
||||
for x in value
|
||||
]
|
||||
for x in value:
|
||||
self._value_to_output(fieldpath, subtypestr, x)
|
||||
return None
|
||||
|
||||
if typestr.startswith('Set[') and typestr.endswith(']'):
|
||||
subtypestr = typestr[4:-1]
|
||||
if not isinstance(value, set):
|
||||
raise TypeError(f'Expected a set for {fieldpath};'
|
||||
f' found a {type(value)}')
|
||||
if self._create:
|
||||
# Note: we output json-friendly values so this becomes a list.
|
||||
return [
|
||||
self._value_to_output(fieldpath, subtypestr, x)
|
||||
for x in value
|
||||
]
|
||||
for x in value:
|
||||
self._value_to_output(fieldpath, subtypestr, x)
|
||||
return None
|
||||
|
||||
if dataclasses.is_dataclass(value):
|
||||
return self._dataclass_to_output(value, fieldpath)
|
||||
|
||||
if isinstance(value, Enum):
|
||||
enumvalue = value.value
|
||||
if type(enumvalue) not in SIMPLE_TYPES_TO_NAMES:
|
||||
raise TypeError(f'Invalid enum value type {type(enumvalue)}'
|
||||
f' for "{fieldpath}".')
|
||||
return enumvalue
|
||||
|
||||
raise TypeError(
|
||||
f"Field '{fieldpath}' of type '{typestr}' is unsupported here.")
|
||||
|
||||
def _dataclass_to_output(self, obj: Any, fieldpath: str) -> Any:
|
||||
if not dataclasses.is_dataclass(obj):
|
||||
raise TypeError(f'Passed obj {obj} is not a dataclass.')
|
||||
fields = dataclasses.fields(obj)
|
||||
out: Optional[Dict[str, Any]] = {} if self._create else None
|
||||
|
||||
for field in fields:
|
||||
fieldname = field.name
|
||||
|
||||
if fieldpath:
|
||||
subfieldpath = f'{fieldpath}.{fieldname}'
|
||||
else:
|
||||
subfieldpath = fieldname
|
||||
typestr = _field_type_str(type(obj), field)
|
||||
value = getattr(obj, fieldname)
|
||||
outvalue = self._value_to_output(subfieldpath, typestr, value)
|
||||
if self._create:
|
||||
assert out is not None
|
||||
out[fieldname] = outvalue
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class _Inputter(Generic[T]):
|
||||
|
||||
def __init__(self, cls: Type[T], coerce_to_float: bool):
|
||||
self._cls = cls
|
||||
self._coerce_to_float = coerce_to_float
|
||||
|
||||
def run(self, values: dict) -> T:
|
||||
"""Do the thing."""
|
||||
return self._dataclass_from_input( # type: ignore
|
||||
self._cls, '', values)
|
||||
|
||||
def _value_from_input(self, cls: Type, fieldpath: str, typestr: str,
|
||||
value: Any) -> Any:
|
||||
"""Convert an assigned value to what a dataclass field expects."""
|
||||
# pylint: disable=too-many-return-statements
|
||||
|
||||
simpletype = SIMPLE_NAMES_TO_TYPES.get(typestr)
|
||||
if simpletype is not None:
|
||||
if type(value) is not simpletype:
|
||||
# Special case: if they want to coerce ints to floats, do so.
|
||||
if (self._coerce_to_float and simpletype is float
|
||||
and type(value) is int):
|
||||
return float(value)
|
||||
_raise_type_error(fieldpath, type(value), (simpletype, ))
|
||||
return value
|
||||
if typestr.startswith('List[') and typestr.endswith(']'):
|
||||
return self._sequence_from_input(cls, fieldpath, typestr, value,
|
||||
'List', list)
|
||||
if typestr.startswith('Set[') and typestr.endswith(']'):
|
||||
return self._sequence_from_input(cls, fieldpath, typestr, value,
|
||||
'Set', set)
|
||||
if typestr.startswith('Optional[') and typestr.endswith(']'):
|
||||
subtypestr = typestr[9:-1]
|
||||
# Handle the 'None' case special and do the default
|
||||
# thing otherwise.
|
||||
if value is None:
|
||||
return None
|
||||
return self._value_from_input(cls, fieldpath, subtypestr, value)
|
||||
|
||||
# Ok, its not a builtin type. It might be an enum or nested dataclass.
|
||||
cls2 = getattr(inspect.getmodule(cls), typestr, None)
|
||||
if cls2 is None:
|
||||
raise RuntimeError(f"Unable to resolve '{typestr}'"
|
||||
f" used by class '{cls.__name__}';"
|
||||
f' make sure all nested types are declared'
|
||||
f' in the global namespace of the module where'
|
||||
f" '{cls.__name__} is defined.")
|
||||
|
||||
if dataclasses.is_dataclass(cls2):
|
||||
return self._dataclass_from_input(cls2, fieldpath, value)
|
||||
|
||||
if issubclass(cls2, Enum):
|
||||
return enum_by_value(cls2, value)
|
||||
|
||||
raise TypeError(
|
||||
f"Field '{fieldpath}' of type '{typestr}' is unsupported here.")
|
||||
|
||||
def _dataclass_from_input(self, cls: Type, fieldpath: str,
|
||||
values: dict) -> Any:
|
||||
"""Given a dict, instantiates a dataclass of the given type.
|
||||
|
||||
The dict must be in the json-friendly format as emitted from
|
||||
dataclass_to_dict. This means that sequence values such as tuples or
|
||||
sets should be passed as lists, enums should be passed as their
|
||||
associated values, and nested dataclasses should be passed as dicts.
|
||||
"""
|
||||
if not dataclasses.is_dataclass(cls):
|
||||
raise TypeError(f'Passed class {cls} is not a dataclass.')
|
||||
if not isinstance(values, dict):
|
||||
raise TypeError("Expected a dict for 'values' arg.")
|
||||
|
||||
# noinspection PyDataclass
|
||||
fields = dataclasses.fields(cls)
|
||||
fields_by_name = {f.name: f for f in fields}
|
||||
args: Dict[str, Any] = {}
|
||||
for key, value in values.items():
|
||||
field = fields_by_name.get(key)
|
||||
if field is None:
|
||||
raise AttributeError(f"'{cls.__name__}' has no '{key}' field.")
|
||||
|
||||
typestr = _field_type_str(cls, field)
|
||||
|
||||
subfieldpath = (f'{fieldpath}.{field.name}'
|
||||
if fieldpath else field.name)
|
||||
args[key] = self._value_from_input(cls, subfieldpath, typestr,
|
||||
value)
|
||||
|
||||
return cls(**args)
|
||||
|
||||
def _sequence_from_input(self, cls: Type, fieldpath: str, typestr: str,
|
||||
value: Any, seqtypestr: str,
|
||||
seqtype: Type) -> Any:
|
||||
# Because we are json-centric, we expect a list for all sequences.
|
||||
if type(value) is not list:
|
||||
raise TypeError(f'Invalid input value for "{fieldpath}";'
|
||||
f' expected a list, got a {type(value).__name__}')
|
||||
subtypestr = typestr[len(seqtypestr) + 1:-1]
|
||||
return seqtype(
|
||||
self._value_from_input(cls, fieldpath, subtypestr, i)
|
||||
for i in value)
|
||||
1351
dist/ba_data/python/efro/dataclassio.py
vendored
Normal file
1351
dist/ba_data/python/efro/dataclassio.py
vendored
Normal file
File diff suppressed because it is too large
Load diff
50
dist/ba_data/python/efro/dataclassio/__init__.py
vendored
Normal file
50
dist/ba_data/python/efro/dataclassio/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for importing, exporting, and validating dataclasses.
|
||||
|
||||
This allows complex nested dataclasses to be flattened to json-compatible
|
||||
data and restored from said data. It also gracefully handles and preserves
|
||||
unrecognized attribute data, allowing older clients to interact with newer
|
||||
data formats in a nondestructive manner.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from efro.util import set_canonical_module
|
||||
from efro.dataclassio._base import Codec, IOAttrs, IOExtendedData
|
||||
from efro.dataclassio._prep import (
|
||||
ioprep,
|
||||
ioprepped,
|
||||
will_ioprep,
|
||||
is_ioprepped_dataclass,
|
||||
)
|
||||
from efro.dataclassio._pathcapture import DataclassFieldLookup
|
||||
from efro.dataclassio._api import (
|
||||
JsonStyle,
|
||||
dataclass_to_dict,
|
||||
dataclass_to_json,
|
||||
dataclass_from_dict,
|
||||
dataclass_from_json,
|
||||
dataclass_validate,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'JsonStyle',
|
||||
'Codec',
|
||||
'IOAttrs',
|
||||
'IOExtendedData',
|
||||
'ioprep',
|
||||
'ioprepped',
|
||||
'will_ioprep',
|
||||
'is_ioprepped_dataclass',
|
||||
'DataclassFieldLookup',
|
||||
'dataclass_to_dict',
|
||||
'dataclass_to_json',
|
||||
'dataclass_from_dict',
|
||||
'dataclass_from_json',
|
||||
'dataclass_validate',
|
||||
]
|
||||
|
||||
# Have these things present themselves cleanly as 'thismodule.SomeClass'
|
||||
# instead of 'thismodule._internalmodule.SomeClass'
|
||||
set_canonical_module(module_globals=globals(), names=__all__)
|
||||
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/__init__.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/__init__.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_api.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_api.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_base.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_base.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_inputter.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_inputter.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_outputter.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_outputter.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_pathcapture.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_pathcapture.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_prep.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_prep.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
163
dist/ba_data/python/efro/dataclassio/_api.py
vendored
Normal file
163
dist/ba_data/python/efro/dataclassio/_api.py
vendored
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for importing, exporting, and validating dataclasses.
|
||||
|
||||
This allows complex nested dataclasses to be flattened to json-compatible
|
||||
data and restored from said data. It also gracefully handles and preserves
|
||||
unrecognized attribute data, allowing older clients to interact with newer
|
||||
data formats in a nondestructive manner.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from efro.dataclassio._outputter import _Outputter
|
||||
from efro.dataclassio._inputter import _Inputter
|
||||
from efro.dataclassio._base import Codec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class JsonStyle(Enum):
|
||||
"""Different style types for json."""
|
||||
|
||||
# Single line, no spaces, no sorting. Not deterministic.
|
||||
# Use this for most storage purposes.
|
||||
FAST = 'fast'
|
||||
|
||||
# Single line, no spaces, sorted keys. Deterministic.
|
||||
# Use this when output may be hashed or compared for equality.
|
||||
SORTED = 'sorted'
|
||||
|
||||
# Multiple lines, spaces, sorted keys. Deterministic.
|
||||
# Use this for pretty human readable output.
|
||||
PRETTY = 'pretty'
|
||||
|
||||
|
||||
def dataclass_to_dict(
|
||||
obj: Any, codec: Codec = Codec.JSON, coerce_to_float: bool = True
|
||||
) -> dict:
|
||||
"""Given a dataclass object, return a json-friendly dict.
|
||||
|
||||
All values will be checked to ensure they match the types specified
|
||||
on fields. Note that a limited set of types and data configurations is
|
||||
supported.
|
||||
|
||||
Values with type Any will be checked to ensure they match types supported
|
||||
directly by json. This does not include types such as tuples which are
|
||||
implicitly translated by Python's json module (as this would break
|
||||
the ability to do a lossless round-trip with data).
|
||||
|
||||
If coerce_to_float is True, integer values present on float typed fields
|
||||
will be converted to float in the dict output. If False, a TypeError
|
||||
will be triggered.
|
||||
"""
|
||||
|
||||
out = _Outputter(
|
||||
obj, create=True, codec=codec, coerce_to_float=coerce_to_float
|
||||
).run()
|
||||
assert isinstance(out, dict)
|
||||
return out
|
||||
|
||||
|
||||
def dataclass_to_json(
|
||||
obj: Any,
|
||||
coerce_to_float: bool = True,
|
||||
pretty: bool = False,
|
||||
sort_keys: bool | None = None,
|
||||
) -> str:
|
||||
"""Utility function; return a json string from a dataclass instance.
|
||||
|
||||
Basically json.dumps(dataclass_to_dict(...)).
|
||||
By default, keys are sorted for pretty output and not otherwise, but
|
||||
this can be overridden by supplying a value for the 'sort_keys' arg.
|
||||
"""
|
||||
import json
|
||||
|
||||
jdict = dataclass_to_dict(
|
||||
obj=obj, coerce_to_float=coerce_to_float, codec=Codec.JSON
|
||||
)
|
||||
if sort_keys is None:
|
||||
sort_keys = pretty
|
||||
if pretty:
|
||||
return json.dumps(jdict, indent=2, sort_keys=sort_keys)
|
||||
return json.dumps(jdict, separators=(',', ':'), sort_keys=sort_keys)
|
||||
|
||||
|
||||
def dataclass_from_dict(
|
||||
cls: type[T],
|
||||
values: dict,
|
||||
codec: Codec = Codec.JSON,
|
||||
coerce_to_float: bool = True,
|
||||
allow_unknown_attrs: bool = True,
|
||||
discard_unknown_attrs: bool = False,
|
||||
) -> T:
|
||||
"""Given a dict, return a dataclass of a given type.
|
||||
|
||||
The dict must be formatted to match the specified codec (generally
|
||||
json-friendly object types). This means that sequence values such as
|
||||
tuples or sets should be passed as lists, enums should be passed as their
|
||||
associated values, nested dataclasses should be passed as dicts, etc.
|
||||
|
||||
All values are checked to ensure their types/values are valid.
|
||||
|
||||
Data for attributes of type Any will be checked to ensure they match
|
||||
types supported directly by json. This does not include types such
|
||||
as tuples which are implicitly translated by Python's json module
|
||||
(as this would break the ability to do a lossless round-trip with data).
|
||||
|
||||
If coerce_to_float is True, int values passed for float typed fields
|
||||
will be converted to float values. Otherwise, a TypeError is raised.
|
||||
|
||||
If allow_unknown_attrs is False, AttributeErrors will be raised for
|
||||
attributes present in the dict but not on the data class. Otherwise, they
|
||||
will be preserved as part of the instance and included if it is
|
||||
exported back to a dict, unless discard_unknown_attrs is True, in which
|
||||
case they will simply be discarded.
|
||||
"""
|
||||
return _Inputter(
|
||||
cls,
|
||||
codec=codec,
|
||||
coerce_to_float=coerce_to_float,
|
||||
allow_unknown_attrs=allow_unknown_attrs,
|
||||
discard_unknown_attrs=discard_unknown_attrs,
|
||||
).run(values)
|
||||
|
||||
|
||||
def dataclass_from_json(
|
||||
cls: type[T],
|
||||
json_str: str,
|
||||
coerce_to_float: bool = True,
|
||||
allow_unknown_attrs: bool = True,
|
||||
discard_unknown_attrs: bool = False,
|
||||
) -> T:
|
||||
"""Utility function; return a dataclass instance given a json string.
|
||||
|
||||
Basically dataclass_from_dict(json.loads(...))
|
||||
"""
|
||||
import json
|
||||
|
||||
return dataclass_from_dict(
|
||||
cls=cls,
|
||||
values=json.loads(json_str),
|
||||
coerce_to_float=coerce_to_float,
|
||||
allow_unknown_attrs=allow_unknown_attrs,
|
||||
discard_unknown_attrs=discard_unknown_attrs,
|
||||
)
|
||||
|
||||
|
||||
def dataclass_validate(
|
||||
obj: Any, coerce_to_float: bool = True, codec: Codec = Codec.JSON
|
||||
) -> None:
|
||||
"""Ensure that values in a dataclass instance are the correct types."""
|
||||
|
||||
# Simply run an output pass but tell it not to generate data;
|
||||
# only run validation.
|
||||
_Outputter(
|
||||
obj, create=False, codec=codec, coerce_to_float=coerce_to_float
|
||||
).run()
|
||||
276
dist/ba_data/python/efro/dataclassio/_base.py
vendored
Normal file
276
dist/ba_data/python/efro/dataclassio/_base.py
vendored
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Core components of dataclassio."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import typing
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, get_args
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from typing import _AnnotatedAlias # type: ignore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable
|
||||
|
||||
# Types which we can pass through as-is.
|
||||
SIMPLE_TYPES = {int, bool, str, float, type(None)}
|
||||
|
||||
# Attr name for dict of extra attributes included on dataclass instances.
|
||||
# Note that this is only added if extra attributes are present.
|
||||
EXTRA_ATTRS_ATTR = '_DCIOEXATTRS'
|
||||
|
||||
|
||||
def _raise_type_error(
|
||||
fieldpath: str, valuetype: type, expected: tuple[type, ...]
|
||||
) -> None:
|
||||
"""Raise an error when a field value's type does not match expected."""
|
||||
assert isinstance(expected, tuple)
|
||||
assert all(isinstance(e, type) for e in expected)
|
||||
if len(expected) == 1:
|
||||
expected_str = expected[0].__name__
|
||||
else:
|
||||
expected_str = ' | '.join(t.__name__ for t in expected)
|
||||
raise TypeError(
|
||||
f'Invalid value type for "{fieldpath}";'
|
||||
f' expected "{expected_str}", got'
|
||||
f' "{valuetype.__name__}".'
|
||||
)
|
||||
|
||||
|
||||
class Codec(Enum):
|
||||
"""Specifies expected data format exported to or imported from."""
|
||||
|
||||
# Use only types that will translate cleanly to/from json: lists,
|
||||
# dicts with str keys, bools, ints, floats, and None.
|
||||
JSON = 'json'
|
||||
|
||||
# Mostly like JSON but passes bytes and datetime objects through
|
||||
# as-is instead of converting them to json-friendly types.
|
||||
FIRESTORE = 'firestore'
|
||||
|
||||
|
||||
class IOExtendedData:
|
||||
"""A class that data types can inherit from for extra functionality."""
|
||||
|
||||
def will_output(self) -> None:
|
||||
"""Called before data is sent to an outputter.
|
||||
|
||||
Can be overridden to validate or filter data before
|
||||
sending it on its way.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def will_input(cls, data: dict) -> None:
|
||||
"""Called on raw data before a class instance is created from it.
|
||||
|
||||
Can be overridden to migrate old data formats to new, etc.
|
||||
"""
|
||||
|
||||
|
||||
def _is_valid_for_codec(obj: Any, codec: Codec) -> bool:
|
||||
"""Return whether a value consists solely of json-supported types.
|
||||
|
||||
Note that this does not include things like tuples which are
|
||||
implicitly translated to lists by python's json module.
|
||||
"""
|
||||
if obj is None:
|
||||
return True
|
||||
|
||||
objtype = type(obj)
|
||||
if objtype in (int, float, str, bool):
|
||||
return True
|
||||
if objtype is dict:
|
||||
# JSON 'objects' supports only string dict keys, but all value types.
|
||||
return all(
|
||||
isinstance(k, str) and _is_valid_for_codec(v, codec)
|
||||
for k, v in obj.items()
|
||||
)
|
||||
if objtype is list:
|
||||
return all(_is_valid_for_codec(elem, codec) for elem in obj)
|
||||
|
||||
# A few things are valid in firestore but not json.
|
||||
if issubclass(objtype, datetime.datetime) or objtype is bytes:
|
||||
return codec is Codec.FIRESTORE
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class IOAttrs:
|
||||
"""For specifying io behavior in annotations.
|
||||
|
||||
'storagename', if passed, is the name used when storing to json/etc.
|
||||
'store_default' can be set to False to avoid writing values when equal
|
||||
to the default value. Note that this requires the dataclass field
|
||||
to define a default or default_factory or for its IOAttrs to
|
||||
define a soft_default value.
|
||||
'whole_days', if True, requires datetime values to be exactly on day
|
||||
boundaries (see efro.util.utc_today()).
|
||||
'whole_hours', if True, requires datetime values to lie exactly on hour
|
||||
boundaries (see efro.util.utc_this_hour()).
|
||||
'whole_minutes', if True, requires datetime values to lie exactly on minute
|
||||
boundaries (see efro.util.utc_this_minute()).
|
||||
'soft_default', if passed, injects a default value into dataclass
|
||||
instantiation when the field is not present in the input data.
|
||||
This allows dataclasses to add new non-optional fields while
|
||||
gracefully 'upgrading' old data. Note that when a soft_default is
|
||||
present it will take precedence over field defaults when determining
|
||||
whether to store a value for a field with store_default=False
|
||||
(since the soft_default value is what we'll get when reading that
|
||||
same data back in when the field is omitted).
|
||||
'soft_default_factory' is similar to 'default_factory' in dataclass
|
||||
fields; it should be used instead of 'soft_default' for mutable types
|
||||
such as lists to prevent a single default object from unintentionally
|
||||
changing over time.
|
||||
"""
|
||||
|
||||
# A sentinel object to detect if a parameter is supplied or not. Use
|
||||
# a class to give it a better repr.
|
||||
class _MissingType:
|
||||
pass
|
||||
|
||||
MISSING = _MissingType()
|
||||
|
||||
storagename: str | None = None
|
||||
store_default: bool = True
|
||||
whole_days: bool = False
|
||||
whole_hours: bool = False
|
||||
whole_minutes: bool = False
|
||||
soft_default: Any = MISSING
|
||||
soft_default_factory: Callable[[], Any] | _MissingType = MISSING
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storagename: str | None = storagename,
|
||||
store_default: bool = store_default,
|
||||
whole_days: bool = whole_days,
|
||||
whole_hours: bool = whole_hours,
|
||||
whole_minutes: bool = whole_minutes,
|
||||
soft_default: Any = MISSING,
|
||||
soft_default_factory: Callable[[], Any] | _MissingType = MISSING,
|
||||
):
|
||||
|
||||
# Only store values that differ from class defaults to keep
|
||||
# our instances nice and lean.
|
||||
cls = type(self)
|
||||
if storagename != cls.storagename:
|
||||
self.storagename = storagename
|
||||
if store_default != cls.store_default:
|
||||
self.store_default = store_default
|
||||
if whole_days != cls.whole_days:
|
||||
self.whole_days = whole_days
|
||||
if whole_hours != cls.whole_hours:
|
||||
self.whole_hours = whole_hours
|
||||
if whole_minutes != cls.whole_minutes:
|
||||
self.whole_minutes = whole_minutes
|
||||
if soft_default is not cls.soft_default:
|
||||
|
||||
# Do what dataclasses does with its default types and
|
||||
# tell the user to use factory for mutable ones.
|
||||
if isinstance(soft_default, (list, dict, set)):
|
||||
raise ValueError(
|
||||
f'mutable {type(soft_default)} is not allowed'
|
||||
f' for soft_default; use soft_default_factory.'
|
||||
)
|
||||
self.soft_default = soft_default
|
||||
if soft_default_factory is not cls.soft_default_factory:
|
||||
self.soft_default_factory = soft_default_factory
|
||||
if self.soft_default is not cls.soft_default:
|
||||
raise ValueError(
|
||||
'Cannot set both soft_default and soft_default_factory'
|
||||
)
|
||||
|
||||
def validate_for_field(self, cls: type, field: dataclasses.Field) -> None:
|
||||
"""Ensure the IOAttrs instance is ok to use with the provided field."""
|
||||
|
||||
# Turning off store_default requires the field to have either
|
||||
# a default or a a default_factory or for us to have soft equivalents.
|
||||
|
||||
if not self.store_default:
|
||||
field_default_factory: Any = field.default_factory
|
||||
if (
|
||||
field_default_factory is dataclasses.MISSING
|
||||
and field.default is dataclasses.MISSING
|
||||
and self.soft_default is self.MISSING
|
||||
and self.soft_default_factory is self.MISSING
|
||||
):
|
||||
raise TypeError(
|
||||
f'Field {field.name} of {cls} has'
|
||||
f' neither a default nor a default_factory'
|
||||
f' and IOAttrs contains neither a soft_default'
|
||||
f' nor a soft_default_factory;'
|
||||
f' store_default=False cannot be set for it.'
|
||||
)
|
||||
|
||||
def validate_datetime(
|
||||
self, value: datetime.datetime, fieldpath: str
|
||||
) -> None:
|
||||
"""Ensure a datetime value meets our value requirements."""
|
||||
if self.whole_days:
|
||||
if any(
|
||||
x != 0
|
||||
for x in (
|
||||
value.hour,
|
||||
value.minute,
|
||||
value.second,
|
||||
value.microsecond,
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f'Value {value} at {fieldpath} is not a whole day.'
|
||||
)
|
||||
elif self.whole_hours:
|
||||
if any(
|
||||
x != 0 for x in (value.minute, value.second, value.microsecond)
|
||||
):
|
||||
raise ValueError(
|
||||
f'Value {value} at {fieldpath}' f' is not a whole hour.'
|
||||
)
|
||||
elif self.whole_minutes:
|
||||
if any(x != 0 for x in (value.second, value.microsecond)):
|
||||
raise ValueError(
|
||||
f'Value {value} at {fieldpath}' f' is not a whole minute.'
|
||||
)
|
||||
|
||||
|
||||
def _get_origin(anntype: Any) -> Any:
|
||||
"""Given a type annotation, return its origin or itself if there is none.
|
||||
|
||||
This differs from typing.get_origin in that it will never return None.
|
||||
This lets us use the same code path for handling typing.List
|
||||
that we do for handling list, which is good since they can be used
|
||||
interchangeably in annotations.
|
||||
"""
|
||||
origin = typing.get_origin(anntype)
|
||||
return anntype if origin is None else origin
|
||||
|
||||
|
||||
def _parse_annotated(anntype: Any) -> tuple[Any, IOAttrs | None]:
|
||||
"""Parse Annotated() constructs, returning annotated type & IOAttrs."""
|
||||
# If we get an Annotated[foo, bar, eep] we take
|
||||
# foo as the actual type, and we look for IOAttrs instances in
|
||||
# bar/eep to affect our behavior.
|
||||
ioattrs: IOAttrs | None = None
|
||||
if isinstance(anntype, _AnnotatedAlias):
|
||||
annargs = get_args(anntype)
|
||||
for annarg in annargs[1:]:
|
||||
if isinstance(annarg, IOAttrs):
|
||||
if ioattrs is not None:
|
||||
raise RuntimeError(
|
||||
'Multiple IOAttrs instances found for a'
|
||||
' single annotation; this is not supported.'
|
||||
)
|
||||
ioattrs = annarg
|
||||
|
||||
# I occasionally just throw a 'x' down when I mean IOAttrs('x');
|
||||
# catch these mistakes.
|
||||
elif isinstance(annarg, (str, int, float, bool)):
|
||||
raise RuntimeError(
|
||||
f'Raw {type(annarg)} found in Annotated[] entry:'
|
||||
f' {anntype}; this is probably not what you intended.'
|
||||
)
|
||||
anntype = annargs[0]
|
||||
return anntype, ioattrs
|
||||
555
dist/ba_data/python/efro/dataclassio/_inputter.py
vendored
Normal file
555
dist/ba_data/python/efro/dataclassio/_inputter.py
vendored
Normal file
|
|
@ -0,0 +1,555 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for dataclassio related to pulling data into dataclasses."""
|
||||
|
||||
# Note: We do lots of comparing of exact types here which is normally
|
||||
# frowned upon (stuff like isinstance() is usually encouraged).
|
||||
# pylint: disable=unidiomatic-typecheck
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
import dataclasses
|
||||
import typing
|
||||
import types
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
from efro.util import enum_by_value, check_utc
|
||||
from efro.dataclassio._base import (
|
||||
Codec,
|
||||
_parse_annotated,
|
||||
EXTRA_ATTRS_ATTR,
|
||||
_is_valid_for_codec,
|
||||
_get_origin,
|
||||
SIMPLE_TYPES,
|
||||
_raise_type_error,
|
||||
IOExtendedData,
|
||||
)
|
||||
from efro.dataclassio._prep import PrepSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from efro.dataclassio._base import IOAttrs
|
||||
from efro.dataclassio._outputter import _Outputter
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class _Inputter(Generic[T]):
|
||||
def __init__(
|
||||
self,
|
||||
cls: type[T],
|
||||
codec: Codec,
|
||||
coerce_to_float: bool,
|
||||
allow_unknown_attrs: bool = True,
|
||||
discard_unknown_attrs: bool = False,
|
||||
):
|
||||
self._cls = cls
|
||||
self._codec = codec
|
||||
self._coerce_to_float = coerce_to_float
|
||||
self._allow_unknown_attrs = allow_unknown_attrs
|
||||
self._discard_unknown_attrs = discard_unknown_attrs
|
||||
self._soft_default_validator: _Outputter | None = None
|
||||
|
||||
if not allow_unknown_attrs and discard_unknown_attrs:
|
||||
raise ValueError(
|
||||
'discard_unknown_attrs cannot be True'
|
||||
' when allow_unknown_attrs is False.'
|
||||
)
|
||||
|
||||
def run(self, values: dict) -> T:
|
||||
"""Do the thing."""
|
||||
|
||||
# For special extended data types, call their 'will_output' callback.
|
||||
tcls = self._cls
|
||||
if issubclass(tcls, IOExtendedData):
|
||||
tcls.will_input(values)
|
||||
|
||||
out = self._dataclass_from_input(self._cls, '', values)
|
||||
assert isinstance(out, self._cls)
|
||||
return out
|
||||
|
||||
def _value_from_input(
|
||||
self,
|
||||
cls: type,
|
||||
fieldpath: str,
|
||||
anntype: Any,
|
||||
value: Any,
|
||||
ioattrs: IOAttrs | None,
|
||||
) -> Any:
|
||||
"""Convert an assigned value to what a dataclass field expects."""
|
||||
# pylint: disable=too-many-return-statements
|
||||
# pylint: disable=too-many-branches
|
||||
|
||||
origin = _get_origin(anntype)
|
||||
|
||||
if origin is typing.Any:
|
||||
if not _is_valid_for_codec(value, self._codec):
|
||||
raise TypeError(
|
||||
f'Invalid value type for \'{fieldpath}\';'
|
||||
f' \'Any\' typed values must contain only'
|
||||
f' types directly supported by the specified'
|
||||
f' codec ({self._codec.name}); found'
|
||||
f' \'{type(value).__name__}\' which is not.'
|
||||
)
|
||||
return value
|
||||
|
||||
if origin is typing.Union or origin is types.UnionType:
|
||||
# Currently, the only unions we support are None/Value
|
||||
# (translated from Optional), which we verified on prep.
|
||||
# So let's treat this as a simple optional case.
|
||||
if value is None:
|
||||
return None
|
||||
childanntypes_l = [
|
||||
c for c in typing.get_args(anntype) if c is not type(None)
|
||||
] # noqa (pycodestyle complains about *is* with type)
|
||||
assert len(childanntypes_l) == 1
|
||||
return self._value_from_input(
|
||||
cls, fieldpath, childanntypes_l[0], value, ioattrs
|
||||
)
|
||||
|
||||
# Everything below this point assumes the annotation type resolves
|
||||
# to a concrete type. (This should have been verified at prep time).
|
||||
assert isinstance(origin, type)
|
||||
|
||||
if origin in SIMPLE_TYPES:
|
||||
if type(value) is not origin:
|
||||
# Special case: if they want to coerce ints to floats, do so.
|
||||
if (
|
||||
self._coerce_to_float
|
||||
and origin is float
|
||||
and type(value) is int
|
||||
):
|
||||
return float(value)
|
||||
_raise_type_error(fieldpath, type(value), (origin,))
|
||||
return value
|
||||
|
||||
if origin in {list, set}:
|
||||
return self._sequence_from_input(
|
||||
cls, fieldpath, anntype, value, origin, ioattrs
|
||||
)
|
||||
|
||||
if origin is tuple:
|
||||
return self._tuple_from_input(
|
||||
cls, fieldpath, anntype, value, ioattrs
|
||||
)
|
||||
|
||||
if origin is dict:
|
||||
return self._dict_from_input(
|
||||
cls, fieldpath, anntype, value, ioattrs
|
||||
)
|
||||
|
||||
if dataclasses.is_dataclass(origin):
|
||||
return self._dataclass_from_input(origin, fieldpath, value)
|
||||
|
||||
if issubclass(origin, Enum):
|
||||
return enum_by_value(origin, value)
|
||||
|
||||
if issubclass(origin, datetime.datetime):
|
||||
return self._datetime_from_input(cls, fieldpath, value, ioattrs)
|
||||
|
||||
if origin is bytes:
|
||||
return self._bytes_from_input(origin, fieldpath, value)
|
||||
|
||||
raise TypeError(
|
||||
f"Field '{fieldpath}' of type '{anntype}' is unsupported here."
|
||||
)
|
||||
|
||||
def _bytes_from_input(self, cls: type, fieldpath: str, value: Any) -> bytes:
|
||||
"""Given input data, returns bytes."""
|
||||
import base64
|
||||
|
||||
# For firestore, bytes are passed as-is. Otherwise, they're encoded
|
||||
# as base64.
|
||||
if self._codec is Codec.FIRESTORE:
|
||||
if not isinstance(value, bytes):
|
||||
raise TypeError(
|
||||
f'Expected a bytes object for {fieldpath}'
|
||||
f' on {cls.__name__}; got a {type(value)}.'
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
assert self._codec is Codec.JSON
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f'Expected a string object for {fieldpath}'
|
||||
f' on {cls.__name__}; got a {type(value)}.'
|
||||
)
|
||||
return base64.b64decode(value)
|
||||
|
||||
def _dataclass_from_input(
|
||||
self, cls: type, fieldpath: str, values: dict
|
||||
) -> Any:
|
||||
"""Given a dict, instantiates a dataclass of the given type.
|
||||
|
||||
The dict must be in the json-friendly format as emitted from
|
||||
dataclass_to_dict. This means that sequence values such as tuples or
|
||||
sets should be passed as lists, enums should be passed as their
|
||||
associated values, and nested dataclasses should be passed as dicts.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
# pylint: disable=too-many-branches
|
||||
if not isinstance(values, dict):
|
||||
raise TypeError(
|
||||
f'Expected a dict for {fieldpath} on {cls.__name__};'
|
||||
f' got a {type(values)}.'
|
||||
)
|
||||
|
||||
prep = PrepSession(explicit=False).prep_dataclass(
|
||||
cls, recursion_level=0
|
||||
)
|
||||
assert prep is not None
|
||||
|
||||
extra_attrs = {}
|
||||
|
||||
# noinspection PyDataclass
|
||||
fields = dataclasses.fields(cls)
|
||||
fields_by_name = {f.name: f for f in fields}
|
||||
|
||||
# Preprocess all fields to convert Annotated[] to contained types
|
||||
# and IOAttrs.
|
||||
parsed_field_annotations = {
|
||||
f.name: _parse_annotated(prep.annotations[f.name]) for f in fields
|
||||
}
|
||||
|
||||
# Go through all data in the input, converting it to either dataclass
|
||||
# args or extra data.
|
||||
args: dict[str, Any] = {}
|
||||
for rawkey, value in values.items():
|
||||
key = prep.storage_names_to_attr_names.get(rawkey, rawkey)
|
||||
field = fields_by_name.get(key)
|
||||
|
||||
# Store unknown attrs off to the side (or error if desired).
|
||||
if field is None:
|
||||
if self._allow_unknown_attrs:
|
||||
if self._discard_unknown_attrs:
|
||||
continue
|
||||
|
||||
# Treat this like 'Any' data; ensure that it is valid
|
||||
# raw json.
|
||||
if not _is_valid_for_codec(value, self._codec):
|
||||
raise TypeError(
|
||||
f'Unknown attr \'{key}\''
|
||||
f' on {fieldpath} contains data type(s)'
|
||||
f' not supported by the specified codec'
|
||||
f' ({self._codec.name}).'
|
||||
)
|
||||
extra_attrs[key] = value
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"'{cls.__name__}' has no '{key}' field."
|
||||
)
|
||||
else:
|
||||
fieldname = field.name
|
||||
anntype, ioattrs = parsed_field_annotations[fieldname]
|
||||
subfieldpath = (
|
||||
f'{fieldpath}.{fieldname}' if fieldpath else fieldname
|
||||
)
|
||||
args[key] = self._value_from_input(
|
||||
cls, subfieldpath, anntype, value, ioattrs
|
||||
)
|
||||
|
||||
# Go through all fields looking for any not yet present in our data.
|
||||
# If we find any such fields with a soft-default value or factory
|
||||
# defined, inject that soft value into our args.
|
||||
for key, aparsed in parsed_field_annotations.items():
|
||||
if key in args:
|
||||
continue
|
||||
ioattrs = aparsed[1]
|
||||
if ioattrs is not None and (
|
||||
ioattrs.soft_default is not ioattrs.MISSING
|
||||
or ioattrs.soft_default_factory is not ioattrs.MISSING
|
||||
):
|
||||
if ioattrs.soft_default is not ioattrs.MISSING:
|
||||
soft_default = ioattrs.soft_default
|
||||
else:
|
||||
assert callable(ioattrs.soft_default_factory)
|
||||
soft_default = ioattrs.soft_default_factory()
|
||||
args[key] = soft_default
|
||||
|
||||
# Make sure these values are valid since we didn't run
|
||||
# them through our normal input type checking.
|
||||
|
||||
self._type_check_soft_default(
|
||||
value=soft_default,
|
||||
anntype=aparsed[0],
|
||||
fieldpath=(f'{fieldpath}.{key}' if fieldpath else key),
|
||||
)
|
||||
|
||||
try:
|
||||
out = cls(**args)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f'Error instantiating class {cls.__name__}'
|
||||
f' at {fieldpath}: {exc}'
|
||||
) from exc
|
||||
if extra_attrs:
|
||||
setattr(out, EXTRA_ATTRS_ATTR, extra_attrs)
|
||||
return out
|
||||
|
||||
def _type_check_soft_default(
|
||||
self, value: Any, anntype: Any, fieldpath: str
|
||||
) -> None:
|
||||
from efro.dataclassio._outputter import _Outputter
|
||||
|
||||
# Counter-intuitively, we create an outputter as part of
|
||||
# our inputter. Soft-default values are already internal types;
|
||||
# we need to make sure they can go out from there.
|
||||
if self._soft_default_validator is None:
|
||||
self._soft_default_validator = _Outputter(
|
||||
obj=None,
|
||||
create=False,
|
||||
codec=self._codec,
|
||||
coerce_to_float=self._coerce_to_float,
|
||||
)
|
||||
self._soft_default_validator.soft_default_check(
|
||||
value=value, anntype=anntype, fieldpath=fieldpath
|
||||
)
|
||||
|
||||
def _dict_from_input(
|
||||
self,
|
||||
cls: type,
|
||||
fieldpath: str,
|
||||
anntype: Any,
|
||||
value: Any,
|
||||
ioattrs: IOAttrs | None,
|
||||
) -> Any:
|
||||
# pylint: disable=too-many-branches
|
||||
# pylint: disable=too-many-locals
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(
|
||||
f'Expected a dict for \'{fieldpath}\' on {cls.__name__};'
|
||||
f' got a {type(value)}.'
|
||||
)
|
||||
|
||||
childtypes = typing.get_args(anntype)
|
||||
assert len(childtypes) in (0, 2)
|
||||
|
||||
out: dict
|
||||
|
||||
# We treat 'Any' dicts simply as json; we don't do any translating.
|
||||
if not childtypes or childtypes[0] is typing.Any:
|
||||
if not isinstance(value, dict) or not _is_valid_for_codec(
|
||||
value, self._codec
|
||||
):
|
||||
raise TypeError(
|
||||
f'Got invalid value for Dict[Any, Any]'
|
||||
f' at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' all keys and values must be'
|
||||
f' compatible with the specified codec'
|
||||
f' ({self._codec.name}).'
|
||||
)
|
||||
out = value
|
||||
else:
|
||||
out = {}
|
||||
keyanntype, valanntype = childtypes
|
||||
|
||||
# Ok; we've got definite key/value types (which we verified as
|
||||
# valid during prep). Run all keys/values through it.
|
||||
|
||||
# str keys we just take directly since that's supported by json.
|
||||
if keyanntype is str:
|
||||
for key, val in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(
|
||||
f'Got invalid key type {type(key)} for'
|
||||
f' dict key at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' expected a str.'
|
||||
)
|
||||
out[key] = self._value_from_input(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
|
||||
# int keys are stored in json as str versions of themselves.
|
||||
elif keyanntype is int:
|
||||
for key, val in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(
|
||||
f'Got invalid key type {type(key)} for'
|
||||
f' dict key at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' expected a str.'
|
||||
)
|
||||
try:
|
||||
keyint = int(key)
|
||||
except ValueError as exc:
|
||||
raise TypeError(
|
||||
f'Got invalid key value {key} for'
|
||||
f' dict key at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' expected an int in string form.'
|
||||
) from exc
|
||||
out[keyint] = self._value_from_input(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
|
||||
elif issubclass(keyanntype, Enum):
|
||||
# In prep, we verified that all these enums' values have
|
||||
# the same type, so we can just look at the first to see if
|
||||
# this is a string enum or an int enum.
|
||||
enumvaltype = type(next(iter(keyanntype)).value)
|
||||
assert enumvaltype in (int, str)
|
||||
if enumvaltype is str:
|
||||
for key, val in value.items():
|
||||
try:
|
||||
enumval = enum_by_value(keyanntype, key)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f'Got invalid key value {key} for'
|
||||
f' dict key at \'{fieldpath}\''
|
||||
f' on {cls.__name__};'
|
||||
f' expected a value corresponding to'
|
||||
f' a {keyanntype}.'
|
||||
) from exc
|
||||
out[enumval] = self._value_from_input(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
else:
|
||||
for key, val in value.items():
|
||||
try:
|
||||
enumval = enum_by_value(keyanntype, int(key))
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise ValueError(
|
||||
f'Got invalid key value {key} for'
|
||||
f' dict key at \'{fieldpath}\''
|
||||
f' on {cls.__name__};'
|
||||
f' expected {keyanntype} value (though'
|
||||
f' in string form).'
|
||||
) from exc
|
||||
out[enumval] = self._value_from_input(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
|
||||
else:
|
||||
raise RuntimeError(f'Unhandled dict in-key-type {keyanntype}')
|
||||
|
||||
return out
|
||||
|
||||
def _sequence_from_input(
|
||||
self,
|
||||
cls: type,
|
||||
fieldpath: str,
|
||||
anntype: Any,
|
||||
value: Any,
|
||||
seqtype: type,
|
||||
ioattrs: IOAttrs | None,
|
||||
) -> Any:
|
||||
|
||||
# Because we are json-centric, we expect a list for all sequences.
|
||||
if type(value) is not list:
|
||||
raise TypeError(
|
||||
f'Invalid input value for "{fieldpath}";'
|
||||
f' expected a list, got a {type(value).__name__}'
|
||||
)
|
||||
|
||||
childanntypes = typing.get_args(anntype)
|
||||
|
||||
# 'Any' type children; make sure they are valid json values
|
||||
# and then just grab them.
|
||||
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
|
||||
for i, child in enumerate(value):
|
||||
if not _is_valid_for_codec(child, self._codec):
|
||||
raise TypeError(
|
||||
f'Item {i} of {fieldpath} contains'
|
||||
f' data type(s) not supported by json.'
|
||||
)
|
||||
return value if type(value) is seqtype else seqtype(value)
|
||||
|
||||
# We contain elements of some specified type.
|
||||
assert len(childanntypes) == 1
|
||||
childanntype = childanntypes[0]
|
||||
return seqtype(
|
||||
self._value_from_input(cls, fieldpath, childanntype, i, ioattrs)
|
||||
for i in value
|
||||
)
|
||||
|
||||
def _datetime_from_input(
|
||||
self, cls: type, fieldpath: str, value: Any, ioattrs: IOAttrs | None
|
||||
) -> Any:
|
||||
|
||||
# For firestore we expect a datetime object.
|
||||
if self._codec is Codec.FIRESTORE:
|
||||
# Don't compare exact type here, as firestore can give us
|
||||
# a subclass with extended precision.
|
||||
if not isinstance(value, datetime.datetime):
|
||||
raise TypeError(
|
||||
f'Invalid input value for "{fieldpath}" on'
|
||||
f' "{cls.__name__}";'
|
||||
f' expected a datetime, got a {type(value).__name__}'
|
||||
)
|
||||
check_utc(value)
|
||||
return value
|
||||
|
||||
assert self._codec is Codec.JSON
|
||||
|
||||
# We expect a list of 7 ints.
|
||||
if type(value) is not list:
|
||||
raise TypeError(
|
||||
f'Invalid input value for "{fieldpath}" on "{cls.__name__}";'
|
||||
f' expected a list, got a {type(value).__name__}'
|
||||
)
|
||||
if len(value) != 7 or not all(isinstance(x, int) for x in value):
|
||||
raise ValueError(
|
||||
f'Invalid input value for "{fieldpath}" on "{cls.__name__}";'
|
||||
f' expected a list of 7 ints, got {[type(v) for v in value]}.'
|
||||
)
|
||||
out = datetime.datetime( # type: ignore
|
||||
*value, tzinfo=datetime.timezone.utc
|
||||
)
|
||||
if ioattrs is not None:
|
||||
ioattrs.validate_datetime(out, fieldpath)
|
||||
return out
|
||||
|
||||
def _tuple_from_input(
|
||||
self,
|
||||
cls: type,
|
||||
fieldpath: str,
|
||||
anntype: Any,
|
||||
value: Any,
|
||||
ioattrs: IOAttrs | None,
|
||||
) -> Any:
|
||||
|
||||
out: list = []
|
||||
|
||||
# Because we are json-centric, we expect a list for all sequences.
|
||||
if type(value) is not list:
|
||||
raise TypeError(
|
||||
f'Invalid input value for "{fieldpath}";'
|
||||
f' expected a list, got a {type(value).__name__}'
|
||||
)
|
||||
|
||||
childanntypes = typing.get_args(anntype)
|
||||
|
||||
# We should have verified this to be non-zero at prep-time.
|
||||
assert childanntypes
|
||||
|
||||
if len(value) != len(childanntypes):
|
||||
raise ValueError(
|
||||
f'Invalid tuple input for "{fieldpath}";'
|
||||
f' expected {len(childanntypes)} values,'
|
||||
f' found {len(value)}.'
|
||||
)
|
||||
|
||||
for i, childanntype in enumerate(childanntypes):
|
||||
childval = value[i]
|
||||
|
||||
# 'Any' type children; make sure they are valid json values
|
||||
# and then just grab them.
|
||||
if childanntype is typing.Any:
|
||||
if not _is_valid_for_codec(childval, self._codec):
|
||||
raise TypeError(
|
||||
f'Item {i} of {fieldpath} contains'
|
||||
f' data type(s) not supported by json.'
|
||||
)
|
||||
out.append(childval)
|
||||
else:
|
||||
out.append(
|
||||
self._value_from_input(
|
||||
cls, fieldpath, childanntype, childval, ioattrs
|
||||
)
|
||||
)
|
||||
|
||||
assert len(out) == len(childanntypes)
|
||||
return tuple(out)
|
||||
457
dist/ba_data/python/efro/dataclassio/_outputter.py
vendored
Normal file
457
dist/ba_data/python/efro/dataclassio/_outputter.py
vendored
Normal file
|
|
@ -0,0 +1,457 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for dataclassio related to exporting data from dataclasses."""
|
||||
|
||||
# Note: We do lots of comparing of exact types here which is normally
|
||||
# frowned upon (stuff like isinstance() is usually encouraged).
|
||||
# pylint: disable=unidiomatic-typecheck
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
import dataclasses
|
||||
import typing
|
||||
import types
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from efro.util import check_utc
|
||||
from efro.dataclassio._base import (
|
||||
Codec,
|
||||
_parse_annotated,
|
||||
EXTRA_ATTRS_ATTR,
|
||||
_is_valid_for_codec,
|
||||
_get_origin,
|
||||
SIMPLE_TYPES,
|
||||
_raise_type_error,
|
||||
IOExtendedData,
|
||||
)
|
||||
from efro.dataclassio._prep import PrepSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
from efro.dataclassio._base import IOAttrs
|
||||
|
||||
|
||||
class _Outputter:
|
||||
"""Validates or exports data contained in a dataclass instance."""
|
||||
|
||||
def __init__(
|
||||
self, obj: Any, create: bool, codec: Codec, coerce_to_float: bool
|
||||
) -> None:
|
||||
self._obj = obj
|
||||
self._create = create
|
||||
self._codec = codec
|
||||
self._coerce_to_float = coerce_to_float
|
||||
|
||||
def run(self) -> Any:
|
||||
"""Do the thing."""
|
||||
|
||||
assert dataclasses.is_dataclass(self._obj)
|
||||
|
||||
# For special extended data types, call their 'will_output' callback.
|
||||
if isinstance(self._obj, IOExtendedData):
|
||||
self._obj.will_output()
|
||||
|
||||
return self._process_dataclass(type(self._obj), self._obj, '')
|
||||
|
||||
def soft_default_check(
|
||||
self, value: Any, anntype: Any, fieldpath: str
|
||||
) -> None:
|
||||
"""(internal)"""
|
||||
self._process_value(
|
||||
type(value),
|
||||
fieldpath=fieldpath,
|
||||
anntype=anntype,
|
||||
value=value,
|
||||
ioattrs=None,
|
||||
)
|
||||
|
||||
def _process_dataclass(self, cls: type, obj: Any, fieldpath: str) -> Any:
|
||||
# pylint: disable=too-many-locals
|
||||
# pylint: disable=too-many-branches
|
||||
prep = PrepSession(explicit=False).prep_dataclass(
|
||||
type(obj), recursion_level=0
|
||||
)
|
||||
assert prep is not None
|
||||
fields = dataclasses.fields(obj)
|
||||
out: dict[str, Any] | None = {} if self._create else None
|
||||
for field in fields:
|
||||
fieldname = field.name
|
||||
if fieldpath:
|
||||
subfieldpath = f'{fieldpath}.{fieldname}'
|
||||
else:
|
||||
subfieldpath = fieldname
|
||||
anntype = prep.annotations[fieldname]
|
||||
value = getattr(obj, fieldname)
|
||||
|
||||
anntype, ioattrs = _parse_annotated(anntype)
|
||||
|
||||
# If we're not storing default values for this fella,
|
||||
# we can skip all output processing if we've got a default value.
|
||||
if ioattrs is not None and not ioattrs.store_default:
|
||||
# If both soft_defaults and regular field defaults
|
||||
# are present we want to go with soft_defaults since
|
||||
# those same values would be re-injected when reading
|
||||
# the same data back in if we've omitted the field.
|
||||
default_factory: Any = field.default_factory
|
||||
if ioattrs.soft_default is not ioattrs.MISSING:
|
||||
if ioattrs.soft_default == value:
|
||||
continue
|
||||
elif ioattrs.soft_default_factory is not ioattrs.MISSING:
|
||||
assert callable(ioattrs.soft_default_factory)
|
||||
if ioattrs.soft_default_factory() == value:
|
||||
continue
|
||||
elif field.default is not dataclasses.MISSING:
|
||||
if field.default == value:
|
||||
continue
|
||||
elif default_factory is not dataclasses.MISSING:
|
||||
if default_factory() == value:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Field {fieldname} of {cls.__name__} has'
|
||||
f' no source of default values; store_default=False'
|
||||
f' cannot be set for it. (AND THIS SHOULD HAVE BEEN'
|
||||
f' CAUGHT IN PREP!)'
|
||||
)
|
||||
|
||||
outvalue = self._process_value(
|
||||
cls, subfieldpath, anntype, value, ioattrs
|
||||
)
|
||||
if self._create:
|
||||
assert out is not None
|
||||
storagename = (
|
||||
fieldname
|
||||
if (ioattrs is None or ioattrs.storagename is None)
|
||||
else ioattrs.storagename
|
||||
)
|
||||
out[storagename] = outvalue
|
||||
|
||||
# If there's extra-attrs stored on us, check/include them.
|
||||
extra_attrs = getattr(obj, EXTRA_ATTRS_ATTR, None)
|
||||
if isinstance(extra_attrs, dict):
|
||||
if not _is_valid_for_codec(extra_attrs, self._codec):
|
||||
raise TypeError(
|
||||
f'Extra attrs on \'{fieldpath}\' contains data type(s)'
|
||||
f' not supported by \'{self._codec.value}\' codec:'
|
||||
f' {extra_attrs}.'
|
||||
)
|
||||
if self._create:
|
||||
assert out is not None
|
||||
out.update(extra_attrs)
|
||||
return out
|
||||
|
||||
def _process_value(
|
||||
self,
|
||||
cls: type,
|
||||
fieldpath: str,
|
||||
anntype: Any,
|
||||
value: Any,
|
||||
ioattrs: IOAttrs | None,
|
||||
) -> Any:
|
||||
# pylint: disable=too-many-return-statements
|
||||
# pylint: disable=too-many-branches
|
||||
# pylint: disable=too-many-statements
|
||||
|
||||
origin = _get_origin(anntype)
|
||||
|
||||
if origin is typing.Any:
|
||||
if not _is_valid_for_codec(value, self._codec):
|
||||
raise TypeError(
|
||||
f'Invalid value type for \'{fieldpath}\';'
|
||||
f" 'Any' typed values must contain types directly"
|
||||
f' supported by the specified codec ({self._codec.name});'
|
||||
f' found \'{type(value).__name__}\' which is not.'
|
||||
)
|
||||
return value if self._create else None
|
||||
|
||||
if origin is typing.Union or origin is types.UnionType:
|
||||
# Currently, the only unions we support are None/Value
|
||||
# (translated from Optional), which we verified on prep.
|
||||
# So let's treat this as a simple optional case.
|
||||
if value is None:
|
||||
return None
|
||||
childanntypes_l = [
|
||||
c for c in typing.get_args(anntype) if c is not type(None)
|
||||
] # noqa (pycodestyle complains about *is* with type)
|
||||
assert len(childanntypes_l) == 1
|
||||
return self._process_value(
|
||||
cls, fieldpath, childanntypes_l[0], value, ioattrs
|
||||
)
|
||||
|
||||
# Everything below this point assumes the annotation type resolves
|
||||
# to a concrete type. (This should have been verified at prep time).
|
||||
assert isinstance(origin, type)
|
||||
|
||||
# For simple flat types, look for exact matches:
|
||||
if origin in SIMPLE_TYPES:
|
||||
if type(value) is not origin:
|
||||
# Special case: if they want to coerce ints to floats, do so.
|
||||
if (
|
||||
self._coerce_to_float
|
||||
and origin is float
|
||||
and type(value) is int
|
||||
):
|
||||
return float(value) if self._create else None
|
||||
_raise_type_error(fieldpath, type(value), (origin,))
|
||||
return value if self._create else None
|
||||
|
||||
if origin is tuple:
|
||||
if not isinstance(value, tuple):
|
||||
raise TypeError(
|
||||
f'Expected a tuple for {fieldpath};'
|
||||
f' found a {type(value)}'
|
||||
)
|
||||
childanntypes = typing.get_args(anntype)
|
||||
|
||||
# We should have verified this was non-zero at prep-time
|
||||
assert childanntypes
|
||||
if len(value) != len(childanntypes):
|
||||
raise TypeError(
|
||||
f'Tuple at {fieldpath} contains'
|
||||
f' {len(value)} values; type specifies'
|
||||
f' {len(childanntypes)}.'
|
||||
)
|
||||
if self._create:
|
||||
return [
|
||||
self._process_value(
|
||||
cls, fieldpath, childanntypes[i], x, ioattrs
|
||||
)
|
||||
for i, x in enumerate(value)
|
||||
]
|
||||
for i, x in enumerate(value):
|
||||
self._process_value(
|
||||
cls, fieldpath, childanntypes[i], x, ioattrs
|
||||
)
|
||||
return None
|
||||
|
||||
if origin is list:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(
|
||||
f'Expected a list for {fieldpath};'
|
||||
f' found a {type(value)}'
|
||||
)
|
||||
childanntypes = typing.get_args(anntype)
|
||||
|
||||
# 'Any' type children; make sure they are valid values for
|
||||
# the specified codec.
|
||||
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
|
||||
for i, child in enumerate(value):
|
||||
if not _is_valid_for_codec(child, self._codec):
|
||||
raise TypeError(
|
||||
f'Item {i} of {fieldpath} contains'
|
||||
f' data type(s) not supported by the specified'
|
||||
f' codec ({self._codec.name}).'
|
||||
)
|
||||
# Hmm; should we do a copy here?
|
||||
return value if self._create else None
|
||||
|
||||
# We contain elements of some specified type.
|
||||
assert len(childanntypes) == 1
|
||||
if self._create:
|
||||
return [
|
||||
self._process_value(
|
||||
cls, fieldpath, childanntypes[0], x, ioattrs
|
||||
)
|
||||
for x in value
|
||||
]
|
||||
for x in value:
|
||||
self._process_value(
|
||||
cls, fieldpath, childanntypes[0], x, ioattrs
|
||||
)
|
||||
return None
|
||||
|
||||
if origin is set:
|
||||
if not isinstance(value, set):
|
||||
raise TypeError(
|
||||
f'Expected a set for {fieldpath};' f' found a {type(value)}'
|
||||
)
|
||||
childanntypes = typing.get_args(anntype)
|
||||
|
||||
# 'Any' type children; make sure they are valid Any values.
|
||||
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
|
||||
for child in value:
|
||||
if not _is_valid_for_codec(child, self._codec):
|
||||
raise TypeError(
|
||||
f'Set at {fieldpath} contains'
|
||||
f' data type(s) not supported by the'
|
||||
f' specified codec ({self._codec.name}).'
|
||||
)
|
||||
return list(value) if self._create else None
|
||||
|
||||
# We contain elements of some specified type.
|
||||
assert len(childanntypes) == 1
|
||||
if self._create:
|
||||
# Note: we output json-friendly values so this becomes
|
||||
# a list.
|
||||
return [
|
||||
self._process_value(
|
||||
cls, fieldpath, childanntypes[0], x, ioattrs
|
||||
)
|
||||
for x in value
|
||||
]
|
||||
for x in value:
|
||||
self._process_value(
|
||||
cls, fieldpath, childanntypes[0], x, ioattrs
|
||||
)
|
||||
return None
|
||||
|
||||
if origin is dict:
|
||||
return self._process_dict(cls, fieldpath, anntype, value, ioattrs)
|
||||
|
||||
if dataclasses.is_dataclass(origin):
|
||||
if not isinstance(value, origin):
|
||||
raise TypeError(
|
||||
f'Expected a {origin} for {fieldpath};'
|
||||
f' found a {type(value)}.'
|
||||
)
|
||||
return self._process_dataclass(cls, value, fieldpath)
|
||||
|
||||
if issubclass(origin, Enum):
|
||||
if not isinstance(value, origin):
|
||||
raise TypeError(
|
||||
f'Expected a {origin} for {fieldpath};'
|
||||
f' found a {type(value)}.'
|
||||
)
|
||||
# At prep-time we verified that these enums had valid value
|
||||
# types, so we can blindly return it here.
|
||||
return value.value if self._create else None
|
||||
|
||||
if issubclass(origin, datetime.datetime):
|
||||
if not isinstance(value, origin):
|
||||
raise TypeError(
|
||||
f'Expected a {origin} for {fieldpath};'
|
||||
f' found a {type(value)}.'
|
||||
)
|
||||
check_utc(value)
|
||||
if ioattrs is not None:
|
||||
ioattrs.validate_datetime(value, fieldpath)
|
||||
if self._codec is Codec.FIRESTORE:
|
||||
return value
|
||||
assert self._codec is Codec.JSON
|
||||
return (
|
||||
[
|
||||
value.year,
|
||||
value.month,
|
||||
value.day,
|
||||
value.hour,
|
||||
value.minute,
|
||||
value.second,
|
||||
value.microsecond,
|
||||
]
|
||||
if self._create
|
||||
else None
|
||||
)
|
||||
|
||||
if origin is bytes:
|
||||
return self._process_bytes(cls, fieldpath, value)
|
||||
|
||||
raise TypeError(
|
||||
f"Field '{fieldpath}' of type '{anntype}' is unsupported here."
|
||||
)
|
||||
|
||||
def _process_bytes(self, cls: type, fieldpath: str, value: bytes) -> Any:
|
||||
import base64
|
||||
|
||||
if not isinstance(value, bytes):
|
||||
raise TypeError(
|
||||
f'Expected bytes for {fieldpath} on {cls.__name__};'
|
||||
f' found a {type(value)}.'
|
||||
)
|
||||
|
||||
if not self._create:
|
||||
return None
|
||||
|
||||
# In JSON we convert to base64, but firestore directly supports bytes.
|
||||
if self._codec is Codec.JSON:
|
||||
return base64.b64encode(value).decode()
|
||||
|
||||
assert self._codec is Codec.FIRESTORE
|
||||
return value
|
||||
|
||||
def _process_dict(
|
||||
self,
|
||||
cls: type,
|
||||
fieldpath: str,
|
||||
anntype: Any,
|
||||
value: dict,
|
||||
ioattrs: IOAttrs | None,
|
||||
) -> Any:
|
||||
# pylint: disable=too-many-branches
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(
|
||||
f'Expected a dict for {fieldpath};' f' found a {type(value)}.'
|
||||
)
|
||||
childtypes = typing.get_args(anntype)
|
||||
assert len(childtypes) in (0, 2)
|
||||
|
||||
# We treat 'Any' dicts simply as json; we don't do any translating.
|
||||
if not childtypes or childtypes[0] is typing.Any:
|
||||
if not isinstance(value, dict) or not _is_valid_for_codec(
|
||||
value, self._codec
|
||||
):
|
||||
raise TypeError(
|
||||
f'Invalid value for Dict[Any, Any]'
|
||||
f' at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' all keys and values must be directly compatible'
|
||||
f' with the specified codec ({self._codec.name})'
|
||||
f' when dict type is Any.'
|
||||
)
|
||||
return value if self._create else None
|
||||
|
||||
# Ok; we've got a definite key type (which we verified as valid
|
||||
# during prep). Make sure all keys match it.
|
||||
out: dict | None = {} if self._create else None
|
||||
keyanntype, valanntype = childtypes
|
||||
|
||||
# str keys we just export directly since that's supported by json.
|
||||
if keyanntype is str:
|
||||
for key, val in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(
|
||||
f'Got invalid key type {type(key)} for'
|
||||
f' dict key at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' expected {keyanntype}.'
|
||||
)
|
||||
outval = self._process_value(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
if self._create:
|
||||
assert out is not None
|
||||
out[key] = outval
|
||||
|
||||
# int keys are stored as str versions of themselves.
|
||||
elif keyanntype is int:
|
||||
for key, val in value.items():
|
||||
if not isinstance(key, int):
|
||||
raise TypeError(
|
||||
f'Got invalid key type {type(key)} for'
|
||||
f' dict key at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' expected an int.'
|
||||
)
|
||||
outval = self._process_value(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
if self._create:
|
||||
assert out is not None
|
||||
out[str(key)] = outval
|
||||
|
||||
elif issubclass(keyanntype, Enum):
|
||||
for key, val in value.items():
|
||||
if not isinstance(key, keyanntype):
|
||||
raise TypeError(
|
||||
f'Got invalid key type {type(key)} for'
|
||||
f' dict key at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' expected a {keyanntype}.'
|
||||
)
|
||||
outval = self._process_value(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
if self._create:
|
||||
assert out is not None
|
||||
out[str(key.value)] = outval
|
||||
else:
|
||||
raise RuntimeError(f'Unhandled dict out-key-type {keyanntype}')
|
||||
|
||||
return out
|
||||
115
dist/ba_data/python/efro/dataclassio/_pathcapture.py
vendored
Normal file
115
dist/ba_data/python/efro/dataclassio/_pathcapture.py
vendored
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality related to capturing nested dataclass paths."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic
|
||||
|
||||
from efro.dataclassio._base import _parse_annotated, _get_origin
|
||||
from efro.dataclassio._prep import PrepSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class _PathCapture:
|
||||
"""Utility for obtaining dataclass storage paths in a type safe way."""
|
||||
|
||||
def __init__(self, obj: Any, pathparts: list[str] | None = None):
|
||||
self._is_dataclass = dataclasses.is_dataclass(obj)
|
||||
if pathparts is None:
|
||||
pathparts = []
|
||||
self._cls = obj if isinstance(obj, type) else type(obj)
|
||||
self._pathparts = pathparts
|
||||
|
||||
def __getattr__(self, name: str) -> _PathCapture:
|
||||
|
||||
# We only allow diving into sub-objects if we are a dataclass.
|
||||
if not self._is_dataclass:
|
||||
raise TypeError(
|
||||
f"Field path cannot include attribute '{name}' "
|
||||
f'under parent {self._cls}; parent types must be dataclasses.'
|
||||
)
|
||||
|
||||
prep = PrepSession(explicit=False).prep_dataclass(
|
||||
self._cls, recursion_level=0
|
||||
)
|
||||
assert prep is not None
|
||||
try:
|
||||
anntype = prep.annotations[name]
|
||||
except KeyError as exc:
|
||||
raise AttributeError(f'{type(self)} has no {name} field.') from exc
|
||||
anntype, ioattrs = _parse_annotated(anntype)
|
||||
storagename = (
|
||||
name
|
||||
if (ioattrs is None or ioattrs.storagename is None)
|
||||
else ioattrs.storagename
|
||||
)
|
||||
origin = _get_origin(anntype)
|
||||
return _PathCapture(origin, pathparts=self._pathparts + [storagename])
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
"""The final output path."""
|
||||
return '.'.join(self._pathparts)
|
||||
|
||||
|
||||
class DataclassFieldLookup(Generic[T]):
|
||||
"""Get info about nested dataclass fields in type-safe way."""
|
||||
|
||||
def __init__(self, cls: type[T]) -> None:
|
||||
self.cls = cls
|
||||
|
||||
def path(self, callback: Callable[[T], Any]) -> str:
|
||||
"""Look up a path on child dataclass fields.
|
||||
|
||||
example:
|
||||
DataclassFieldLookup(MyType).path(lambda obj: obj.foo.bar)
|
||||
|
||||
The above example will return the string 'foo.bar' or something
|
||||
like 'f.b' if the dataclasses have custom storage names set.
|
||||
It will also be static-type-checked, triggering an error if
|
||||
MyType.foo.bar is not a valid path. Note, however, that the
|
||||
callback technically allows any return value but only nested
|
||||
dataclasses and their fields will succeed.
|
||||
"""
|
||||
|
||||
# We tell the type system that we are returning an instance
|
||||
# of our class, which allows it to perform type checking on
|
||||
# member lookups. In reality, however, we are providing a
|
||||
# special object which captures path lookups, so we can build
|
||||
# a string from them.
|
||||
if not TYPE_CHECKING:
|
||||
out = callback(_PathCapture(self.cls))
|
||||
if not isinstance(out, _PathCapture):
|
||||
raise TypeError(
|
||||
f'Expected a valid path under'
|
||||
f' the provided object; got a {type(out)}.'
|
||||
)
|
||||
return out.path
|
||||
return ''
|
||||
|
||||
def paths(self, callback: Callable[[T], list[Any]]) -> list[str]:
|
||||
"""Look up multiple paths on child dataclass fields.
|
||||
|
||||
Functionality is identical to path() but for multiple paths at once.
|
||||
|
||||
example:
|
||||
DataclassFieldLookup(MyType).paths(lambda obj: [obj.foo, obj.bar])
|
||||
"""
|
||||
outvals: list[str] = []
|
||||
if not TYPE_CHECKING:
|
||||
outs = callback(_PathCapture(self.cls))
|
||||
assert isinstance(outs, list)
|
||||
for out in outs:
|
||||
if not isinstance(out, _PathCapture):
|
||||
raise TypeError(
|
||||
f'Expected a valid path under'
|
||||
f' the provided object; got a {type(out)}.'
|
||||
)
|
||||
outvals.append(out.path)
|
||||
return outvals
|
||||
459
dist/ba_data/python/efro/dataclassio/_prep.py
vendored
Normal file
459
dist/ba_data/python/efro/dataclassio/_prep.py
vendored
Normal file
|
|
@ -0,0 +1,459 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for prepping types for use with dataclassio."""
|
||||
|
||||
# Note: We do lots of comparing of exact types here which is normally
|
||||
# frowned upon (stuff like isinstance() is usually encouraged).
|
||||
# pylint: disable=unidiomatic-typecheck
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
import dataclasses
|
||||
import typing
|
||||
import types
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, TypeVar, get_type_hints
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from efro.dataclassio._base import _parse_annotated, _get_origin, SIMPLE_TYPES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
from efro.dataclassio._base import IOAttrs
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# How deep we go when prepping nested types
|
||||
# (basically for detecting recursive types)
|
||||
MAX_RECURSION = 10
|
||||
|
||||
# Attr name for data we store on dataclass types that have been prepped.
|
||||
PREP_ATTR = '_DCIOPREP'
|
||||
|
||||
# We also store the prep-session while the prep is in progress.
|
||||
# (necessary to support recursive types).
|
||||
PREP_SESSION_ATTR = '_DCIOPREPSESSION'
|
||||
|
||||
|
||||
def ioprep(cls: type, globalns: dict | None = None) -> None:
|
||||
"""Prep a dataclass type for use with this module's functionality.
|
||||
|
||||
Prepping ensures that all types contained in a data class as well as
|
||||
the usage of said types are supported by this module and pre-builds
|
||||
necessary constructs needed for encoding/decoding/etc.
|
||||
|
||||
Prepping will happen on-the-fly as needed, but a warning will be
|
||||
emitted in such cases, as it is better to explicitly prep all used types
|
||||
early in a process to ensure any invalid types or configuration are caught
|
||||
immediately.
|
||||
|
||||
Prepping a dataclass involves evaluating its type annotations, which,
|
||||
as of PEP 563, are stored simply as strings. This evaluation is done
|
||||
with localns set to the class dict (so that types defined in the class
|
||||
can be used) and globalns set to the containing module's class.
|
||||
It is possible to override globalns for special cases such as when
|
||||
prepping happens as part of an execed string instead of within a
|
||||
module.
|
||||
"""
|
||||
PrepSession(explicit=True, globalns=globalns).prep_dataclass(
|
||||
cls, recursion_level=0
|
||||
)
|
||||
|
||||
|
||||
def ioprepped(cls: type[T]) -> type[T]:
|
||||
"""Class decorator for easily prepping a dataclass at definition time.
|
||||
|
||||
Note that in some cases it may not be possible to prep a dataclass
|
||||
immediately (such as when its type annotations refer to forward-declared
|
||||
types). In these cases, dataclass_prep() should be explicitly called for
|
||||
the class as soon as possible; ideally at module import time to expose any
|
||||
errors as early as possible in execution.
|
||||
"""
|
||||
ioprep(cls)
|
||||
return cls
|
||||
|
||||
|
||||
def will_ioprep(cls: type[T]) -> type[T]:
|
||||
"""Class decorator hinting that we will prep a class later.
|
||||
|
||||
In some cases (such as recursive types) we cannot use the @ioprepped
|
||||
decorator and must instead call ioprep() explicitly later. However,
|
||||
some of our custom pylint checking behaves differently when the
|
||||
@ioprepped decorator is present, in that case requiring type annotations
|
||||
to be present and not simply forward declared under an "if TYPE_CHECKING"
|
||||
block. (since they are used at runtime).
|
||||
|
||||
The @will_ioprep decorator triggers the same pylint behavior
|
||||
differences as @ioprepped (which are necessary for the later ioprep() call
|
||||
to work correctly) but without actually running any prep itself.
|
||||
"""
|
||||
return cls
|
||||
|
||||
|
||||
def is_ioprepped_dataclass(obj: Any) -> bool:
|
||||
"""Return whether the obj is an ioprepped dataclass type or instance."""
|
||||
cls = obj if isinstance(obj, type) else type(obj)
|
||||
return dataclasses.is_dataclass(cls) and hasattr(cls, PREP_ATTR)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PrepData:
|
||||
"""Data we prepare and cache for a class during prep.
|
||||
|
||||
This data is used as part of the encoding/decoding/validating process.
|
||||
"""
|
||||
|
||||
# Resolved annotation data with 'live' classes.
|
||||
annotations: dict[str, Any]
|
||||
|
||||
# Map of storage names to attr names.
|
||||
storage_names_to_attr_names: dict[str, str]
|
||||
|
||||
|
||||
class PrepSession:
|
||||
"""Context for a prep."""
|
||||
|
||||
def __init__(self, explicit: bool, globalns: dict | None = None):
|
||||
self.explicit = explicit
|
||||
self.globalns = globalns
|
||||
|
||||
def prep_dataclass(
|
||||
self, cls: type, recursion_level: int
|
||||
) -> PrepData | None:
|
||||
"""Run prep on a dataclass if necessary and return its prep data.
|
||||
|
||||
The only case where this will return None is for recursive types
|
||||
if the type is already being prepped higher in the call order.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
# pylint: disable=too-many-branches
|
||||
|
||||
# We should only need to do this once per dataclass.
|
||||
existing_data = getattr(cls, PREP_ATTR, None)
|
||||
if existing_data is not None:
|
||||
assert isinstance(existing_data, PrepData)
|
||||
return existing_data
|
||||
|
||||
# Sanity check.
|
||||
# Note that we now support recursive types via the PREP_SESSION_ATTR,
|
||||
# so we theoretically shouldn't run into this this.
|
||||
if recursion_level > MAX_RECURSION:
|
||||
raise RuntimeError('Max recursion exceeded.')
|
||||
|
||||
# We should only be passed classes which are dataclasses.
|
||||
if not isinstance(cls, type) or not dataclasses.is_dataclass(cls):
|
||||
raise TypeError(f'Passed arg {cls} is not a dataclass type.')
|
||||
|
||||
# Add a pointer to the prep-session while doing the prep.
|
||||
# This way we can ignore types that we're already in the process
|
||||
# of prepping and can support recursive types.
|
||||
existing_prep = getattr(cls, PREP_SESSION_ATTR, None)
|
||||
if existing_prep is not None:
|
||||
if existing_prep is self:
|
||||
return None
|
||||
# We shouldn't need to support failed preps
|
||||
# or preps from multiple threads at once.
|
||||
raise RuntimeError('Found existing in-progress prep.')
|
||||
setattr(cls, PREP_SESSION_ATTR, self)
|
||||
|
||||
# Generate a warning on non-explicit preps; we prefer prep to
|
||||
# happen explicitly at runtime so errors can be detected early on.
|
||||
if not self.explicit:
|
||||
logging.warning(
|
||||
'efro.dataclassio: implicitly prepping dataclass: %s.'
|
||||
' It is highly recommended to explicitly prep dataclasses'
|
||||
' as soon as possible after definition (via'
|
||||
' efro.dataclassio.ioprep() or the'
|
||||
' @efro.dataclassio.ioprepped decorator).',
|
||||
cls,
|
||||
)
|
||||
|
||||
try:
|
||||
# NOTE: Now passing the class' __dict__ (vars()) as locals
|
||||
# which allows us to pick up nested classes, etc.
|
||||
resolved_annotations = get_type_hints(
|
||||
cls,
|
||||
localns=vars(cls),
|
||||
globalns=self.globalns,
|
||||
include_extras=True,
|
||||
)
|
||||
# pylint: enable=unexpected-keyword-arg
|
||||
except Exception as exc:
|
||||
raise TypeError(
|
||||
f'dataclassio prep for {cls} failed with error: {exc}.'
|
||||
f' Make sure all types used in annotations are defined'
|
||||
f' at the module or class level or add them as part of an'
|
||||
f' explicit prep call.'
|
||||
) from exc
|
||||
|
||||
# noinspection PyDataclass
|
||||
fields = dataclasses.fields(cls)
|
||||
fields_by_name = {f.name: f for f in fields}
|
||||
|
||||
all_storage_names: set[str] = set()
|
||||
storage_names_to_attr_names: dict[str, str] = {}
|
||||
|
||||
# Ok; we've resolved actual types for this dataclass.
|
||||
# now recurse through them, verifying that we support all contained
|
||||
# types and prepping any contained dataclass types.
|
||||
for attrname, anntype in resolved_annotations.items():
|
||||
|
||||
anntype, ioattrs = _parse_annotated(anntype)
|
||||
|
||||
# If we found attached IOAttrs data, make sure it contains
|
||||
# valid values for the field it is attached to.
|
||||
if ioattrs is not None:
|
||||
ioattrs.validate_for_field(cls, fields_by_name[attrname])
|
||||
if ioattrs.storagename is not None:
|
||||
storagename = ioattrs.storagename
|
||||
storage_names_to_attr_names[ioattrs.storagename] = attrname
|
||||
else:
|
||||
storagename = attrname
|
||||
else:
|
||||
storagename = attrname
|
||||
|
||||
# Make sure we don't have any clashes in our storage names.
|
||||
if storagename in all_storage_names:
|
||||
raise TypeError(
|
||||
f'Multiple attrs on {cls} are using'
|
||||
f' storage-name \'{storagename}\''
|
||||
)
|
||||
all_storage_names.add(storagename)
|
||||
|
||||
self.prep_type(
|
||||
cls,
|
||||
attrname,
|
||||
anntype,
|
||||
ioattrs=ioattrs,
|
||||
recursion_level=recursion_level + 1,
|
||||
)
|
||||
|
||||
# Success! Store our resolved stuff with the class and we're done.
|
||||
prepdata = PrepData(
|
||||
annotations=resolved_annotations,
|
||||
storage_names_to_attr_names=storage_names_to_attr_names,
|
||||
)
|
||||
setattr(cls, PREP_ATTR, prepdata)
|
||||
|
||||
# Clear our prep-session tag.
|
||||
assert getattr(cls, PREP_SESSION_ATTR, None) is self
|
||||
delattr(cls, PREP_SESSION_ATTR)
|
||||
return prepdata
|
||||
|
||||
def prep_type(
|
||||
self,
|
||||
cls: type,
|
||||
attrname: str,
|
||||
anntype: Any,
|
||||
ioattrs: IOAttrs | None,
|
||||
recursion_level: int,
|
||||
) -> None:
|
||||
"""Run prep on a dataclass."""
|
||||
# pylint: disable=too-many-return-statements
|
||||
# pylint: disable=too-many-branches
|
||||
# pylint: disable=too-many-statements
|
||||
|
||||
if recursion_level > MAX_RECURSION:
|
||||
raise RuntimeError('Max recursion exceeded.')
|
||||
|
||||
origin = _get_origin(anntype)
|
||||
|
||||
if origin is typing.Union or origin is types.UnionType:
|
||||
self.prep_union(
|
||||
cls, attrname, anntype, recursion_level=recursion_level + 1
|
||||
)
|
||||
return
|
||||
|
||||
if anntype is typing.Any:
|
||||
return
|
||||
|
||||
# Everything below this point assumes the annotation type resolves
|
||||
# to a concrete type.
|
||||
if not isinstance(origin, type):
|
||||
raise TypeError(
|
||||
f'Unsupported type found for \'{attrname}\' on {cls}:'
|
||||
f' {anntype}'
|
||||
)
|
||||
|
||||
# If a soft_default value/factory was passed, we do some basic
|
||||
# type checking on the top-level value here. We also run full
|
||||
# recursive validation on values later during inputting, but this
|
||||
# should catch at least some errors early on, which can be
|
||||
# useful since soft_defaults are not static type checked.
|
||||
if ioattrs is not None:
|
||||
have_soft_default = False
|
||||
soft_default: Any = None
|
||||
if ioattrs.soft_default is not ioattrs.MISSING:
|
||||
have_soft_default = True
|
||||
soft_default = ioattrs.soft_default
|
||||
elif ioattrs.soft_default_factory is not ioattrs.MISSING:
|
||||
assert callable(ioattrs.soft_default_factory)
|
||||
have_soft_default = True
|
||||
soft_default = ioattrs.soft_default_factory()
|
||||
|
||||
# Do a simple type check for the top level to catch basic
|
||||
# soft_default mismatches early; full check will happen at
|
||||
# input time.
|
||||
if have_soft_default:
|
||||
if not isinstance(soft_default, origin):
|
||||
raise TypeError(
|
||||
f'{cls} attr {attrname} has type {origin}'
|
||||
f' but soft_default value is type {type(soft_default)}'
|
||||
)
|
||||
|
||||
if origin in SIMPLE_TYPES:
|
||||
return
|
||||
|
||||
# For sets and lists, check out their single contained type (if any).
|
||||
if origin in (list, set):
|
||||
childtypes = typing.get_args(anntype)
|
||||
if len(childtypes) == 0:
|
||||
# This is equivalent to Any; nothing else needs checking.
|
||||
return
|
||||
if len(childtypes) > 1:
|
||||
raise TypeError(
|
||||
f'Unrecognized typing arg count {len(childtypes)}'
|
||||
f" for {anntype} attr '{attrname}' on {cls}"
|
||||
)
|
||||
self.prep_type(
|
||||
cls,
|
||||
attrname,
|
||||
childtypes[0],
|
||||
ioattrs=None,
|
||||
recursion_level=recursion_level + 1,
|
||||
)
|
||||
return
|
||||
|
||||
if origin is dict:
|
||||
childtypes = typing.get_args(anntype)
|
||||
assert len(childtypes) in (0, 2)
|
||||
|
||||
# For key types we support Any, str, int,
|
||||
# and Enums with uniform str/int values.
|
||||
if not childtypes or childtypes[0] is typing.Any:
|
||||
# 'Any' needs no further checks (just checked per-instance).
|
||||
pass
|
||||
elif childtypes[0] in (str, int):
|
||||
# str and int are all good as keys.
|
||||
pass
|
||||
elif issubclass(childtypes[0], Enum):
|
||||
# Allow our usual str or int enum types as keys.
|
||||
self.prep_enum(childtypes[0])
|
||||
else:
|
||||
raise TypeError(
|
||||
f'Dict key type {childtypes[0]} for \'{attrname}\''
|
||||
f' on {cls.__name__} is not supported by dataclassio.'
|
||||
)
|
||||
|
||||
# For value types we support any of our normal types.
|
||||
if not childtypes or _get_origin(childtypes[1]) is typing.Any:
|
||||
# 'Any' needs no further checks (just checked per-instance).
|
||||
pass
|
||||
else:
|
||||
self.prep_type(
|
||||
cls,
|
||||
attrname,
|
||||
childtypes[1],
|
||||
ioattrs=None,
|
||||
recursion_level=recursion_level + 1,
|
||||
)
|
||||
return
|
||||
|
||||
# For Tuples, simply check individual member types.
|
||||
# (and, for now, explicitly disallow zero member types or usage
|
||||
# of ellipsis)
|
||||
if origin is tuple:
|
||||
childtypes = typing.get_args(anntype)
|
||||
if not childtypes:
|
||||
raise TypeError(
|
||||
f'Tuple at \'{attrname}\''
|
||||
f' has no type args; dataclassio requires type args.'
|
||||
)
|
||||
if childtypes[-1] is ...:
|
||||
raise TypeError(
|
||||
f'Found ellipsis as part of type for'
|
||||
f' \'{attrname}\' on {cls.__name__};'
|
||||
f' these are not'
|
||||
f' supported by dataclassio.'
|
||||
)
|
||||
for childtype in childtypes:
|
||||
self.prep_type(
|
||||
cls,
|
||||
attrname,
|
||||
childtype,
|
||||
ioattrs=None,
|
||||
recursion_level=recursion_level + 1,
|
||||
)
|
||||
return
|
||||
|
||||
if issubclass(origin, Enum):
|
||||
self.prep_enum(origin)
|
||||
return
|
||||
|
||||
# We allow datetime objects (and google's extended subclass of them
|
||||
# used in firestore, which is why we don't look for exact type here).
|
||||
if issubclass(origin, datetime.datetime):
|
||||
return
|
||||
|
||||
if dataclasses.is_dataclass(origin):
|
||||
self.prep_dataclass(origin, recursion_level=recursion_level + 1)
|
||||
return
|
||||
|
||||
if origin is bytes:
|
||||
return
|
||||
|
||||
raise TypeError(
|
||||
f"Attr '{attrname}' on {cls.__name__} contains"
|
||||
f" type '{anntype}'"
|
||||
f' which is not supported by dataclassio.'
|
||||
)
|
||||
|
||||
def prep_union(
|
||||
self, cls: type, attrname: str, anntype: Any, recursion_level: int
|
||||
) -> None:
|
||||
"""Run prep on a Union type."""
|
||||
typeargs = typing.get_args(anntype)
|
||||
if (
|
||||
len(typeargs) != 2
|
||||
or len([c for c in typeargs if c is type(None)]) != 1
|
||||
): # noqa
|
||||
raise TypeError(
|
||||
f'Union {anntype} for attr \'{attrname}\' on'
|
||||
f' {cls.__name__} is not supported by dataclassio;'
|
||||
f' only 2 member Unions with one type being None'
|
||||
f' are supported.'
|
||||
)
|
||||
for childtype in typeargs:
|
||||
self.prep_type(
|
||||
cls,
|
||||
attrname,
|
||||
childtype,
|
||||
None,
|
||||
recursion_level=recursion_level + 1,
|
||||
)
|
||||
|
||||
def prep_enum(self, enumtype: type[Enum]) -> None:
|
||||
"""Run prep on an enum type."""
|
||||
|
||||
valtype: Any = None
|
||||
|
||||
# We currently support enums with str or int values; fail if we
|
||||
# find any others.
|
||||
for enumval in enumtype:
|
||||
if not isinstance(enumval.value, (str, int)):
|
||||
raise TypeError(
|
||||
f'Enum value {enumval} has value type'
|
||||
f' {type(enumval.value)}; only str and int is'
|
||||
f' supported by dataclassio.'
|
||||
)
|
||||
if valtype is None:
|
||||
valtype = type(enumval.value)
|
||||
else:
|
||||
if type(enumval.value) is not valtype:
|
||||
raise TypeError(
|
||||
f'Enum type {enumtype} has multiple'
|
||||
f' value types; dataclassio requires'
|
||||
f' them to be uniform.'
|
||||
)
|
||||
71
dist/ba_data/python/efro/dataclassio/extras.py
vendored
Normal file
71
dist/ba_data/python/efro/dataclassio/extras.py
vendored
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Extra rarely-needed functionality related to dataclasses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
def dataclass_diff(obj1: Any, obj2: Any) -> str:
|
||||
"""Generate a string showing differences between two dataclass instances.
|
||||
|
||||
Both must be of the exact same type.
|
||||
"""
|
||||
diff = _diff(obj1, obj2, 2)
|
||||
return ' <no differences>' if diff == '' else diff
|
||||
|
||||
|
||||
class DataclassDiff:
|
||||
"""Wraps dataclass_diff() in an object for efficiency.
|
||||
|
||||
It is preferable to pass this to logging calls instead of the
|
||||
final diff string since the diff will never be generated if
|
||||
the associated logging level is not being emitted.
|
||||
"""
|
||||
|
||||
def __init__(self, obj1: Any, obj2: Any):
|
||||
self._obj1 = obj1
|
||||
self._obj2 = obj2
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return dataclass_diff(self._obj1, self._obj2)
|
||||
|
||||
|
||||
def _diff(obj1: Any, obj2: Any, indent: int) -> str:
|
||||
assert dataclasses.is_dataclass(obj1)
|
||||
assert dataclasses.is_dataclass(obj2)
|
||||
if type(obj1) is not type(obj2):
|
||||
raise TypeError(
|
||||
f'Passed objects are not of the same'
|
||||
f' type ({type(obj1)} and {type(obj2)}).'
|
||||
)
|
||||
bits: list[str] = []
|
||||
indentstr = ' ' * indent
|
||||
fields = dataclasses.fields(obj1)
|
||||
for field in fields:
|
||||
fieldname = field.name
|
||||
val1 = getattr(obj1, fieldname)
|
||||
val2 = getattr(obj2, fieldname)
|
||||
|
||||
# For nested dataclasses, dive in and do nice piecewise compares.
|
||||
if (
|
||||
dataclasses.is_dataclass(val1)
|
||||
and dataclasses.is_dataclass(val2)
|
||||
and type(val1) is type(val2)
|
||||
):
|
||||
diff = _diff(val1, val2, indent + 2)
|
||||
if diff != '':
|
||||
bits.append(f'{indentstr}{fieldname}:')
|
||||
bits.append(diff)
|
||||
|
||||
# For all else just do a single line
|
||||
# (perhaps we could improve on this for other complex types)
|
||||
else:
|
||||
if val1 != val2:
|
||||
bits.append(f'{indentstr}{fieldname}: {val1} -> {val2}')
|
||||
return '\n'.join(bits)
|
||||
349
dist/ba_data/python/efro/debug.py
vendored
Normal file
349
dist/ba_data/python/efro/debug.py
vendored
Normal file
|
|
@ -0,0 +1,349 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Utilities for debugging memory leaks or other issues.
|
||||
|
||||
IMPORTANT - these functions use the gc module which looks 'under the hood'
|
||||
at Python and sometimes returns not-fully-initialized objects, which may
|
||||
cause crashes or errors due to suddenly having references to them that they
|
||||
didn't expect, etc. See https://github.com/python/cpython/issues/59313.
|
||||
For this reason, these methods should NEVER be called in production code.
|
||||
Enable them only for debugging situations and be aware that their use may
|
||||
itself cause problems. The same is true for the gc module itself.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import types
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, TextIO
|
||||
|
||||
ABS_MAX_LEVEL = 10
|
||||
|
||||
# NOTE: In general we want this toolset to allow us to explore
|
||||
# which objects are holding references to others so we can diagnose
|
||||
# leaks/etc. It is a bit tricky to do that, however, without
|
||||
# affecting the objects we are looking at by adding temporary references
|
||||
# from module dicts, function scopes, etc. So we need to try to be
|
||||
# careful about cleaning up after ourselves and explicitly avoiding
|
||||
# returning these temporary references wherever possible.
|
||||
|
||||
# A good test is running printrefs() repeatedly on some object that is
|
||||
# known to be static. If the list of references or the ids or any
|
||||
# the listed references changes with each run, it's a good sign that
|
||||
# we're showing some temporary objects that we should be ignoring.
|
||||
|
||||
|
||||
def getobjs(
|
||||
cls: type | str, contains: str | None = None, expanded: bool = False
|
||||
) -> list[Any]:
|
||||
"""Return all garbage-collected objects matching criteria.
|
||||
|
||||
'type' can be an actual type or a string in which case objects
|
||||
whose types contain that string will be returned.
|
||||
|
||||
If 'contains' is provided, objects will be filtered to those
|
||||
containing that in their str() representations.
|
||||
"""
|
||||
|
||||
# Don't wanna return stuff waiting to be garbage-collected.
|
||||
gc.collect()
|
||||
|
||||
if not isinstance(cls, type | str):
|
||||
raise TypeError('Expected a type or string for cls')
|
||||
if not isinstance(contains, str | None):
|
||||
raise TypeError('Expected a string or None for contains')
|
||||
|
||||
allobjs = _get_all_objects(expanded=expanded)
|
||||
|
||||
if isinstance(cls, str):
|
||||
objs = [o for o in allobjs if cls in str(type(o))]
|
||||
else:
|
||||
objs = [o for o in allobjs if isinstance(o, cls)]
|
||||
if contains is not None:
|
||||
objs = [o for o in objs if contains in str(o)]
|
||||
|
||||
return objs
|
||||
|
||||
|
||||
# Recursively expand slists objects into olist, using seen to track
|
||||
# already processed objects.
|
||||
def _getr(slist: list[Any], olist: list[Any], seen: set[int]) -> None:
|
||||
for obj in slist:
|
||||
if id(obj) in seen:
|
||||
continue
|
||||
seen.add(id(obj))
|
||||
olist.append(obj)
|
||||
tll = gc.get_referents(obj)
|
||||
if tll:
|
||||
_getr(tll, olist, seen)
|
||||
|
||||
|
||||
def _get_all_objects(expanded: bool) -> list[Any]:
|
||||
"""Return an expanded list of all objects.
|
||||
|
||||
See https://utcc.utoronto.ca/~cks/space/blog/python/GetAllObjects
|
||||
"""
|
||||
gcl = gc.get_objects()
|
||||
if not expanded:
|
||||
return gcl
|
||||
olist: list[Any] = []
|
||||
seen: set[int] = set()
|
||||
# Just in case:
|
||||
seen.add(id(gcl))
|
||||
seen.add(id(olist))
|
||||
seen.add(id(seen))
|
||||
# _getr does the real work.
|
||||
_getr(gcl, olist, seen)
|
||||
return olist
|
||||
|
||||
|
||||
def getobj(objid: int, expanded: bool = False) -> Any:
|
||||
"""Return a garbage-collected object by its id.
|
||||
|
||||
Remember that this is VERY inefficient and should only ever be used
|
||||
for debugging.
|
||||
"""
|
||||
if not isinstance(objid, int):
|
||||
raise TypeError(f'Expected an int for objid; got a {type(objid)}.')
|
||||
|
||||
# Don't wanna return stuff waiting to be garbage-collected.
|
||||
gc.collect()
|
||||
|
||||
allobjs = _get_all_objects(expanded=expanded)
|
||||
for obj in allobjs:
|
||||
if id(obj) == objid:
|
||||
return obj
|
||||
raise RuntimeError(f'Object with id {objid} not found.')
|
||||
|
||||
|
||||
def getrefs(obj: Any) -> list[Any]:
|
||||
"""Given an object, return things referencing it."""
|
||||
v = vars() # Ignore ref coming from locals.
|
||||
return [o for o in gc.get_referrers(obj) if o is not v]
|
||||
|
||||
|
||||
def printfiles(file: TextIO | None = None) -> None:
|
||||
"""Print info about open files in the current app."""
|
||||
import io
|
||||
|
||||
file = sys.stderr if file is None else file
|
||||
try:
|
||||
import psutil
|
||||
except ImportError:
|
||||
print(
|
||||
"Error: printfiles requires the 'psutil' module to be installed.",
|
||||
file=file,
|
||||
)
|
||||
return
|
||||
|
||||
proc = psutil.Process()
|
||||
|
||||
# Let's grab all Python file handles so we can associate raw files
|
||||
# with their Python objects when possible.
|
||||
fileio_ids = {obj.fileno(): obj for obj in getobjs(io.FileIO)}
|
||||
textio_ids = {obj.fileno(): obj for obj in getobjs(io.TextIOWrapper)}
|
||||
|
||||
# FIXME: we could do a more limited version of this when psutil is
|
||||
# not present that simply includes Python's files.
|
||||
print('Files open by this app (not limited to Python\'s):', file=file)
|
||||
for i, ofile in enumerate(proc.open_files()):
|
||||
# Mypy doesn't know about mode apparently.
|
||||
# (and can't use type: ignore because we don't require psutil
|
||||
# and then mypy complains about unused ignore comment when its
|
||||
# not present)
|
||||
mode = getattr(ofile, 'mode')
|
||||
assert isinstance(mode, str)
|
||||
textio = textio_ids.get(ofile.fd)
|
||||
textio_s = id(textio) if textio is not None else '<not found>'
|
||||
fileio = fileio_ids.get(ofile.fd)
|
||||
fileio_s = id(fileio) if fileio is not None else '<not found>'
|
||||
print(
|
||||
f'#{i+1}: path={ofile.path!r},'
|
||||
f' fd={ofile.fd}, mode={mode!r}, TextIOWrapper={textio_s},'
|
||||
f' FileIO={fileio_s}'
|
||||
)
|
||||
|
||||
|
||||
def printrefs(
|
||||
obj: Any,
|
||||
max_level: int = 2,
|
||||
exclude_objs: list[Any] | None = None,
|
||||
expand_ids: list[int] | None = None,
|
||||
file: TextIO | None = None,
|
||||
) -> None:
|
||||
"""Print human readable list of objects referring to an object.
|
||||
|
||||
'max_level' specifies how many levels of recursion are printed.
|
||||
'exclude_objs' can be a list of exact objects to skip if found in the
|
||||
referrers list. This can be useful to avoid printing the local context
|
||||
where the object was passed in from (locals(), etc).
|
||||
'expand_ids' can be a list of object ids; if that particular object is
|
||||
found, it will always be expanded even if max_level has been reached.
|
||||
"""
|
||||
_printrefs(
|
||||
obj,
|
||||
level=0,
|
||||
max_level=max_level,
|
||||
exclude_objs=[] if exclude_objs is None else exclude_objs,
|
||||
expand_ids=[] if expand_ids is None else expand_ids,
|
||||
file=sys.stderr if file is None else file,
|
||||
)
|
||||
|
||||
|
||||
def printtypes(
|
||||
limit: int = 50, file: TextIO | None = None, expanded: bool = False
|
||||
) -> None:
|
||||
"""Print a human readable list of which types have the most instances."""
|
||||
assert limit > 0
|
||||
objtypes: dict[str, int] = {}
|
||||
gc.collect() # Recommended before get_objects().
|
||||
allobjs = _get_all_objects(expanded=expanded)
|
||||
allobjc = len(allobjs)
|
||||
for obj in allobjs:
|
||||
modname = type(obj).__module__
|
||||
tpname = type(obj).__qualname__
|
||||
if modname != 'builtins':
|
||||
tpname = f'{modname}.{tpname}'
|
||||
objtypes[tpname] = objtypes.get(tpname, 0) + 1
|
||||
|
||||
# Presumably allobjs contains stack-frame/dict type stuff
|
||||
# from this function call which in turn contain refs to allobjs.
|
||||
# Let's try to prevent these huge lists from accumulating until
|
||||
# the cyclical collector (hopefully) gets to them.
|
||||
allobjs.clear()
|
||||
del allobjs
|
||||
|
||||
print(f'Types most allocated ({allobjc} total objects):', file=file)
|
||||
for i, tpitem in enumerate(
|
||||
sorted(objtypes.items(), key=lambda x: x[1], reverse=True)[:limit]
|
||||
):
|
||||
tpname, tpval = tpitem
|
||||
percent = tpval / allobjc * 100.0
|
||||
print(f'{i+1}: {tpname}: {tpval} ({percent:.2f}%)', file=file)
|
||||
|
||||
|
||||
def printsizes(
|
||||
limit: int = 50, file: TextIO | None = None, expanded: bool = False
|
||||
) -> None:
|
||||
"""Print total allocated sizes of different types."""
|
||||
assert limit > 0
|
||||
objsizes: dict[str, int] = {}
|
||||
gc.collect() # Recommended before get_objects().
|
||||
allobjs = _get_all_objects(expanded=expanded)
|
||||
totalobjsize = 0
|
||||
|
||||
for obj in allobjs:
|
||||
modname = type(obj).__module__
|
||||
tpname = type(obj).__qualname__
|
||||
if modname != 'builtins':
|
||||
tpname = f'{modname}.{tpname}'
|
||||
objsize = sys.getsizeof(obj)
|
||||
objsizes[tpname] = objsizes.get(tpname, 0) + objsize
|
||||
totalobjsize += objsize
|
||||
|
||||
totalobjmb = totalobjsize / (1024 * 1024)
|
||||
print(
|
||||
f'Types with most allocated bytes ({totalobjmb:.2f} mb total):',
|
||||
file=file,
|
||||
)
|
||||
for i, tpitem in enumerate(
|
||||
sorted(objsizes.items(), key=lambda x: x[1], reverse=True)[:limit]
|
||||
):
|
||||
tpname, tpval = tpitem
|
||||
percent = tpval / totalobjsize * 100.0
|
||||
print(f'{i+1}: {tpname}: {tpval} ({percent:.2f}%)', file=file)
|
||||
|
||||
|
||||
def _desctype(obj: Any) -> str:
|
||||
cls = type(obj)
|
||||
if cls is types.ModuleType:
|
||||
return f'{type(obj).__name__} {obj.__name__}'
|
||||
if cls is types.MethodType:
|
||||
bnd = 'bound' if hasattr(obj, '__self__') else 'unbound'
|
||||
return f'{bnd} {type(obj).__name__} {obj.__name__}'
|
||||
return f'{type(obj).__name__}'
|
||||
|
||||
|
||||
def _desc(obj: Any) -> str:
|
||||
extra: str | None = None
|
||||
if isinstance(obj, list | tuple):
|
||||
# Print length and the first few types.
|
||||
tps = [_desctype(i) for i in obj[:3]]
|
||||
tpsj = ', '.join(tps)
|
||||
tpss = (
|
||||
f', contains [{tpsj}, ...]'
|
||||
if len(obj) > 3
|
||||
else f', contains [{tpsj}]'
|
||||
if tps
|
||||
else ''
|
||||
)
|
||||
extra = f' (len {len(obj)}{tpss})'
|
||||
elif isinstance(obj, dict):
|
||||
# If it seems to be the vars() for a type or module,
|
||||
# try to identify what.
|
||||
for ref in getrefs(obj):
|
||||
if hasattr(ref, '__dict__') and vars(ref) is obj:
|
||||
extra = f' (vars for {_desctype(ref)} @ {id(ref)})'
|
||||
|
||||
# Generic dict: print length and the first few key:type pairs.
|
||||
if extra is None:
|
||||
pairs = [
|
||||
f'{repr(n)}: {_desctype(v)}' for n, v in list(obj.items())[:3]
|
||||
]
|
||||
pairsj = ', '.join(pairs)
|
||||
pairss = (
|
||||
f', contains {{{pairsj}, ...}}'
|
||||
if len(obj) > 3
|
||||
else f', contains {{{pairsj}}}'
|
||||
if pairs
|
||||
else ''
|
||||
)
|
||||
extra = f' (len {len(obj)}{pairss})'
|
||||
if extra is None:
|
||||
extra = ''
|
||||
return f'{_desctype(obj)} @ {id(obj)}{extra}'
|
||||
|
||||
|
||||
def _printrefs(
|
||||
obj: Any,
|
||||
level: int,
|
||||
max_level: int,
|
||||
exclude_objs: list,
|
||||
expand_ids: list[int],
|
||||
file: TextIO,
|
||||
) -> None:
|
||||
ind = ' ' * level
|
||||
print(ind + _desc(obj), file=file)
|
||||
v = vars()
|
||||
if level < max_level or (id(obj) in expand_ids and level < ABS_MAX_LEVEL):
|
||||
refs = getrefs(obj)
|
||||
for ref in refs:
|
||||
|
||||
# It seems we tend to get a transient cell object with contents
|
||||
# set to obj. Would be nice to understand why that happens
|
||||
# but just ignoring it for now.
|
||||
if isinstance(ref, types.CellType) and ref.cell_contents is obj:
|
||||
continue
|
||||
|
||||
# Ignore anything we were asked to ignore.
|
||||
if exclude_objs is not None:
|
||||
if any(ref is eobj for eobj in exclude_objs):
|
||||
continue
|
||||
|
||||
# Ignore references from our locals.
|
||||
if ref is v:
|
||||
continue
|
||||
|
||||
# The 'refs' list we just made will be listed as a referrer
|
||||
# of this obj, so explicitly exclude it from the obj's listing.
|
||||
_printrefs(
|
||||
ref,
|
||||
level=level + 1,
|
||||
max_level=max_level,
|
||||
exclude_objs=exclude_objs + [refs],
|
||||
expand_ids=expand_ids,
|
||||
file=file,
|
||||
)
|
||||
36
dist/ba_data/python/efro/entity/__init__.py
vendored
Normal file
36
dist/ba_data/python/efro/entity/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Entity functionality.
|
||||
|
||||
A system for defining structured data, supporting both static and runtime
|
||||
type safety, serialization, efficient/sparse storage, per-field value
|
||||
limits, etc. This is a heavyweight option in comparison to things such as
|
||||
dataclasses, but the increased features can make the overhead worth it for
|
||||
certain use cases.
|
||||
|
||||
Advantages compared to nested dataclasses:
|
||||
- Field names separated from their data representation so can get more
|
||||
concise json data, change variable names while preserving back-compat, etc.
|
||||
- Can wrap and preserve unmapped data (so fields can be added to new versions
|
||||
of something without breaking old versions' ability to read the data)
|
||||
- Incorrectly typed data is caught at runtime (for dataclasses we rely on
|
||||
type-checking and explicit validation calls)
|
||||
|
||||
Disadvantages compared to nested dataclasses:
|
||||
- More complex to use
|
||||
- Significantly more heavyweight (roughly 10 times slower in quick tests)
|
||||
- Can't currently be initialized in constructors (this would probably require
|
||||
a Mypy plugin to do in a type-safe way)
|
||||
"""
|
||||
# pylint: disable=unused-import
|
||||
|
||||
from efro.entity._entity import EntityMixin, Entity
|
||||
from efro.entity._field import (Field, CompoundField, ListField, DictField,
|
||||
CompoundListField, CompoundDictField)
|
||||
from efro.entity._value import (
|
||||
EnumValue, OptionalEnumValue, IntValue, OptionalIntValue, StringValue,
|
||||
OptionalStringValue, BoolValue, OptionalBoolValue, FloatValue,
|
||||
OptionalFloatValue, DateTimeValue, OptionalDateTimeValue, Float3Value,
|
||||
CompoundValue)
|
||||
|
||||
from efro.entity._support import FieldInspector
|
||||
133
dist/ba_data/python/efro/entity/_base.py
vendored
Normal file
133
dist/ba_data/python/efro/entity/_base.py
vendored
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Base classes for the entity system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from efro.util import enum_by_value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Type
|
||||
|
||||
|
||||
def dict_key_to_raw(key: Any, keytype: Type) -> Any:
|
||||
"""Given a key value from the world, filter to stored key."""
|
||||
if not isinstance(key, keytype):
|
||||
raise TypeError(
|
||||
f'Invalid key type; expected {keytype}, got {type(key)}.')
|
||||
if issubclass(keytype, Enum):
|
||||
val = key.value
|
||||
# We convert int enums to string since that is what firestore supports.
|
||||
if isinstance(val, int):
|
||||
val = str(val)
|
||||
return val
|
||||
return key
|
||||
|
||||
|
||||
def dict_key_from_raw(key: Any, keytype: Type) -> Any:
|
||||
"""Given internal key, filter to world visible type."""
|
||||
if issubclass(keytype, Enum):
|
||||
# We store all enum keys as strings; if the enum uses
|
||||
# int keys, convert back.
|
||||
for enumval in keytype:
|
||||
if isinstance(enumval.value, int):
|
||||
return enum_by_value(keytype, int(key))
|
||||
break
|
||||
return enum_by_value(keytype, key)
|
||||
return key
|
||||
|
||||
|
||||
class DataHandler:
|
||||
"""Base class for anything that can wrangle entity data.
|
||||
|
||||
This contains common functionality shared by Fields and Values.
|
||||
"""
|
||||
|
||||
def get_default_data(self) -> Any:
|
||||
"""Return the default internal data value for this object.
|
||||
|
||||
This will be inserted when initing nonexistent entity data.
|
||||
"""
|
||||
raise RuntimeError(f'get_default_data() unimplemented for {self}')
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> Any:
|
||||
"""Given arbitrary input data, return valid internal data.
|
||||
|
||||
If error is True, exceptions should be thrown for any non-trivial
|
||||
mismatch (more than just int vs float/etc.). Otherwise the invalid
|
||||
data should be replaced with valid defaults and the problem noted
|
||||
via the logging module.
|
||||
The passed-in data can be modified in-place or returned as-is, or
|
||||
completely new data can be returned. Compound types are responsible
|
||||
for setting defaults and/or calling this recursively for their
|
||||
children. Data that is not used by the field (such as orphaned values
|
||||
in a dict field) can be left alone.
|
||||
|
||||
Supported types for internal data are:
|
||||
- anything that works with json (lists, dicts, bools, floats, ints,
|
||||
strings, None) - no tuples!
|
||||
- datetime.datetime objects
|
||||
"""
|
||||
del error # unused
|
||||
return data
|
||||
|
||||
def filter_output(self, data: Any) -> Any:
|
||||
"""Given valid internal data, return user-facing data.
|
||||
|
||||
Note that entity data is expected to be filtered to correctness on
|
||||
input, so if internal and extra entity data are the same type
|
||||
Value types such as Vec3 may store data internally as simple float
|
||||
tuples but return Vec3 objects to the user/etc. this is the mechanism
|
||||
by which they do so.
|
||||
"""
|
||||
return data
|
||||
|
||||
def prune_data(self, data: Any) -> bool:
|
||||
"""Prune internal data to strip out default values/etc.
|
||||
|
||||
Should return a bool indicating whether root data itself can be pruned.
|
||||
The object is responsible for pruning any sub-fields before returning.
|
||||
"""
|
||||
|
||||
|
||||
class BaseField(DataHandler):
|
||||
"""Base class for all field types."""
|
||||
|
||||
def __init__(self, d_key: str = None) -> None:
|
||||
|
||||
# Key for this field's data in parent dict/list (when applicable;
|
||||
# some fields such as the child field under a list field represent
|
||||
# more than a single field entry so this is unused)
|
||||
self.d_key = d_key
|
||||
|
||||
# IMPORTANT: this method should only be overridden in the eyes of the
|
||||
# type-checker (to specify exact return types). Subclasses should instead
|
||||
# override get_with_data() for doing the actual work, since that method
|
||||
# may sometimes be called explicitly instead of through __get__
|
||||
def __get__(self, obj: Any, type_in: Any = None) -> Any:
|
||||
if obj is None:
|
||||
# when called on the type, we return the field
|
||||
return self
|
||||
return self.get_with_data(obj.d_data)
|
||||
|
||||
# IMPORTANT: same deal as __get__() (see note above)
|
||||
def __set__(self, obj: Any, value: Any) -> None:
|
||||
assert obj is not None
|
||||
self.set_with_data(obj.d_data, value, error=True)
|
||||
|
||||
def get_with_data(self, data: Any) -> Any:
|
||||
"""Get the field value given an explicit data source."""
|
||||
assert self.d_key is not None
|
||||
return self.filter_output(data[self.d_key])
|
||||
|
||||
def set_with_data(self, data: Any, value: Any, error: bool) -> Any:
|
||||
"""Set the field value given an explicit data target.
|
||||
|
||||
If error is True, exceptions should be thrown for invalid data;
|
||||
otherwise the problem should be logged but corrected.
|
||||
"""
|
||||
assert self.d_key is not None
|
||||
data[self.d_key] = self.filter_input(value, error=error)
|
||||
222
dist/ba_data/python/efro/entity/_entity.py
vendored
Normal file
222
dist/ba_data/python/efro/entity/_entity.py
vendored
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for the actual Entity types."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from efro.entity._support import FieldInspector, BoundCompoundValue
|
||||
from efro.entity._value import CompoundValue
|
||||
from efro.json import ExtendedJSONEncoder, ExtendedJSONDecoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Dict, Any, Type, Union, Optional
|
||||
|
||||
T = TypeVar('T', bound='EntityMixin')
|
||||
|
||||
|
||||
class EntityMixin:
|
||||
"""Mixin class to add data-storage to CompoundValue, forming an Entity.
|
||||
|
||||
Distinct Entity types should inherit from this first and a CompoundValue
|
||||
(sub)type second. This order ensures that constructor arguments for this
|
||||
class are accessible on the new type.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_data: Dict[str, Any] = None,
|
||||
error: bool = True) -> None:
|
||||
super().__init__()
|
||||
if not isinstance(self, CompoundValue):
|
||||
raise RuntimeError('EntityMixin class must be combined'
|
||||
' with a CompoundValue class.')
|
||||
|
||||
# Underlying data for this entity; fields simply operate on this.
|
||||
self.d_data: Dict[str, Any] = {}
|
||||
assert isinstance(self, EntityMixin)
|
||||
self.set_data(d_data if d_data is not None else {}, error=error)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets data to default."""
|
||||
self.set_data({}, error=True)
|
||||
|
||||
def set_data(self, data: Dict, error: bool = True) -> None:
|
||||
"""Set the data for this entity and apply all value filters to it.
|
||||
|
||||
Note that it is more efficient to pass data to an Entity's constructor
|
||||
than it is to create a default Entity and then call this on it.
|
||||
"""
|
||||
assert isinstance(self, CompoundValue)
|
||||
self.d_data = self.filter_input(data, error=error)
|
||||
|
||||
def copy_data(self, target: Union[CompoundValue,
|
||||
BoundCompoundValue]) -> None:
|
||||
"""Copy data from a target Entity or compound-value.
|
||||
|
||||
This first verifies that the target has a matching set of fields
|
||||
and then copies its data into ourself. To copy data into a nested
|
||||
compound field, the assignment operator can be used.
|
||||
"""
|
||||
import copy
|
||||
from efro.entity.util import have_matching_fields
|
||||
tvalue: CompoundValue
|
||||
if isinstance(target, CompoundValue):
|
||||
tvalue = target
|
||||
elif isinstance(target, BoundCompoundValue):
|
||||
tvalue = target.d_value
|
||||
else:
|
||||
raise TypeError(
|
||||
'Target must be a CompoundValue or BoundCompoundValue')
|
||||
target_data = getattr(target, 'd_data', None)
|
||||
if target_data is None:
|
||||
raise ValueError('Target is not bound to data.')
|
||||
assert isinstance(self, CompoundValue)
|
||||
if not have_matching_fields(self, tvalue):
|
||||
raise ValueError(
|
||||
f'Fields for target {type(tvalue)} do not match ours'
|
||||
f" ({type(self)}); can't copy data.")
|
||||
self.d_data = copy.deepcopy(target_data)
|
||||
|
||||
def steal_data(self, target: EntityMixin) -> None:
|
||||
"""Steal data from another entity.
|
||||
|
||||
This is more efficient than copy_data, as data is moved instead
|
||||
of copied. However this leaves the target object in an invalid
|
||||
state, and it must no longer be used after this call.
|
||||
This can be convenient for entities to use to update themselves
|
||||
with the result of a database transaction (which generally return
|
||||
fresh entities).
|
||||
"""
|
||||
from efro.entity.util import have_matching_fields
|
||||
if not isinstance(target, EntityMixin):
|
||||
raise TypeError('EntityMixin is required.')
|
||||
assert isinstance(target, CompoundValue)
|
||||
assert isinstance(self, CompoundValue)
|
||||
if not have_matching_fields(self, target):
|
||||
raise ValueError(
|
||||
f'Fields for target {type(target)} do not match ours'
|
||||
f" ({type(self)}); can't steal data.")
|
||||
assert target.d_data is not None
|
||||
self.d_data = target.d_data
|
||||
|
||||
# Make sure target blows up if someone tries to use it.
|
||||
# noinspection PyTypeHints
|
||||
target.d_data = None # type: ignore
|
||||
|
||||
def pruned_data(self) -> Dict[str, Any]:
|
||||
"""Return a pruned version of this instance's data.
|
||||
|
||||
This varies from d_data in that values may be stripped out if
|
||||
they are equal to defaults (for fields with that option enabled).
|
||||
"""
|
||||
import copy
|
||||
data = copy.deepcopy(self.d_data)
|
||||
assert isinstance(self, CompoundValue)
|
||||
self.prune_fields_data(data)
|
||||
return data
|
||||
|
||||
def to_json_str(self,
|
||||
prune: bool = True,
|
||||
pretty: bool = False,
|
||||
sort_keys_override: Optional[bool] = None) -> str:
|
||||
"""Convert the entity to a json string.
|
||||
|
||||
This uses efro.jsontools.ExtendedJSONEncoder/Decoder
|
||||
to support data types not natively storable in json.
|
||||
Be sure to use the corresponding loading functions here for
|
||||
this same reason.
|
||||
By default, keys are sorted when pretty-printing and not otherwise,
|
||||
but this can be overridden by passing a bool as sort_keys_override.
|
||||
"""
|
||||
if prune:
|
||||
data = self.pruned_data()
|
||||
else:
|
||||
data = self.d_data
|
||||
if pretty:
|
||||
return json.dumps(
|
||||
data,
|
||||
indent=2,
|
||||
sort_keys=(sort_keys_override
|
||||
if sort_keys_override is not None else True),
|
||||
cls=ExtendedJSONEncoder)
|
||||
|
||||
# When not doing pretty, go for quick and compact.
|
||||
return json.dumps(data,
|
||||
separators=(',', ':'),
|
||||
sort_keys=(sort_keys_override if sort_keys_override
|
||||
is not None else False),
|
||||
cls=ExtendedJSONEncoder)
|
||||
|
||||
@staticmethod
|
||||
def json_loads(s: Union[str, bytes]) -> Any:
|
||||
"""Load a json string using our special extended decoder.
|
||||
|
||||
Note that this simply returns loaded json data; no
|
||||
Entities are involved.
|
||||
"""
|
||||
return json.loads(s, cls=ExtendedJSONDecoder)
|
||||
|
||||
def load_from_json_str(self,
|
||||
s: Union[str, bytes],
|
||||
error: bool = True) -> None:
|
||||
"""Set the entity's data in-place from a json string.
|
||||
|
||||
The 'error' argument determines whether Exceptions will be raised
|
||||
for invalid data values. Values will be reset/conformed to valid ones
|
||||
if error is False. Note that Exceptions will always be raised
|
||||
in the case of invalid formatted json.
|
||||
"""
|
||||
data = self.json_loads(s)
|
||||
self.set_data(data, error=error)
|
||||
|
||||
@classmethod
|
||||
def from_json_str(cls: Type[T],
|
||||
s: Union[str, bytes],
|
||||
error: bool = True) -> T:
|
||||
"""Instantiate a new instance with provided json string.
|
||||
|
||||
The 'error' argument determines whether exceptions will be raised
|
||||
on invalid data values. Values will be reset/conformed to valid ones
|
||||
if error is False. Note that exceptions will always be raised
|
||||
in the case of invalid formatted json.
|
||||
"""
|
||||
obj = cls(d_data=cls.json_loads(s), error=error)
|
||||
return obj
|
||||
|
||||
# Note: though d_fields actually returns a FieldInspector,
|
||||
# in type-checking-land we currently just say it returns self.
|
||||
# This allows the type-checker to at least validate subfield access,
|
||||
# though the types will be incorrect (values instead of inspectors).
|
||||
# This means that anything taking FieldInspectors needs to take 'Any'
|
||||
# at the moment. Hopefully we can make this cleaner via a mypy
|
||||
# plugin at some point.
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@property
|
||||
def d_fields(self: T) -> T:
|
||||
"""For accessing entity field objects (as opposed to values)."""
|
||||
...
|
||||
else:
|
||||
|
||||
@property
|
||||
def d_fields(self):
|
||||
"""For accessing entity field objects (as opposed to values)."""
|
||||
return FieldInspector(self, self, [], [])
|
||||
|
||||
|
||||
class Entity(EntityMixin, CompoundValue):
|
||||
"""A data class consisting of Fields and their underlying data.
|
||||
|
||||
Fields and Values simply define a data layout; Entities are concrete
|
||||
objects using those layouts.
|
||||
|
||||
Inherit from this class and add Fields to define a simple Entity type.
|
||||
Alternately, combine an EntityMixin with any CompoundValue child class
|
||||
to accomplish the same. The latter allows sharing CompoundValue
|
||||
layouts between different concrete Entity types. For example, a
|
||||
'Weapon' CompoundValue could be embedded as part of a 'Character'
|
||||
Entity but also exist as a distinct 'WeaponEntity' in an armory
|
||||
database.
|
||||
"""
|
||||
602
dist/ba_data/python/efro/entity/_field.py
vendored
Normal file
602
dist/ba_data/python/efro/entity/_field.py
vendored
Normal file
|
|
@ -0,0 +1,602 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Field types for the entity system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar, overload
|
||||
|
||||
# from efro.util import enum_by_value
|
||||
from efro.entity._base import BaseField, dict_key_to_raw, dict_key_from_raw
|
||||
from efro.entity._support import (BoundCompoundValue, BoundListField,
|
||||
BoundDictField, BoundCompoundListField,
|
||||
BoundCompoundDictField)
|
||||
from efro.entity.util import have_matching_fields
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Dict, Type, List, Any
|
||||
from efro.entity._value import TypedValue, CompoundValue
|
||||
|
||||
T = TypeVar('T')
|
||||
TK = TypeVar('TK')
|
||||
TC = TypeVar('TC', bound='CompoundValue')
|
||||
|
||||
|
||||
class Field(BaseField, Generic[T]):
|
||||
"""Field consisting of a single value."""
|
||||
|
||||
def __init__(self,
|
||||
d_key: str,
|
||||
value: TypedValue[T],
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(d_key)
|
||||
self.d_value = value
|
||||
self._store_default = store_default
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Field "{self.d_key}" with {self.d_value}>'
|
||||
|
||||
def get_default_data(self) -> Any:
|
||||
return self.d_value.get_default_data()
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> Any:
|
||||
return self.d_value.filter_input(data, error)
|
||||
|
||||
def filter_output(self, data: Any) -> Any:
|
||||
return self.d_value.filter_output(data)
|
||||
|
||||
def prune_data(self, data: Any) -> bool:
|
||||
return self.d_value.prune_data(data)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Use default runtime get/set but let type-checker know our types.
|
||||
# Note: we actually return a bound-field when accessed on
|
||||
# a type instead of an instance, but we don't reflect that here yet
|
||||
# (would need to write a mypy plugin so sub-field access works first)
|
||||
|
||||
@overload
|
||||
def __get__(self, obj: None, cls: Any = None) -> Field[T]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self, obj: Any, cls: Any = None) -> T:
|
||||
...
|
||||
|
||||
def __get__(self, obj: Any, cls: Any = None) -> Any:
|
||||
...
|
||||
|
||||
def __set__(self, obj: Any, value: T) -> None:
|
||||
...
|
||||
|
||||
|
||||
class CompoundField(BaseField, Generic[TC]):
|
||||
"""Field consisting of a single compound value."""
|
||||
|
||||
def __init__(self,
|
||||
d_key: str,
|
||||
value: TC,
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(d_key)
|
||||
if __debug__:
|
||||
from efro.entity._value import CompoundValue
|
||||
assert isinstance(value, CompoundValue)
|
||||
assert not hasattr(value, 'd_data')
|
||||
self.d_value = value
|
||||
self._store_default = store_default
|
||||
|
||||
def get_default_data(self) -> dict:
|
||||
return self.d_value.get_default_data()
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> dict:
|
||||
return self.d_value.filter_input(data, error)
|
||||
|
||||
def prune_data(self, data: Any) -> bool:
|
||||
return self.d_value.prune_data(data)
|
||||
|
||||
# Note:
|
||||
# Currently, to the type-checker we just return a simple instance
|
||||
# of our CompoundValue so it can properly type-check access to its
|
||||
# attrs. However at runtime we return a FieldInspector or
|
||||
# BoundCompoundField which both use magic to provide the same attrs
|
||||
# dynamically (but which the type-checker doesn't understand).
|
||||
# Perhaps at some point we can write a mypy plugin to correct this.
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def __get__(self, obj: Any, cls: Any = None) -> TC:
|
||||
...
|
||||
|
||||
# Theoretically this type-checking may be too tight;
|
||||
# we can support assigning a parent class to a child class if
|
||||
# their fields match. Not sure if that'll ever come up though;
|
||||
# gonna leave this for now as I prefer to have *some* checking.
|
||||
# Also once we get BoundCompoundValues working with mypy we'll
|
||||
# need to accept those too.
|
||||
def __set__(self: CompoundField[TC], obj: Any, value: TC) -> None:
|
||||
...
|
||||
|
||||
def get_with_data(self, data: Any) -> Any:
|
||||
assert self.d_key in data
|
||||
return BoundCompoundValue(self.d_value, data[self.d_key])
|
||||
|
||||
def set_with_data(self, data: Any, value: Any, error: bool) -> Any:
|
||||
from efro.entity._value import CompoundValue
|
||||
|
||||
# Ok here's the deal: our type checking above allows any subtype
|
||||
# of our CompoundValue in here, but we want to be more picky than
|
||||
# that. Let's check fields for equality. This way we'll allow
|
||||
# assigning something like a Carentity to a Car field
|
||||
# (where the data is the same), but won't allow assigning a Car
|
||||
# to a Vehicle field (as Car probably adds more fields).
|
||||
value1: CompoundValue
|
||||
if isinstance(value, BoundCompoundValue):
|
||||
value1 = value.d_value
|
||||
elif isinstance(value, CompoundValue):
|
||||
value1 = value
|
||||
else:
|
||||
raise ValueError(f"Can't assign from object type {type(value)}")
|
||||
dataval = getattr(value, 'd_data', None)
|
||||
if dataval is None:
|
||||
raise ValueError(f"Can't assign from unbound object {value}")
|
||||
if self.d_value.get_fields() != value1.get_fields():
|
||||
raise ValueError(f"Can't assign to {self.d_value} from"
|
||||
f' incompatible type {value.d_value}; '
|
||||
f'sub-fields do not match.')
|
||||
|
||||
# If we're allowing this to go through, we can simply copy the
|
||||
# data from the passed in value. The fields match so it should
|
||||
# be in a valid state already.
|
||||
data[self.d_key] = copy.deepcopy(dataval)
|
||||
|
||||
|
||||
class ListField(BaseField, Generic[T]):
|
||||
"""Field consisting of repeated values."""
|
||||
|
||||
def __init__(self,
|
||||
d_key: str,
|
||||
value: TypedValue[T],
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(d_key)
|
||||
self.d_value = value
|
||||
self._store_default = store_default
|
||||
|
||||
def get_default_data(self) -> list:
|
||||
return []
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> Any:
|
||||
|
||||
# If we were passed a BoundListField, operate on its raw values
|
||||
if isinstance(data, BoundListField):
|
||||
data = data.d_data
|
||||
|
||||
if not isinstance(data, list):
|
||||
if error:
|
||||
raise TypeError(f'list value expected; got {type(data)}')
|
||||
logging.error('Ignoring non-list data for %s: %s', self, data)
|
||||
data = []
|
||||
for i, entry in enumerate(data):
|
||||
data[i] = self.d_value.filter_input(entry, error=error)
|
||||
return data
|
||||
|
||||
def prune_data(self, data: Any) -> bool:
|
||||
# We never prune individual values since that would fundamentally
|
||||
# change the list, but we can prune completely if empty (and allowed).
|
||||
return not data and not self._store_default
|
||||
|
||||
# When accessed on a FieldInspector we return a sub-field FieldInspector.
|
||||
# When accessed on an instance we return a BoundListField.
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
# Access via type gives our field; via an instance gives a bound field.
|
||||
@overload
|
||||
def __get__(self, obj: None, cls: Any = None) -> ListField[T]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self, obj: Any, cls: Any = None) -> BoundListField[T]:
|
||||
...
|
||||
|
||||
def __get__(self, obj: Any, cls: Any = None) -> Any:
|
||||
...
|
||||
|
||||
# Allow setting via a raw value list or a bound list field
|
||||
@overload
|
||||
def __set__(self, obj: Any, value: List[T]) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __set__(self, obj: Any, value: BoundListField[T]) -> None:
|
||||
...
|
||||
|
||||
def __set__(self, obj: Any, value: Any) -> None:
|
||||
...
|
||||
|
||||
def get_with_data(self, data: Any) -> Any:
|
||||
return BoundListField(self, data[self.d_key])
|
||||
|
||||
|
||||
class DictField(BaseField, Generic[TK, T]):
|
||||
"""A field of values in a dict with a specified index type."""
|
||||
|
||||
def __init__(self,
|
||||
d_key: str,
|
||||
keytype: Type[TK],
|
||||
field: TypedValue[T],
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(d_key)
|
||||
self.d_value = field
|
||||
self._store_default = store_default
|
||||
self._keytype = keytype
|
||||
|
||||
def get_default_data(self) -> dict:
|
||||
return {}
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> Any:
|
||||
|
||||
# If we were passed a BoundDictField, operate on its raw values
|
||||
if isinstance(data, BoundDictField):
|
||||
data = data.d_data
|
||||
|
||||
if not isinstance(data, dict):
|
||||
if error:
|
||||
raise TypeError('dict value expected')
|
||||
logging.error('Ignoring non-dict data for %s: %s', self, data)
|
||||
data = {}
|
||||
data_out = {}
|
||||
for key, val in data.items():
|
||||
|
||||
# For enum keys, make sure its a valid enum.
|
||||
if issubclass(self._keytype, Enum):
|
||||
# Our input data can either be an enum or the underlying type.
|
||||
if isinstance(key, self._keytype):
|
||||
key = dict_key_to_raw(key, self._keytype)
|
||||
# key = key.value
|
||||
else:
|
||||
try:
|
||||
_enumval = dict_key_from_raw(key, self._keytype)
|
||||
# _enumval = enum_by_value(self._keytype, key)
|
||||
except Exception as exc:
|
||||
if error:
|
||||
raise ValueError(
|
||||
f'No enum of type {self._keytype}'
|
||||
f' exists with value {key}') from exc
|
||||
logging.error('Ignoring invalid key type for %s: %s',
|
||||
self, data)
|
||||
continue
|
||||
|
||||
# For all other keys we can check for exact types.
|
||||
elif not isinstance(key, self._keytype):
|
||||
if error:
|
||||
raise TypeError(
|
||||
f'Invalid key type; expected {self._keytype},'
|
||||
f' got {type(key)}.')
|
||||
logging.error('Ignoring invalid key type for %s: %s', self,
|
||||
data)
|
||||
continue
|
||||
|
||||
data_out[key] = self.d_value.filter_input(val, error=error)
|
||||
return data_out
|
||||
|
||||
def prune_data(self, data: Any) -> bool:
|
||||
# We never prune individual values since that would fundamentally
|
||||
# change the dict, but we can prune completely if empty (and allowed)
|
||||
return not data and not self._store_default
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
# Return our field if accessed via type and bound-dict-field
|
||||
# if via instance.
|
||||
@overload
|
||||
def __get__(self, obj: None, cls: Any = None) -> DictField[TK, T]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self, obj: Any, cls: Any = None) -> BoundDictField[TK, T]:
|
||||
...
|
||||
|
||||
def __get__(self, obj: Any, cls: Any = None) -> Any:
|
||||
...
|
||||
|
||||
# Allow setting via matching dict values or BoundDictFields
|
||||
@overload
|
||||
def __set__(self, obj: Any, value: Dict[TK, T]) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __set__(self, obj: Any, value: BoundDictField[TK, T]) -> None:
|
||||
...
|
||||
|
||||
def __set__(self, obj: Any, value: Any) -> None:
|
||||
...
|
||||
|
||||
def get_with_data(self, data: Any) -> Any:
|
||||
return BoundDictField(self._keytype, self, data[self.d_key])
|
||||
|
||||
|
||||
class CompoundListField(BaseField, Generic[TC]):
|
||||
"""A field consisting of repeated instances of a compound-value.
|
||||
|
||||
Element access returns the sub-field, allowing nested field access.
|
||||
ie: mylist[10].fieldattr = 'foo'
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_key: str,
|
||||
valuetype: TC,
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(d_key)
|
||||
self.d_value = valuetype
|
||||
|
||||
# This doesnt actually exist for us, but want the type-checker
|
||||
# to think it does (see TYPE_CHECKING note below).
|
||||
self.d_data: Any
|
||||
self._store_default = store_default
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> list:
|
||||
|
||||
if not isinstance(data, list):
|
||||
if error:
|
||||
raise TypeError('list value expected')
|
||||
logging.error('Ignoring non-list data for %s: %s', self, data)
|
||||
data = []
|
||||
assert isinstance(data, list)
|
||||
|
||||
# Ok we've got a list; now run everything in it through validation.
|
||||
for i, subdata in enumerate(data):
|
||||
data[i] = self.d_value.filter_input(subdata, error=error)
|
||||
return data
|
||||
|
||||
def get_default_data(self) -> list:
|
||||
return []
|
||||
|
||||
def prune_data(self, data: Any) -> bool:
|
||||
# Run pruning on all individual entries' data through out child field.
|
||||
# However we don't *completely* prune values from the list since that
|
||||
# would change it.
|
||||
for subdata in data:
|
||||
self.d_value.prune_fields_data(subdata)
|
||||
|
||||
# We can also optionally prune the whole list if empty and allowed.
|
||||
return not data and not self._store_default
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def __get__(self, obj: None, cls: Any = None) -> CompoundListField[TC]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self,
|
||||
obj: Any,
|
||||
cls: Any = None) -> BoundCompoundListField[TC]:
|
||||
...
|
||||
|
||||
def __get__(self, obj: Any, cls: Any = None) -> Any:
|
||||
...
|
||||
|
||||
# Note:
|
||||
# When setting the list, we tell the type-checker that we also accept
|
||||
# a raw list of CompoundValue objects, but at runtime we actually
|
||||
# always deal with BoundCompoundValue objects (see note in
|
||||
# BoundCompoundListField for why we accept CompoundValue objs)
|
||||
@overload
|
||||
def __set__(self, obj: Any, value: List[TC]) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __set__(self, obj: Any, value: BoundCompoundListField[TC]) -> None:
|
||||
...
|
||||
|
||||
def __set__(self, obj: Any, value: Any) -> None:
|
||||
...
|
||||
|
||||
def get_with_data(self, data: Any) -> Any:
|
||||
assert self.d_key in data
|
||||
return BoundCompoundListField(self, data[self.d_key])
|
||||
|
||||
def set_with_data(self, data: Any, value: Any, error: bool) -> Any:
|
||||
|
||||
# If we were passed a BoundCompoundListField,
|
||||
# simply convert it to a flat list of BoundCompoundValue objects which
|
||||
# is what we work with natively here.
|
||||
if isinstance(value, BoundCompoundListField):
|
||||
value = list(value)
|
||||
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f'CompoundListField expected list value on set;'
|
||||
f' got {type(value)}.')
|
||||
|
||||
# Allow assigning only from a sequence of our existing children.
|
||||
# (could look into expanding this to other children if we can
|
||||
# be sure the underlying data will line up; for example two
|
||||
# CompoundListFields with different child_field values should not
|
||||
# be inter-assignable.
|
||||
if not all(isinstance(i, BoundCompoundValue) for i in value):
|
||||
raise ValueError('CompoundListField assignment must be a '
|
||||
'list containing only BoundCompoundValue objs.')
|
||||
|
||||
# Make sure the data all has the same CompoundValue type and
|
||||
# compare that type against ours once to make sure its fields match.
|
||||
# (this will not allow passing CompoundValues from multiple sources
|
||||
# but I don't know if that would ever come up..)
|
||||
for i, val in enumerate(value):
|
||||
if i == 0:
|
||||
# Do the full field comparison on the first value only..
|
||||
if not have_matching_fields(val.d_value, self.d_value):
|
||||
raise ValueError(
|
||||
'CompoundListField assignment must be a '
|
||||
'list containing matching CompoundValues.')
|
||||
else:
|
||||
# For all remaining values, just ensure they match the first.
|
||||
if val.d_value is not value[0].d_value:
|
||||
raise ValueError(
|
||||
'CompoundListField assignment cannot contain '
|
||||
'multiple CompoundValue types as sources.')
|
||||
|
||||
data[self.d_key] = self.filter_input([i.d_data for i in value],
|
||||
error=error)
|
||||
|
||||
|
||||
class CompoundDictField(BaseField, Generic[TK, TC]):
|
||||
"""A field consisting of key-indexed instances of a compound-value.
|
||||
|
||||
Element access returns the sub-field, allowing nested field access.
|
||||
ie: mylist[10].fieldattr = 'foo'
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_key: str,
|
||||
keytype: Type[TK],
|
||||
valuetype: TC,
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(d_key)
|
||||
self.d_value = valuetype
|
||||
|
||||
# This doesnt actually exist for us, but want the type-checker
|
||||
# to think it does (see TYPE_CHECKING note below).
|
||||
self.d_data: Any
|
||||
|
||||
self.d_keytype = keytype
|
||||
self._store_default = store_default
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> dict:
|
||||
if not isinstance(data, dict):
|
||||
if error:
|
||||
raise TypeError('dict value expected')
|
||||
logging.error('Ignoring non-dict data for %s: %s', self, data)
|
||||
data = {}
|
||||
data_out = {}
|
||||
for key, val in data.items():
|
||||
|
||||
# For enum keys, make sure its a valid enum.
|
||||
if issubclass(self.d_keytype, Enum):
|
||||
# Our input data can either be an enum or the underlying type.
|
||||
if isinstance(key, self.d_keytype):
|
||||
key = dict_key_to_raw(key, self.d_keytype)
|
||||
# key = key.value
|
||||
else:
|
||||
try:
|
||||
_enumval = dict_key_from_raw(key, self.d_keytype)
|
||||
# _enumval = enum_by_value(self.d_keytype, key)
|
||||
except Exception as exc:
|
||||
if error:
|
||||
raise ValueError(
|
||||
f'No enum of type {self.d_keytype}'
|
||||
f' exists with value {key}') from exc
|
||||
logging.error('Ignoring invalid key type for %s: %s',
|
||||
self, data)
|
||||
continue
|
||||
|
||||
# For all other keys we can check for exact types.
|
||||
elif not isinstance(key, self.d_keytype):
|
||||
if error:
|
||||
raise TypeError(
|
||||
f'Invalid key type; expected {self.d_keytype},'
|
||||
f' got {type(key)}.')
|
||||
logging.error('Ignoring invalid key type for %s: %s', self,
|
||||
data)
|
||||
continue
|
||||
|
||||
data_out[key] = self.d_value.filter_input(val, error=error)
|
||||
return data_out
|
||||
|
||||
def get_default_data(self) -> dict:
|
||||
return {}
|
||||
|
||||
def prune_data(self, data: Any) -> bool:
|
||||
# Run pruning on all individual entries' data through our child field.
|
||||
# However we don't *completely* prune values from the list since that
|
||||
# would change it.
|
||||
for subdata in data.values():
|
||||
self.d_value.prune_fields_data(subdata)
|
||||
|
||||
# We can also optionally prune the whole list if empty and allowed.
|
||||
return not data and not self._store_default
|
||||
|
||||
# ONLY overriding these in type-checker land to clarify types.
|
||||
# (see note in BaseField)
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def __get__(self,
|
||||
obj: None,
|
||||
cls: Any = None) -> CompoundDictField[TK, TC]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self,
|
||||
obj: Any,
|
||||
cls: Any = None) -> BoundCompoundDictField[TK, TC]:
|
||||
...
|
||||
|
||||
def __get__(self, obj: Any, cls: Any = None) -> Any:
|
||||
...
|
||||
|
||||
# Note:
|
||||
# When setting the dict, we tell the type-checker that we also accept
|
||||
# a raw dict of CompoundValue objects, but at runtime we actually
|
||||
# always deal with BoundCompoundValue objects (see note in
|
||||
# BoundCompoundDictField for why we accept CompoundValue objs)
|
||||
@overload
|
||||
def __set__(self, obj: Any, value: Dict[TK, TC]) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __set__(self, obj: Any, value: BoundCompoundDictField[TK,
|
||||
TC]) -> None:
|
||||
...
|
||||
|
||||
def __set__(self, obj: Any, value: Any) -> None:
|
||||
...
|
||||
|
||||
def get_with_data(self, data: Any) -> Any:
|
||||
assert self.d_key in data
|
||||
return BoundCompoundDictField(self, data[self.d_key])
|
||||
|
||||
def set_with_data(self, data: Any, value: Any, error: bool) -> Any:
|
||||
|
||||
# If we were passed a BoundCompoundDictField,
|
||||
# simply convert it to a flat dict of BoundCompoundValue objects which
|
||||
# is what we work with natively here.
|
||||
if isinstance(value, BoundCompoundDictField):
|
||||
value = dict(value.items())
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError('CompoundDictField expected dict value on set.')
|
||||
|
||||
# Allow assigning only from a sequence of our existing children.
|
||||
# (could look into expanding this to other children if we can
|
||||
# be sure the underlying data will line up; for example two
|
||||
# CompoundListFields with different child_field values should not
|
||||
# be inter-assignable.
|
||||
if (not all(isinstance(i, BoundCompoundValue)
|
||||
for i in value.values())):
|
||||
raise ValueError('CompoundDictField assignment must be a '
|
||||
'dict containing only BoundCompoundValues.')
|
||||
|
||||
# Make sure the data all has the same CompoundValue type and
|
||||
# compare that type against ours once to make sure its fields match.
|
||||
# (this will not allow passing CompoundValues from multiple sources
|
||||
# but I don't know if that would ever come up..)
|
||||
first_value: Any = None
|
||||
for i, val in enumerate(value.values()):
|
||||
if i == 0:
|
||||
first_value = val.d_value
|
||||
# Do the full field comparison on the first value only..
|
||||
if not have_matching_fields(val.d_value, self.d_value):
|
||||
raise ValueError(
|
||||
'CompoundListField assignment must be a '
|
||||
'list containing matching CompoundValues.')
|
||||
else:
|
||||
# For all remaining values, just ensure they match the first.
|
||||
if val.d_value is not first_value:
|
||||
raise ValueError(
|
||||
'CompoundListField assignment cannot contain '
|
||||
'multiple CompoundValue types as sources.')
|
||||
|
||||
data[self.d_key] = self.filter_input(
|
||||
{key: val.d_data
|
||||
for key, val in value.items()}, error=error)
|
||||
468
dist/ba_data/python/efro/entity/_support.py
vendored
Normal file
468
dist/ba_data/python/efro/entity/_support.py
vendored
Normal file
|
|
@ -0,0 +1,468 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Various support classes for accessing data and info on fields and values."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic, overload
|
||||
|
||||
from efro.entity._base import (BaseField, dict_key_to_raw, dict_key_from_raw)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import (Optional, Tuple, Type, Any, Dict, List, Union)
|
||||
from efro.entity._value import CompoundValue
|
||||
from efro.entity._field import (ListField, DictField, CompoundListField,
|
||||
CompoundDictField)
|
||||
|
||||
T = TypeVar('T')
|
||||
TKey = TypeVar('TKey')
|
||||
TCompound = TypeVar('TCompound', bound='CompoundValue')
|
||||
TBoundList = TypeVar('TBoundList', bound='BoundCompoundListField')
|
||||
|
||||
|
||||
class BoundCompoundValue:
|
||||
"""Wraps a CompoundValue object and its entity data.
|
||||
|
||||
Allows access to its values through our own equivalent attributes.
|
||||
"""
|
||||
|
||||
def __init__(self, value: CompoundValue, d_data: Union[List[Any],
|
||||
Dict[str, Any]]):
|
||||
self.d_value: CompoundValue
|
||||
self.d_data: Union[List[Any], Dict[str, Any]]
|
||||
|
||||
# Need to use base setters to avoid triggering our own overrides.
|
||||
object.__setattr__(self, 'd_value', value)
|
||||
object.__setattr__(self, 'd_data', d_data)
|
||||
|
||||
def __eq__(self, other: Any) -> Any:
|
||||
# Allow comparing to compound and bound-compound objects.
|
||||
from efro.entity.util import compound_eq
|
||||
return compound_eq(self, other)
|
||||
|
||||
def __getattr__(self, name: str, default: Any = None) -> Any:
|
||||
# If this attribute corresponds to a field on our compound value's
|
||||
# unbound type, ask it to give us a value using our data
|
||||
d_value = type(object.__getattribute__(self, 'd_value'))
|
||||
field = getattr(d_value, name, None)
|
||||
if isinstance(field, BaseField):
|
||||
return field.get_with_data(self.d_data)
|
||||
raise AttributeError
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
# Same deal as __getattr__ basically.
|
||||
field = getattr(type(object.__getattribute__(self, 'd_value')), name,
|
||||
None)
|
||||
if isinstance(field, BaseField):
|
||||
field.set_with_data(self.d_data, value, error=True)
|
||||
return
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset this field's data to defaults."""
|
||||
value = object.__getattribute__(self, 'd_value')
|
||||
data = object.__getattribute__(self, 'd_data')
|
||||
assert isinstance(data, dict)
|
||||
|
||||
# Need to clear our dict in-place since we have no
|
||||
# access to our parent which we'd need to assign an empty one.
|
||||
data.clear()
|
||||
|
||||
# Now fill in default data.
|
||||
value.apply_fields_to_data(data, error=True)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
fstrs: List[str] = []
|
||||
for field in self.d_value.get_fields():
|
||||
try:
|
||||
fstrs.append(str(field) + '=' + repr(getattr(self, field)))
|
||||
except Exception:
|
||||
fstrs.append('FAIL' + str(field) + ' ' + str(type(self)))
|
||||
return type(self.d_value).__name__ + '(' + ', '.join(fstrs) + ')'
|
||||
|
||||
|
||||
class FieldInspector:
|
||||
"""Used for inspecting fields."""
|
||||
|
||||
def __init__(self, root: Any, obj: Any, path: List[str],
|
||||
dbpath: List[str]) -> None:
|
||||
self._root = root
|
||||
self._obj = obj
|
||||
self._path = path
|
||||
self._dbpath = dbpath
|
||||
|
||||
def __repr__(self) -> str:
|
||||
path = '.'.join(self._path)
|
||||
typename = type(self._root).__name__
|
||||
if path == '':
|
||||
return f'<FieldInspector: {typename}>'
|
||||
return f'<FieldInspector: {typename}: {path}>'
|
||||
|
||||
def __getattr__(self, name: str, default: Any = None) -> Any:
|
||||
# pylint: disable=cyclic-import
|
||||
from efro.entity._field import CompoundField
|
||||
|
||||
# If this attribute corresponds to a field on our obj's
|
||||
# unbound type, return a new inspector for it.
|
||||
if isinstance(self._obj, CompoundField):
|
||||
target = self._obj.d_value
|
||||
else:
|
||||
target = self._obj
|
||||
field = getattr(type(target), name, None)
|
||||
if isinstance(field, BaseField):
|
||||
newpath = list(self._path)
|
||||
newpath.append(name)
|
||||
newdbpath = list(self._dbpath)
|
||||
assert field.d_key is not None
|
||||
newdbpath.append(field.d_key)
|
||||
return FieldInspector(self._root, field, newpath, newdbpath)
|
||||
raise AttributeError
|
||||
|
||||
def get_root(self) -> Any:
|
||||
"""Return the root object this inspector is targeting."""
|
||||
return self._root
|
||||
|
||||
def get_path(self) -> List[str]:
|
||||
"""Return the python path components of this inspector."""
|
||||
return self._path
|
||||
|
||||
def get_db_path(self) -> List[str]:
|
||||
"""Return the database path components of this inspector."""
|
||||
return self._dbpath
|
||||
|
||||
|
||||
class BoundListField(Generic[T]):
|
||||
"""ListField bound to data; used for accessing field values."""
|
||||
|
||||
def __init__(self, field: ListField[T], d_data: List[Any]):
|
||||
self.d_field = field
|
||||
assert isinstance(d_data, list)
|
||||
self.d_data = d_data
|
||||
self._i = 0
|
||||
|
||||
def __eq__(self, other: Any) -> Any:
|
||||
# Just convert us into a regular list and run a compare with that.
|
||||
flattened = [
|
||||
self.d_field.d_value.filter_output(value) for value in self.d_data
|
||||
]
|
||||
return flattened == other
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '[' + ', '.join(
|
||||
repr(self.d_field.d_value.filter_output(i))
|
||||
for i in self.d_data) + ']'
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.d_data)
|
||||
|
||||
def __iter__(self) -> Any:
|
||||
self._i = 0
|
||||
return self
|
||||
|
||||
def append(self, val: T) -> None:
|
||||
"""Append the provided value to the list."""
|
||||
self.d_data.append(self.d_field.d_value.filter_input(val, error=True))
|
||||
|
||||
def __next__(self) -> T:
|
||||
if self._i < len(self.d_data):
|
||||
self._i += 1
|
||||
val: T = self.d_field.d_value.filter_output(self.d_data[self._i -
|
||||
1])
|
||||
return val
|
||||
raise StopIteration
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> T:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: slice) -> List[T]:
|
||||
...
|
||||
|
||||
def __getitem__(self, key: Any) -> Any:
|
||||
if isinstance(key, slice):
|
||||
dofilter = self.d_field.d_value.filter_output
|
||||
return [
|
||||
dofilter(self.d_data[i])
|
||||
for i in range(*key.indices(len(self)))
|
||||
]
|
||||
assert isinstance(key, int)
|
||||
return self.d_field.d_value.filter_output(self.d_data[key])
|
||||
|
||||
def __setitem__(self, key: int, value: T) -> None:
|
||||
if not isinstance(key, int):
|
||||
raise TypeError('Expected int index.')
|
||||
self.d_data[key] = self.d_field.d_value.filter_input(value, error=True)
|
||||
|
||||
|
||||
class BoundDictField(Generic[TKey, T]):
|
||||
"""DictField bound to its data; used for accessing its values."""
|
||||
|
||||
def __init__(self, keytype: Type[TKey], field: DictField[TKey, T],
|
||||
d_data: Dict[TKey, T]):
|
||||
self._keytype = keytype
|
||||
self.d_field = field
|
||||
assert isinstance(d_data, dict)
|
||||
self.d_data = d_data
|
||||
|
||||
def __eq__(self, other: Any) -> Any:
|
||||
# Just convert us into a regular dict and run a compare with that.
|
||||
flattened = {
|
||||
key: self.d_field.d_value.filter_output(value)
|
||||
for key, value in self.d_data.items()
|
||||
}
|
||||
return flattened == other
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '{' + ', '.join(
|
||||
repr(dict_key_from_raw(key, self._keytype)) + ': ' +
|
||||
repr(self.d_field.d_value.filter_output(val))
|
||||
for key, val in self.d_data.items()) + '}'
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.d_data)
|
||||
|
||||
def __getitem__(self, key: TKey) -> T:
|
||||
keyfilt = dict_key_to_raw(key, self._keytype)
|
||||
typedval: T = self.d_field.d_value.filter_output(self.d_data[keyfilt])
|
||||
return typedval
|
||||
|
||||
def get(self, key: TKey, default: Optional[T] = None) -> Optional[T]:
|
||||
"""Get a value if present, or a default otherwise."""
|
||||
keyfilt = dict_key_to_raw(key, self._keytype)
|
||||
if keyfilt not in self.d_data:
|
||||
return default
|
||||
typedval: T = self.d_field.d_value.filter_output(self.d_data[keyfilt])
|
||||
return typedval
|
||||
|
||||
def __setitem__(self, key: TKey, value: T) -> None:
|
||||
keyfilt = dict_key_to_raw(key, self._keytype)
|
||||
self.d_data[keyfilt] = self.d_field.d_value.filter_input(value,
|
||||
error=True)
|
||||
|
||||
def __contains__(self, key: TKey) -> bool:
|
||||
keyfilt = dict_key_to_raw(key, self._keytype)
|
||||
return keyfilt in self.d_data
|
||||
|
||||
def __delitem__(self, key: TKey) -> None:
|
||||
keyfilt = dict_key_to_raw(key, self._keytype)
|
||||
del self.d_data[keyfilt]
|
||||
|
||||
def keys(self) -> List[TKey]:
|
||||
"""Return a list of our keys."""
|
||||
return [
|
||||
dict_key_from_raw(k, self._keytype) for k in self.d_data.keys()
|
||||
]
|
||||
|
||||
def values(self) -> List[T]:
|
||||
"""Return a list of our values."""
|
||||
return [
|
||||
self.d_field.d_value.filter_output(value)
|
||||
for value in self.d_data.values()
|
||||
]
|
||||
|
||||
def items(self) -> List[Tuple[TKey, T]]:
|
||||
"""Return a list of item/value pairs."""
|
||||
return [(dict_key_from_raw(key, self._keytype),
|
||||
self.d_field.d_value.filter_output(value))
|
||||
for key, value in self.d_data.items()]
|
||||
|
||||
|
||||
class BoundCompoundListField(Generic[TCompound]):
|
||||
"""A CompoundListField bound to its entity sub-data."""
|
||||
|
||||
def __init__(self, field: CompoundListField[TCompound], d_data: List[Any]):
|
||||
self.d_field = field
|
||||
self.d_data = d_data
|
||||
self._i = 0
|
||||
|
||||
def __eq__(self, other: Any) -> Any:
|
||||
from efro.entity.util import have_matching_fields
|
||||
|
||||
# We can only be compared to other bound-compound-fields
|
||||
if not isinstance(other, BoundCompoundListField):
|
||||
return NotImplemented
|
||||
|
||||
# If our compound values have differing fields, we're unequal.
|
||||
if not have_matching_fields(self.d_field.d_value,
|
||||
other.d_field.d_value):
|
||||
return False
|
||||
|
||||
# Ok our data schemas match; now just compare our data..
|
||||
return self.d_data == other.d_data
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.d_data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '[' + ', '.join(
|
||||
repr(BoundCompoundValue(self.d_field.d_value, i))
|
||||
for i in self.d_data) + ']'
|
||||
|
||||
# Note: to the type checker our gets/sets simply deal with CompoundValue
|
||||
# objects so the type-checker can cleanly handle their sub-fields.
|
||||
# However at runtime we deal in BoundCompoundValue objects which use magic
|
||||
# to tie the CompoundValue object to its data but which the type checker
|
||||
# can't understand.
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> TCompound:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: slice) -> List[TCompound]:
|
||||
...
|
||||
|
||||
def __getitem__(self, key: Any) -> Any:
|
||||
...
|
||||
|
||||
def __next__(self) -> TCompound:
|
||||
...
|
||||
|
||||
def append(self) -> TCompound:
|
||||
"""Append and return a new field entry to the array."""
|
||||
...
|
||||
else:
|
||||
|
||||
def __getitem__(self, key: Any) -> Any:
|
||||
if isinstance(key, slice):
|
||||
return [
|
||||
BoundCompoundValue(self.d_field.d_value, self.d_data[i])
|
||||
for i in range(*key.indices(len(self)))
|
||||
]
|
||||
assert isinstance(key, int)
|
||||
return BoundCompoundValue(self.d_field.d_value, self.d_data[key])
|
||||
|
||||
def __next__(self):
|
||||
if self._i < len(self.d_data):
|
||||
self._i += 1
|
||||
return BoundCompoundValue(self.d_field.d_value,
|
||||
self.d_data[self._i - 1])
|
||||
raise StopIteration
|
||||
|
||||
def append(self) -> Any:
|
||||
"""Append and return a new field entry to the array."""
|
||||
# push the entity default into data and then let it fill in
|
||||
# any children/etc.
|
||||
self.d_data.append(
|
||||
self.d_field.d_value.filter_input(
|
||||
self.d_field.d_value.get_default_data(), error=True))
|
||||
return BoundCompoundValue(self.d_field.d_value, self.d_data[-1])
|
||||
|
||||
def __iter__(self: TBoundList) -> TBoundList:
|
||||
self._i = 0
|
||||
return self
|
||||
|
||||
|
||||
class BoundCompoundDictField(Generic[TKey, TCompound]):
|
||||
"""A CompoundDictField bound to its entity sub-data."""
|
||||
|
||||
def __init__(self, field: CompoundDictField[TKey, TCompound],
|
||||
d_data: Dict[Any, Any]):
|
||||
self.d_field = field
|
||||
self.d_data = d_data
|
||||
|
||||
def __eq__(self, other: Any) -> Any:
|
||||
from efro.entity.util import have_matching_fields
|
||||
|
||||
# We can only be compared to other bound-compound-fields
|
||||
if not isinstance(other, BoundCompoundDictField):
|
||||
return NotImplemented
|
||||
|
||||
# If our compound values have differing fields, we're unequal.
|
||||
if not have_matching_fields(self.d_field.d_value,
|
||||
other.d_field.d_value):
|
||||
return False
|
||||
|
||||
# Ok our data schemas match; now just compare our data..
|
||||
return self.d_data == other.d_data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '{' + ', '.join(
|
||||
repr(key) + ': ' +
|
||||
repr(BoundCompoundValue(self.d_field.d_value, value))
|
||||
for key, value in self.d_data.items()) + '}'
|
||||
|
||||
# In the typechecker's eyes, gets/sets on us simply deal in
|
||||
# CompoundValue object. This allows type-checking to work nicely
|
||||
# for its sub-fields.
|
||||
# However in real-life we return BoundCompoundValues which use magic
|
||||
# to tie the CompoundValue to its data (but which the typechecker
|
||||
# would not be able to make sense of)
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def get(self, key: TKey) -> Optional[TCompound]:
|
||||
"""Return a value if present; otherwise None."""
|
||||
|
||||
def __getitem__(self, key: TKey) -> TCompound:
|
||||
...
|
||||
|
||||
def values(self) -> List[TCompound]:
|
||||
"""Return a list of our values."""
|
||||
|
||||
def items(self) -> List[Tuple[TKey, TCompound]]:
|
||||
"""Return key/value pairs for all dict entries."""
|
||||
|
||||
def add(self, key: TKey) -> TCompound:
|
||||
"""Add an entry into the dict, returning it.
|
||||
|
||||
Any existing value is replaced."""
|
||||
|
||||
else:
|
||||
|
||||
def get(self, key):
|
||||
"""return a value if present; otherwise None."""
|
||||
keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
|
||||
data = self.d_data.get(keyfilt)
|
||||
if data is not None:
|
||||
return BoundCompoundValue(self.d_field.d_value, data)
|
||||
return None
|
||||
|
||||
def __getitem__(self, key):
|
||||
keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
|
||||
return BoundCompoundValue(self.d_field.d_value,
|
||||
self.d_data[keyfilt])
|
||||
|
||||
def values(self):
|
||||
"""Return a list of our values."""
|
||||
return list(
|
||||
BoundCompoundValue(self.d_field.d_value, i)
|
||||
for i in self.d_data.values())
|
||||
|
||||
def items(self):
|
||||
"""Return key/value pairs for all dict entries."""
|
||||
return [(dict_key_from_raw(key, self.d_field.d_keytype),
|
||||
BoundCompoundValue(self.d_field.d_value, value))
|
||||
for key, value in self.d_data.items()]
|
||||
|
||||
def add(self, key: TKey) -> TCompound:
|
||||
"""Add an entry into the dict, returning it.
|
||||
|
||||
Any existing value is replaced."""
|
||||
keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
|
||||
|
||||
# Push the entity default into data and then let it fill in
|
||||
# any children/etc.
|
||||
self.d_data[keyfilt] = (self.d_field.d_value.filter_input(
|
||||
self.d_field.d_value.get_default_data(), error=True))
|
||||
return BoundCompoundValue(self.d_field.d_value,
|
||||
self.d_data[keyfilt])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.d_data)
|
||||
|
||||
def __contains__(self, key: TKey) -> bool:
|
||||
keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
|
||||
return keyfilt in self.d_data
|
||||
|
||||
def __delitem__(self, key: TKey) -> None:
|
||||
keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
|
||||
del self.d_data[keyfilt]
|
||||
|
||||
def keys(self) -> List[TKey]:
|
||||
"""Return a list of our keys."""
|
||||
return [
|
||||
dict_key_from_raw(k, self.d_field.d_keytype)
|
||||
for k in self.d_data.keys()
|
||||
]
|
||||
537
dist/ba_data/python/efro/entity/_value.py
vendored
Normal file
537
dist/ba_data/python/efro/entity/_value.py
vendored
Normal file
|
|
@ -0,0 +1,537 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Value types for the entity system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import inspect
|
||||
import logging
|
||||
from collections import abc
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic
|
||||
# Our Pylint class_generics_filter gives us a false-positive unused-import.
|
||||
from typing import Tuple, Optional # pylint: disable=W0611
|
||||
|
||||
from efro.entity._base import DataHandler, BaseField
|
||||
from efro.entity.util import compound_eq
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional, Set, List, Dict, Any, Type
|
||||
|
||||
T = TypeVar('T')
|
||||
TE = TypeVar('TE', bound=Enum)
|
||||
|
||||
_sanity_tested_types: Set[Type] = set()
|
||||
_type_field_cache: Dict[Type, Dict[str, BaseField]] = {}
|
||||
|
||||
|
||||
class TypedValue(DataHandler, Generic[T]):
|
||||
"""Base class for all value types dealing with a single data type."""
|
||||
|
||||
|
||||
class SimpleValue(TypedValue[T]):
|
||||
"""Standard base class for simple single-value types.
|
||||
|
||||
This class provides enough functionality to handle most simple
|
||||
types such as int/float/etc without too many subclass overrides.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
default: T,
|
||||
store_default: bool,
|
||||
target_type: Type = None,
|
||||
convert_source_types: Tuple[Type, ...] = (),
|
||||
allow_none: bool = False) -> None:
|
||||
"""Init the value field.
|
||||
|
||||
If store_default is False, the field value will not be included
|
||||
in final entity data if it is a default value. Be sure to set
|
||||
this to True for any fields that will be used for server-side
|
||||
queries so they are included in indexing.
|
||||
target_type and convert_source_types are used in the default
|
||||
filter_input implementation; if passed in data's type is present
|
||||
in convert_source_types, a target_type will be instantiated
|
||||
using it. (allows for simple conversions to bool, int, etc)
|
||||
Data will also be allowed through untouched if it matches target_type.
|
||||
(types needing further introspection should override filter_input).
|
||||
Lastly, the value of allow_none is also used in filter_input for
|
||||
whether values of None should be allowed.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._store_default = store_default
|
||||
self._target_type = target_type
|
||||
self._convert_source_types = convert_source_types
|
||||
self._allow_none = allow_none
|
||||
|
||||
# We store _default_data in our internal data format so need
|
||||
# to run user-facing values through our input filter.
|
||||
# Make sure we do this last since filter_input depends on above vals.
|
||||
self._default_data: T = self.filter_input(default, error=True)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self._target_type is not None:
|
||||
return f'<Value of type {self._target_type.__name__}>'
|
||||
return '<Value of unknown type>'
|
||||
|
||||
def get_default_data(self) -> Any:
|
||||
return self._default_data
|
||||
|
||||
def prune_data(self, data: Any) -> bool:
|
||||
return not self._store_default and data == self._default_data
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> Any:
|
||||
|
||||
# Let data pass through untouched if its already our target type
|
||||
if self._target_type is not None:
|
||||
if isinstance(data, self._target_type):
|
||||
return data
|
||||
|
||||
# ...and also if its None and we're into that sort of thing.
|
||||
if self._allow_none and data is None:
|
||||
return data
|
||||
|
||||
# If its one of our convertible types, convert.
|
||||
if (self._convert_source_types
|
||||
and isinstance(data, self._convert_source_types)):
|
||||
assert self._target_type is not None
|
||||
return self._target_type(data)
|
||||
if error:
|
||||
errmsg = (f'value of type {self._target_type} or None expected'
|
||||
if self._allow_none else
|
||||
f'value of type {self._target_type} expected')
|
||||
errmsg += f'; got {type(data)}'
|
||||
raise TypeError(errmsg)
|
||||
errmsg = f'Ignoring incompatible data for {self};'
|
||||
errmsg += (f' expected {self._target_type} or None;'
|
||||
if self._allow_none else f'expected {self._target_type};')
|
||||
errmsg += f' got {type(data)}'
|
||||
logging.error(errmsg)
|
||||
return self.get_default_data()
|
||||
|
||||
|
||||
class StringValue(SimpleValue[str]):
|
||||
"""Value consisting of a single string."""
|
||||
|
||||
def __init__(self, default: str = '', store_default: bool = True) -> None:
|
||||
super().__init__(default, store_default, str)
|
||||
|
||||
|
||||
class OptionalStringValue(SimpleValue[Optional[str]]):
|
||||
"""Value consisting of a single string or None."""
|
||||
|
||||
def __init__(self,
|
||||
default: Optional[str] = None,
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(default, store_default, str, allow_none=True)
|
||||
|
||||
|
||||
class BoolValue(SimpleValue[bool]):
|
||||
"""Value consisting of a single bool."""
|
||||
|
||||
def __init__(self,
|
||||
default: bool = False,
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(default, store_default, bool, (int, float))
|
||||
|
||||
|
||||
class OptionalBoolValue(SimpleValue[Optional[bool]]):
|
||||
"""Value consisting of a single bool or None."""
|
||||
|
||||
def __init__(self,
|
||||
default: Optional[bool] = None,
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(default,
|
||||
store_default,
|
||||
bool, (int, float),
|
||||
allow_none=True)
|
||||
|
||||
|
||||
def verify_time_input(data: Any, error: bool, allow_none: bool) -> Any:
|
||||
"""Checks input data for time values."""
|
||||
pytz_utc: Any
|
||||
|
||||
# We don't *require* pytz since it must be installed through pip
|
||||
# but it is used by firestore client for its date values
|
||||
# (in which case it should be installed as a dependency anyway).
|
||||
try:
|
||||
import pytz
|
||||
pytz_utc = pytz.utc
|
||||
except ModuleNotFoundError:
|
||||
pytz_utc = None
|
||||
|
||||
# Filter unallowed None values.
|
||||
if not allow_none and data is None:
|
||||
if error:
|
||||
raise ValueError('datetime value cannot be None')
|
||||
logging.error('ignoring datetime value of None')
|
||||
data = (None if allow_none else datetime.datetime.now(
|
||||
datetime.timezone.utc))
|
||||
|
||||
# Parent filter_input does what we need, but let's just make
|
||||
# sure we *only* accept datetime values that know they're UTC.
|
||||
elif (isinstance(data, datetime.datetime)
|
||||
and data.tzinfo is not datetime.timezone.utc
|
||||
and (pytz_utc is None or data.tzinfo is not pytz_utc)):
|
||||
if error:
|
||||
raise ValueError(
|
||||
'datetime values must have timezone set as timezone.utc')
|
||||
logging.error(
|
||||
'ignoring datetime value without timezone.utc set: %s %s',
|
||||
type(datetime.timezone.utc), type(data.tzinfo))
|
||||
data = (None if allow_none else datetime.datetime.now(
|
||||
datetime.timezone.utc))
|
||||
return data
|
||||
|
||||
|
||||
class DateTimeValue(SimpleValue[datetime.datetime]):
|
||||
"""Value consisting of a datetime.datetime object.
|
||||
|
||||
The default value for this is always the current time in UTC.
|
||||
"""
|
||||
|
||||
def __init__(self, store_default: bool = True) -> None:
|
||||
# Pass dummy datetime value as default just to satisfy constructor;
|
||||
# we override get_default_data though so this doesn't get used.
|
||||
dummy_default = datetime.datetime.now(datetime.timezone.utc)
|
||||
super().__init__(dummy_default, store_default, datetime.datetime)
|
||||
|
||||
def get_default_data(self) -> Any:
|
||||
# For this class we don't use a static default value;
|
||||
# default is always now.
|
||||
return datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> Any:
|
||||
data = verify_time_input(data, error, allow_none=False)
|
||||
return super().filter_input(data, error)
|
||||
|
||||
|
||||
class OptionalDateTimeValue(SimpleValue[Optional[datetime.datetime]]):
|
||||
"""Value consisting of a datetime.datetime object or None."""
|
||||
|
||||
def __init__(self, store_default: bool = True) -> None:
|
||||
super().__init__(None,
|
||||
store_default,
|
||||
datetime.datetime,
|
||||
allow_none=True)
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> Any:
|
||||
data = verify_time_input(data, error, allow_none=True)
|
||||
return super().filter_input(data, error)
|
||||
|
||||
|
||||
class IntValue(SimpleValue[int]):
|
||||
"""Value consisting of a single int."""
|
||||
|
||||
def __init__(self, default: int = 0, store_default: bool = True) -> None:
|
||||
super().__init__(default, store_default, int, (bool, float))
|
||||
|
||||
|
||||
class OptionalIntValue(SimpleValue[Optional[int]]):
|
||||
"""Value consisting of a single int or None"""
|
||||
|
||||
def __init__(self,
|
||||
default: int = None,
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(default,
|
||||
store_default,
|
||||
int, (bool, float),
|
||||
allow_none=True)
|
||||
|
||||
|
||||
class FloatValue(SimpleValue[float]):
|
||||
"""Value consisting of a single float."""
|
||||
|
||||
def __init__(self,
|
||||
default: float = 0.0,
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(default, store_default, float, (bool, int))
|
||||
|
||||
|
||||
class OptionalFloatValue(SimpleValue[Optional[float]]):
|
||||
"""Value consisting of a single float or None."""
|
||||
|
||||
def __init__(self,
|
||||
default: float = None,
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(default,
|
||||
store_default,
|
||||
float, (bool, int),
|
||||
allow_none=True)
|
||||
|
||||
|
||||
class Float3Value(SimpleValue[Tuple[float, float, float]]):
|
||||
"""Value consisting of 3 floats."""
|
||||
|
||||
def __init__(self,
|
||||
default: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(default, store_default)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '<Value of type float3>'
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> Any:
|
||||
if (not isinstance(data, abc.Sequence) or len(data) != 3
|
||||
or any(not isinstance(i, (int, float)) for i in data)):
|
||||
if error:
|
||||
raise TypeError('Sequence of 3 float values expected.')
|
||||
logging.error('Ignoring non-3-float-sequence data for %s: %s',
|
||||
self, data)
|
||||
data = self.get_default_data()
|
||||
|
||||
# Actually store as list.
|
||||
return [float(data[0]), float(data[1]), float(data[2])]
|
||||
|
||||
def filter_output(self, data: Any) -> Any:
|
||||
"""Override."""
|
||||
assert len(data) == 3
|
||||
return tuple(data)
|
||||
|
||||
|
||||
class BaseEnumValue(TypedValue[T]):
|
||||
"""Value class for storing Python Enums.
|
||||
|
||||
Internally enums are stored as their corresponding int/str/etc. values.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
enumtype: Type[T],
|
||||
default: Optional[T] = None,
|
||||
store_default: bool = True,
|
||||
allow_none: bool = False) -> None:
|
||||
super().__init__()
|
||||
assert issubclass(enumtype, Enum)
|
||||
|
||||
vals: List[T] = list(enumtype)
|
||||
|
||||
# Bit of sanity checking: make sure this enum has at least
|
||||
# one value and that its underlying values are all of simple
|
||||
# json-friendly types.
|
||||
if not vals:
|
||||
raise TypeError(f'enum {enumtype} has no values')
|
||||
for val in vals:
|
||||
assert isinstance(val, Enum)
|
||||
if not isinstance(val.value, (int, bool, float, str)):
|
||||
raise TypeError(f'enum value {val} has an invalid'
|
||||
f' value type {type(val.value)}')
|
||||
self._enumtype: Type[Enum] = enumtype
|
||||
self._store_default: bool = store_default
|
||||
self._allow_none: bool = allow_none
|
||||
|
||||
# We store default data is internal format so need to run
|
||||
# user-provided value through input filter.
|
||||
# Make sure to set this last since it could depend on other
|
||||
# stuff we set here.
|
||||
if default is None and not self._allow_none:
|
||||
# Special case: we allow passing None as default even if
|
||||
# we don't support None as a value; in that case we sub
|
||||
# in the first enum value.
|
||||
default = vals[0]
|
||||
self._default_data: Enum = self.filter_input(default, error=True)
|
||||
|
||||
def get_default_data(self) -> Any:
|
||||
return self._default_data
|
||||
|
||||
def prune_data(self, data: Any) -> bool:
|
||||
return not self._store_default and data == self._default_data
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> Any:
|
||||
|
||||
# Allow passing in enum objects directly of course.
|
||||
if isinstance(data, self._enumtype):
|
||||
data = data.value
|
||||
elif self._allow_none and data is None:
|
||||
pass
|
||||
else:
|
||||
# At this point we assume its an enum value
|
||||
try:
|
||||
self._enumtype(data)
|
||||
except ValueError:
|
||||
if error:
|
||||
raise ValueError(
|
||||
f'Invalid value for {self._enumtype}: {data}'
|
||||
) from None
|
||||
logging.error('Ignoring invalid value for %s: %s',
|
||||
self._enumtype, data)
|
||||
data = self._default_data
|
||||
return data
|
||||
|
||||
def filter_output(self, data: Any) -> Any:
|
||||
if self._allow_none and data is None:
|
||||
return None
|
||||
return self._enumtype(data)
|
||||
|
||||
|
||||
class EnumValue(BaseEnumValue[TE]):
|
||||
"""Value class for storing Python Enums.
|
||||
|
||||
Internally enums are stored as their corresponding int/str/etc. values.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
enumtype: Type[TE],
|
||||
default: TE = None,
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(enumtype, default, store_default, allow_none=False)
|
||||
|
||||
|
||||
class OptionalEnumValue(BaseEnumValue[Optional[TE]]):
|
||||
"""Value class for storing Python Enums (or None).
|
||||
|
||||
Internally enums are stored as their corresponding int/str/etc. values.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
enumtype: Type[TE],
|
||||
default: TE = None,
|
||||
store_default: bool = True) -> None:
|
||||
super().__init__(enumtype, default, store_default, allow_none=True)
|
||||
|
||||
|
||||
class CompoundValue(DataHandler):
|
||||
"""A value containing one or more named child fields of its own.
|
||||
|
||||
Custom classes can be defined that inherit from this and include
|
||||
any number of Field instances within themself.
|
||||
"""
|
||||
|
||||
def __init__(self, store_default: bool = True) -> None:
|
||||
super().__init__()
|
||||
self._store_default = store_default
|
||||
|
||||
# Run sanity checks on this type if we haven't.
|
||||
self.run_type_sanity_checks()
|
||||
|
||||
def __eq__(self, other: Any) -> Any:
|
||||
# Allow comparing to compound and bound-compound objects.
|
||||
return compound_eq(self, other)
|
||||
|
||||
def get_default_data(self) -> dict:
|
||||
return {}
|
||||
|
||||
# NOTE: once we've got bound-compound-fields working in mypy
|
||||
# we should get rid of this here.
|
||||
# For now it needs to be here though since bound-compound fields
|
||||
# come across as these in type-land.
|
||||
def reset(self) -> None:
|
||||
"""Resets data to default."""
|
||||
raise ValueError('Unbound CompoundValue cannot be reset.')
|
||||
|
||||
def filter_input(self, data: Any, error: bool) -> dict:
|
||||
if not isinstance(data, dict):
|
||||
if error:
|
||||
raise TypeError('dict value expected')
|
||||
logging.error('Ignoring non-dict data for %s: %s', self, data)
|
||||
data = {}
|
||||
assert isinstance(data, dict)
|
||||
self.apply_fields_to_data(data, error=error)
|
||||
return data
|
||||
|
||||
def prune_data(self, data: Any) -> bool:
|
||||
# Let all of our sub-fields prune themselves..
|
||||
self.prune_fields_data(data)
|
||||
|
||||
# Now we can optionally prune ourself completely if there's
|
||||
# nothing left in our data dict...
|
||||
return not data and not self._store_default
|
||||
|
||||
def prune_fields_data(self, d_data: Dict[str, Any]) -> None:
|
||||
"""Given a CompoundValue and data, prune any unnecessary data.
|
||||
will include those set to default values with store_default False.
|
||||
"""
|
||||
|
||||
# Allow all fields to take a pruning pass.
|
||||
assert isinstance(d_data, dict)
|
||||
for field in self.get_fields().values():
|
||||
assert isinstance(field.d_key, str)
|
||||
|
||||
# This is supposed to be valid data so there should be *something*
|
||||
# there for all fields.
|
||||
if field.d_key not in d_data:
|
||||
raise RuntimeError(f'expected to find {field.d_key} in data'
|
||||
f' for {self}; got data {d_data}')
|
||||
|
||||
# Now ask the field if this data is necessary. If not, prune it.
|
||||
if field.prune_data(d_data[field.d_key]):
|
||||
del d_data[field.d_key]
|
||||
|
||||
def apply_fields_to_data(self, d_data: Dict[str, Any],
|
||||
error: bool) -> None:
|
||||
"""Apply all of our fields to target data.
|
||||
|
||||
If error is True, exceptions will be raised for invalid data;
|
||||
otherwise it will be overwritten (with logging notices emitted).
|
||||
"""
|
||||
assert isinstance(d_data, dict)
|
||||
for field in self.get_fields().values():
|
||||
assert isinstance(field.d_key, str)
|
||||
|
||||
# First off, make sure *something* is there for this field.
|
||||
if field.d_key not in d_data:
|
||||
d_data[field.d_key] = field.get_default_data()
|
||||
|
||||
# Now let the field tweak the data as needed so its valid.
|
||||
d_data[field.d_key] = field.filter_input(d_data[field.d_key],
|
||||
error=error)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if not hasattr(self, 'd_data'):
|
||||
return f'<unbound {type(self).__name__} at {hex(id(self))}>'
|
||||
fstrs: List[str] = []
|
||||
assert isinstance(self, CompoundValue)
|
||||
for field in self.get_fields():
|
||||
fstrs.append(str(field) + '=' + repr(getattr(self, field)))
|
||||
return type(self).__name__ + '(' + ', '.join(fstrs) + ')'
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls) -> Dict[str, BaseField]:
|
||||
"""Return all field instances for this type."""
|
||||
assert issubclass(cls, CompoundValue)
|
||||
|
||||
# If we haven't yet, calculate and cache a complete list of fields
|
||||
# for this exact type.
|
||||
if cls not in _type_field_cache:
|
||||
fields: Dict[str, BaseField] = {}
|
||||
for icls in inspect.getmro(cls):
|
||||
for name, field in icls.__dict__.items():
|
||||
if isinstance(field, BaseField):
|
||||
fields[name] = field
|
||||
_type_field_cache[cls] = fields
|
||||
retval: Dict[str, BaseField] = _type_field_cache[cls]
|
||||
assert isinstance(retval, dict)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def run_type_sanity_checks(cls) -> None:
|
||||
"""Given a type, run one-time sanity checks on it.
|
||||
|
||||
These tests ensure child fields are using valid
|
||||
non-repeating names/etc.
|
||||
"""
|
||||
if cls not in _sanity_tested_types:
|
||||
_sanity_tested_types.add(cls)
|
||||
|
||||
# Make sure all embedded fields have a key set and there are no
|
||||
# duplicates.
|
||||
field_keys: Set[str] = set()
|
||||
for field in cls.get_fields().values():
|
||||
assert isinstance(field.d_key, str)
|
||||
if field.d_key is None:
|
||||
raise RuntimeError(f'Child field {field} under {cls}'
|
||||
'has d_key None')
|
||||
if field.d_key == '':
|
||||
raise RuntimeError(f'Child field {field} under {cls}'
|
||||
'has empty d_key')
|
||||
|
||||
# Allow alphanumeric and underscore only.
|
||||
if not field.d_key.replace('_', '').isalnum():
|
||||
raise RuntimeError(
|
||||
f'Child field "{field.d_key}" under {cls}'
|
||||
f' contains invalid characters; only alphanumeric'
|
||||
f' and underscore allowed.')
|
||||
if field.d_key in field_keys:
|
||||
raise RuntimeError('Multiple child fields with key'
|
||||
f' "{field.d_key}" found in {cls}')
|
||||
field_keys.add(field.d_key)
|
||||
131
dist/ba_data/python/efro/entity/util.py
vendored
Normal file
131
dist/ba_data/python/efro/entity/util.py
vendored
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Misc utility functionality related to the entity system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Union, Tuple, List
|
||||
from efro.entity._value import CompoundValue
|
||||
from efro.entity._support import BoundCompoundValue
|
||||
|
||||
|
||||
def diff_compound_values(
|
||||
obj1: Union[BoundCompoundValue, CompoundValue],
|
||||
obj2: Union[BoundCompoundValue, CompoundValue]) -> str:
|
||||
"""Generate a string showing differences between two compound values.
|
||||
|
||||
Both must be associated with data and have the same set of fields.
|
||||
"""
|
||||
|
||||
# Ensure fields match and both are attached to data...
|
||||
value1, data1 = get_compound_value_and_data(obj1)
|
||||
if data1 is None:
|
||||
raise ValueError(f'Invalid unbound compound value: {obj1}')
|
||||
value2, data2 = get_compound_value_and_data(obj2)
|
||||
if data2 is None:
|
||||
raise ValueError(f'Invalid unbound compound value: {obj2}')
|
||||
if not have_matching_fields(value1, value2):
|
||||
raise ValueError(
|
||||
f"Can't diff objs with non-matching fields: {value1} and {value2}")
|
||||
|
||||
# Ok; let 'er rip...
|
||||
diff = _diff(obj1, obj2, 2)
|
||||
return ' <no differences>' if diff == '' else diff
|
||||
|
||||
|
||||
class CompoundValueDiff:
|
||||
"""Wraps diff_compound_values() in an object for efficiency.
|
||||
|
||||
It is preferable to pass this to logging calls instead of the
|
||||
final diff string since the diff will never be generated if
|
||||
the associated logging level is not being emitted.
|
||||
"""
|
||||
|
||||
def __init__(self, obj1: Union[BoundCompoundValue, CompoundValue],
|
||||
obj2: Union[BoundCompoundValue, CompoundValue]):
|
||||
self._obj1 = obj1
|
||||
self._obj2 = obj2
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return diff_compound_values(self._obj1, self._obj2)
|
||||
|
||||
|
||||
def _diff(obj1: Union[BoundCompoundValue, CompoundValue],
|
||||
obj2: Union[BoundCompoundValue, CompoundValue], indent: int) -> str:
|
||||
from efro.entity._support import BoundCompoundValue
|
||||
bits: List[str] = []
|
||||
indentstr = ' ' * indent
|
||||
vobj1, _data1 = get_compound_value_and_data(obj1)
|
||||
fields = sorted(vobj1.get_fields().keys())
|
||||
for field in fields:
|
||||
val1 = getattr(obj1, field)
|
||||
val2 = getattr(obj2, field)
|
||||
# for nested compounds, dive in and do nice piecewise compares
|
||||
if isinstance(val1, BoundCompoundValue):
|
||||
assert isinstance(val2, BoundCompoundValue)
|
||||
diff = _diff(val1, val2, indent + 2)
|
||||
if diff != '':
|
||||
bits.append(f'{indentstr}{field}:')
|
||||
bits.append(diff)
|
||||
# for all else just do a single line
|
||||
# (perhaps we could improve on this for other complex types)
|
||||
else:
|
||||
if val1 != val2:
|
||||
bits.append(f'{indentstr}{field}: {val1} -> {val2}')
|
||||
return '\n'.join(bits)
|
||||
|
||||
|
||||
def have_matching_fields(val1: CompoundValue, val2: CompoundValue) -> bool:
|
||||
"""Return whether two compound-values have matching sets of fields.
|
||||
|
||||
Note this just refers to the field configuration; not data.
|
||||
"""
|
||||
# Quick-out: matching types will always have identical fields.
|
||||
if type(val1) is type(val2):
|
||||
return True
|
||||
|
||||
# Otherwise do a full comparison.
|
||||
return val1.get_fields() == val2.get_fields()
|
||||
|
||||
|
||||
def get_compound_value_and_data(
|
||||
obj: Union[BoundCompoundValue,
|
||||
CompoundValue]) -> Tuple[CompoundValue, Any]:
|
||||
"""Return value and data for bound or unbound compound values."""
|
||||
# pylint: disable=cyclic-import
|
||||
from efro.entity._support import BoundCompoundValue
|
||||
from efro.entity._value import CompoundValue
|
||||
if isinstance(obj, BoundCompoundValue):
|
||||
value = obj.d_value
|
||||
data = obj.d_data
|
||||
elif isinstance(obj, CompoundValue):
|
||||
value = obj
|
||||
data = getattr(obj, 'd_data', None) # may not exist
|
||||
else:
|
||||
raise TypeError(
|
||||
f'Expected a BoundCompoundValue or CompoundValue; got {type(obj)}')
|
||||
return value, data
|
||||
|
||||
|
||||
def compound_eq(obj1: Union[BoundCompoundValue, CompoundValue],
|
||||
obj2: Union[BoundCompoundValue, CompoundValue]) -> Any:
|
||||
"""Compare two compound value/bound-value objects for equality."""
|
||||
|
||||
# Criteria for comparison: both need to be a compound value
|
||||
# and both must have data (which implies they are either a entity
|
||||
# or bound to a subfield in an entity).
|
||||
value1, data1 = get_compound_value_and_data(obj1)
|
||||
if data1 is None:
|
||||
return NotImplemented
|
||||
value2, data2 = get_compound_value_and_data(obj2)
|
||||
if data2 is None:
|
||||
return NotImplemented
|
||||
|
||||
# Ok we can compare them. To consider them equal we look for
|
||||
# matching sets of fields and matching data. Note that there
|
||||
# could be unbound data causing inequality despite their field
|
||||
# values all matching; not sure if that's what we want.
|
||||
return have_matching_fields(value1, value2) and data1 == data2
|
||||
244
dist/ba_data/python/efro/error.py
vendored
Normal file
244
dist/ba_data/python/efro/error.py
vendored
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Common errors and related functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
import errno
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class CleanError(Exception):
|
||||
"""An error that can be presented to the user as a simple message.
|
||||
|
||||
These errors should be completely self-explanatory, to the point where
|
||||
a traceback or other context would not be useful.
|
||||
|
||||
A CleanError with no message can be used to inform a script to fail
|
||||
without printing any message.
|
||||
|
||||
This should generally be limited to errors that will *always* be
|
||||
presented to the user (such as those in high level tool code).
|
||||
Exceptions that may be caught and handled by other code should use
|
||||
more descriptive exception types.
|
||||
"""
|
||||
|
||||
def pretty_print(self, flush: bool = False) -> None:
|
||||
"""Print the error to stdout, using red colored output if available.
|
||||
|
||||
If the error has an empty message, prints nothing (not even a newline).
|
||||
"""
|
||||
from efro.terminal import Clr
|
||||
|
||||
errstr = str(self)
|
||||
if errstr:
|
||||
print(f'{Clr.SRED}{errstr}{Clr.RST}', flush=flush)
|
||||
|
||||
|
||||
class CommunicationError(Exception):
|
||||
"""A communication related error has occurred.
|
||||
|
||||
This covers anything network-related going wrong in the sending
|
||||
of data or receiving of a response. Basically anything that is out
|
||||
of our control should get lumped in here. This error does not imply
|
||||
that data was not received on the other end; only that a full
|
||||
acknowledgement round trip was not completed.
|
||||
|
||||
These errors should be gracefully handled whenever possible, as
|
||||
occasional network issues are unavoidable.
|
||||
"""
|
||||
|
||||
|
||||
class RemoteError(Exception):
|
||||
"""An error occurred on the other end of some connection.
|
||||
|
||||
This occurs when communication succeeds but another type of error
|
||||
occurs remotely. The error string can consist of a remote stack
|
||||
trace or a simple message depending on the context.
|
||||
|
||||
Communication systems should raise more specific error types locally
|
||||
when more introspection/control is needed; this is intended somewhat
|
||||
as a catch-all.
|
||||
"""
|
||||
|
||||
def __init__(self, msg: str, peer_desc: str):
|
||||
super().__init__(msg)
|
||||
self._peer_desc = peer_desc
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = ''.join(str(arg) for arg in self.args)
|
||||
# Indent so we can more easily tell what is the remote part when
|
||||
# this is in the middle of a long exception chain.
|
||||
padding = ' '
|
||||
s = ''.join(padding + line for line in s.splitlines(keepends=True))
|
||||
return f'The following occurred on {self._peer_desc}:\n{s}'
|
||||
|
||||
|
||||
class IntegrityError(ValueError):
|
||||
"""Data has been tampered with or corrupted in some form."""
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Authentication has failed for some operation.
|
||||
|
||||
This can be raised if server-side-verification does not match
|
||||
client-supplied credentials, if an invalid password is supplied
|
||||
for a sign-in attempt, etc.
|
||||
"""
|
||||
|
||||
|
||||
def is_urllib_communication_error(exc: BaseException, url: str | None) -> bool:
|
||||
"""Is the provided exception from urllib a communication-related error?
|
||||
|
||||
Url, if provided can provide extra context for when to treat an error
|
||||
as such an error.
|
||||
|
||||
This should be passed an exception which resulted from opening or
|
||||
reading a urllib Request. It returns True for any errors that could
|
||||
conceivably arise due to unavailable/poor network connections,
|
||||
firewall/connectivity issues, or other issues out of our control.
|
||||
These errors can often be safely ignored or presented to the user
|
||||
as general 'network-unavailable' states.
|
||||
"""
|
||||
import urllib.error
|
||||
import http.client
|
||||
import socket
|
||||
|
||||
if isinstance(
|
||||
exc,
|
||||
(
|
||||
urllib.error.URLError,
|
||||
ConnectionError,
|
||||
http.client.IncompleteRead,
|
||||
http.client.BadStatusLine,
|
||||
http.client.RemoteDisconnected,
|
||||
socket.timeout,
|
||||
),
|
||||
):
|
||||
|
||||
# Special case: although an HTTPError is a subclass of URLError,
|
||||
# we don't consider it a communication error. It generally means we
|
||||
# have successfully communicated with the server but what we are asking
|
||||
# for is not there/etc.
|
||||
if isinstance(exc, urllib.error.HTTPError):
|
||||
|
||||
# Special sub-case: appspot.com hosting seems to give 403 errors
|
||||
# (forbidden) to some countries. I'm assuming for legal reasons?..
|
||||
# Let's consider that a communication error since its out of our
|
||||
# control so we don't fill up logs with it.
|
||||
if exc.code == 403 and url is not None and '.appspot.com' in url:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if isinstance(exc, OSError):
|
||||
if exc.errno == 10051: # Windows unreachable network error.
|
||||
return True
|
||||
if exc.errno in {
|
||||
errno.ETIMEDOUT,
|
||||
errno.EHOSTUNREACH,
|
||||
errno.ENETUNREACH,
|
||||
}:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_requests_communication_error(exc: BaseException) -> bool:
|
||||
"""Is the provided exception a communication-related error from requests?"""
|
||||
import requests
|
||||
|
||||
# Looks like this maps pretty well onto requests' ConnectionError
|
||||
return isinstance(exc, requests.ConnectionError)
|
||||
|
||||
|
||||
def is_udp_communication_error(exc: BaseException) -> bool:
|
||||
"""Should this udp-related exception be considered a communication error?
|
||||
|
||||
This should be passed an exception which resulted from creating and
|
||||
using a socket.SOCK_DGRAM type socket. It should return True for any
|
||||
errors that could conceivably arise due to unavailable/poor network
|
||||
conditions, firewall/connectivity issues, etc. These issues can often
|
||||
be safely ignored or presented to the user as general
|
||||
'network-unavailable' states.
|
||||
"""
|
||||
if isinstance(exc, ConnectionRefusedError | TimeoutError):
|
||||
return True
|
||||
if isinstance(exc, OSError):
|
||||
if exc.errno == 10051: # Windows unreachable network error.
|
||||
return True
|
||||
if exc.errno in {
|
||||
errno.EADDRNOTAVAIL,
|
||||
errno.ETIMEDOUT,
|
||||
errno.EHOSTUNREACH,
|
||||
errno.ENETUNREACH,
|
||||
errno.EINVAL,
|
||||
errno.EPERM,
|
||||
errno.EACCES,
|
||||
# Windows 'invalid argument' error.
|
||||
10022,
|
||||
# Windows 'a socket operation was attempted to'
|
||||
# 'an unreachable network' error.
|
||||
10051,
|
||||
}:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_asyncio_streams_communication_error(exc: BaseException) -> bool:
|
||||
"""Should this streams error be considered a communication error?
|
||||
|
||||
This should be passed an exception which resulted from creating and
|
||||
using asyncio streams. It should return True for any errors that could
|
||||
conceivably arise due to unavailable/poor network connections,
|
||||
firewall/connectivity issues, etc. These issues can often be safely
|
||||
ignored or presented to the user as general 'connection-lost' events.
|
||||
"""
|
||||
# pylint: disable=too-many-return-statements
|
||||
import ssl
|
||||
|
||||
if isinstance(
|
||||
exc,
|
||||
(
|
||||
ConnectionError,
|
||||
TimeoutError,
|
||||
EOFError,
|
||||
),
|
||||
):
|
||||
return True
|
||||
|
||||
# Also some specific errno ones.
|
||||
if isinstance(exc, OSError):
|
||||
if exc.errno == 10051: # Windows unreachable network error.
|
||||
return True
|
||||
if exc.errno in {
|
||||
errno.ETIMEDOUT,
|
||||
errno.EHOSTUNREACH,
|
||||
errno.ENETUNREACH,
|
||||
}:
|
||||
return True
|
||||
|
||||
# Am occasionally getting a specific SSL error on shutdown which I
|
||||
# believe is harmless (APPLICATION_DATA_AFTER_CLOSE_NOTIFY).
|
||||
# It sounds like it may soon be ignored by Python (as of March 2022).
|
||||
# Let's still complain, however, if we get any SSL errors besides
|
||||
# this one. https://bugs.python.org/issue39951
|
||||
if isinstance(exc, ssl.SSLError):
|
||||
excstr = str(exc)
|
||||
if 'APPLICATION_DATA_AFTER_CLOSE_NOTIFY' in excstr:
|
||||
return True
|
||||
|
||||
# Also occasionally am getting WRONG_VERSION_NUMBER ssl errors;
|
||||
# Assuming this just means client is attempting to connect from some
|
||||
# outdated browser or whatnot.
|
||||
if 'SSL: WRONG_VERSION_NUMBER' in excstr:
|
||||
return True
|
||||
|
||||
# And seeing this very rarely; assuming its just data corruption?
|
||||
if 'SSL: DECRYPTION_FAILED_OR_BAD_RECORD_MAC' in excstr:
|
||||
return True
|
||||
|
||||
return False
|
||||
72
dist/ba_data/python/efro/json.py
vendored
Normal file
72
dist/ba_data/python/efro/json.py
vendored
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Custom json compressor/decompressor with support for more data times/etc."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
# Special attr we included for our extended type information
|
||||
# (extended-json-type)
|
||||
TYPE_TAG = '_xjtp'
|
||||
|
||||
_pytz_utc: Any
|
||||
|
||||
# We don't *require* pytz since it must be installed through pip
|
||||
# but it is used by firestore client for its utc tzinfos.
|
||||
# (in which case it should be installed as a dependency anyway)
|
||||
try:
|
||||
import pytz
|
||||
_pytz_utc = pytz.utc
|
||||
except ModuleNotFoundError:
|
||||
_pytz_utc = None # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ExtendedJSONEncoder(json.JSONEncoder):
|
||||
"""Custom json encoder supporting additional types."""
|
||||
|
||||
def default(self, obj: Any) -> Any: # pylint: disable=W0221
|
||||
if isinstance(obj, datetime.datetime):
|
||||
|
||||
# We only support timezone-aware utc times.
|
||||
if (obj.tzinfo is not datetime.timezone.utc
|
||||
and (_pytz_utc is None or obj.tzinfo is not _pytz_utc)):
|
||||
raise ValueError(
|
||||
'datetime values must have timezone set as timezone.utc')
|
||||
return {
|
||||
TYPE_TAG:
|
||||
'dt',
|
||||
'v': [
|
||||
obj.year, obj.month, obj.day, obj.hour, obj.minute,
|
||||
obj.second, obj.microsecond
|
||||
],
|
||||
}
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
class ExtendedJSONDecoder(json.JSONDecoder):
|
||||
"""Custom json decoder supporting extended types."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
json.JSONDecoder.__init__(self,
|
||||
object_hook=self.object_hook,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
def object_hook(self, obj: Any) -> Any: # pylint: disable=E0202
|
||||
"""Custom hook."""
|
||||
if TYPE_TAG not in obj:
|
||||
return obj
|
||||
objtype = obj[TYPE_TAG]
|
||||
if objtype == 'dt':
|
||||
vals = obj.get('v', [])
|
||||
if len(vals) != 7:
|
||||
raise ValueError('malformed datetime value')
|
||||
return datetime.datetime( # type: ignore
|
||||
*vals, tzinfo=datetime.timezone.utc)
|
||||
return obj
|
||||
626
dist/ba_data/python/efro/log.py
vendored
Normal file
626
dist/ba_data/python/efro/log.py
vendored
Normal file
|
|
@ -0,0 +1,626 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Logging functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import datetime
|
||||
import itertools
|
||||
from enum import Enum
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
from threading import Thread, current_thread, Lock
|
||||
|
||||
from efro.util import utc_now
|
||||
from efro.call import tpartial
|
||||
from efro.terminal import TerminalColor
|
||||
from efro.dataclassio import ioprepped, IOAttrs, dataclass_to_json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, TextIO
|
||||
|
||||
|
||||
class LogLevel(Enum):
|
||||
"""Severity level for a log entry.
|
||||
|
||||
These enums have numeric values so they can be compared in severity.
|
||||
Note that these values are not currently interchangeable with the
|
||||
logging.ERROR, logging.DEBUG, etc. values.
|
||||
"""
|
||||
|
||||
DEBUG = 0
|
||||
INFO = 1
|
||||
WARNING = 2
|
||||
ERROR = 3
|
||||
CRITICAL = 4
|
||||
|
||||
@property
|
||||
def python_logging_level(self) -> int:
|
||||
"""Give the corresponding logging level."""
|
||||
return LOG_LEVEL_LEVELNOS[self]
|
||||
|
||||
@classmethod
|
||||
def from_python_logging_level(cls, levelno: int) -> LogLevel:
|
||||
"""Given a Python logging level, return a LogLevel."""
|
||||
return LEVELNO_LOG_LEVELS[levelno]
|
||||
|
||||
|
||||
# Python logging levels from LogLevels
|
||||
LOG_LEVEL_LEVELNOS = {
|
||||
LogLevel.DEBUG: logging.DEBUG,
|
||||
LogLevel.INFO: logging.INFO,
|
||||
LogLevel.WARNING: logging.WARNING,
|
||||
LogLevel.ERROR: logging.ERROR,
|
||||
LogLevel.CRITICAL: logging.CRITICAL,
|
||||
}
|
||||
|
||||
# LogLevels from Python logging levels
|
||||
LEVELNO_LOG_LEVELS = {
|
||||
logging.DEBUG: LogLevel.DEBUG,
|
||||
logging.INFO: LogLevel.INFO,
|
||||
logging.WARNING: LogLevel.WARNING,
|
||||
logging.ERROR: LogLevel.ERROR,
|
||||
logging.CRITICAL: LogLevel.CRITICAL,
|
||||
}
|
||||
|
||||
LEVELNO_COLOR_CODES: dict[int, tuple[str, str]] = {
|
||||
logging.DEBUG: (TerminalColor.CYAN.value, TerminalColor.RESET.value),
|
||||
logging.INFO: ('', ''),
|
||||
logging.WARNING: (TerminalColor.YELLOW.value, TerminalColor.RESET.value),
|
||||
logging.ERROR: (TerminalColor.RED.value, TerminalColor.RESET.value),
|
||||
logging.CRITICAL: (
|
||||
TerminalColor.STRONG_MAGENTA.value
|
||||
+ TerminalColor.BOLD.value
|
||||
+ TerminalColor.BG_BLACK.value,
|
||||
TerminalColor.RESET.value,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class LogEntry:
|
||||
"""Single logged message."""
|
||||
|
||||
name: Annotated[str, IOAttrs('n', soft_default='root', store_default=False)]
|
||||
message: Annotated[str, IOAttrs('m')]
|
||||
level: Annotated[LogLevel, IOAttrs('l')]
|
||||
time: Annotated[datetime.datetime, IOAttrs('t')]
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class LogArchive:
|
||||
"""Info and data for a log."""
|
||||
|
||||
# Total number of entries submitted to the log.
|
||||
log_size: Annotated[int, IOAttrs('t')]
|
||||
|
||||
# Offset for the entries contained here.
|
||||
# (10 means our first entry is the 10th in the log, etc.)
|
||||
start_index: Annotated[int, IOAttrs('c')]
|
||||
|
||||
entries: Annotated[list[LogEntry], IOAttrs('e')]
|
||||
|
||||
|
||||
class LogHandler(logging.Handler):
|
||||
"""Fancy-pants handler for logging output.
|
||||
|
||||
Writes logs to disk in structured json format and echoes them
|
||||
to stdout/stderr with pretty colors.
|
||||
"""
|
||||
|
||||
_event_loop: asyncio.AbstractEventLoop
|
||||
|
||||
# IMPORTANT: Any debug prints we do here should ONLY go to echofile.
|
||||
# Otherwise we can get infinite loops as those prints come back to us
|
||||
# as new log entries.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | Path | None,
|
||||
echofile: TextIO | None,
|
||||
suppress_non_root_debug: bool,
|
||||
cache_size_limit: int,
|
||||
cache_time_limit: datetime.timedelta | None,
|
||||
):
|
||||
super().__init__()
|
||||
# pylint: disable=consider-using-with
|
||||
self._file = None if path is None else open(path, 'w', encoding='utf-8')
|
||||
self._echofile = echofile
|
||||
self._callbacks_lock = Lock()
|
||||
self._callbacks: list[Callable[[LogEntry], None]] = []
|
||||
self._suppress_non_root_debug = suppress_non_root_debug
|
||||
self._file_chunks: dict[str, list[str]] = {'stdout': [], 'stderr': []}
|
||||
self._file_chunk_ship_task: dict[str, asyncio.Task | None] = {
|
||||
'stdout': None,
|
||||
'stderr': None,
|
||||
}
|
||||
self._cache_size = 0
|
||||
assert cache_size_limit >= 0
|
||||
self._cache_size_limit = cache_size_limit
|
||||
self._cache_time_limit = cache_time_limit
|
||||
self._cache = deque[tuple[int, LogEntry]]()
|
||||
self._cache_index_offset = 0
|
||||
self._cache_lock = Lock()
|
||||
self._printed_callback_error = False
|
||||
self._thread_bootstrapped = False
|
||||
self._thread = Thread(target=self._log_thread_main, daemon=True)
|
||||
if __debug__:
|
||||
self._last_slow_emit_warning_time: float | None = None
|
||||
self._thread.start()
|
||||
|
||||
# Spin until our thread is up and running; otherwise we could
|
||||
# wind up trying to push stuff to our event loop before the
|
||||
# loop exists.
|
||||
while not self._thread_bootstrapped:
|
||||
time.sleep(0.001)
|
||||
|
||||
def add_callback(self, call: Callable[[LogEntry], None]) -> None:
|
||||
"""Add a callback to be run for each LogEntry.
|
||||
|
||||
Note that this callback will always run in a background thread.
|
||||
"""
|
||||
with self._callbacks_lock:
|
||||
self._callbacks.append(call)
|
||||
|
||||
def _log_thread_main(self) -> None:
|
||||
self._event_loop = asyncio.new_event_loop()
|
||||
|
||||
# In our background thread event loop we do a fair amount of
|
||||
# slow synchronous stuff such as mucking with the log cache.
|
||||
# Let's avoid getting tons of warnings about this in debug mode.
|
||||
self._event_loop.slow_callback_duration = 2.0 # Default is 0.1
|
||||
|
||||
# NOTE: if we ever use default threadpool at all we should allow
|
||||
# setting it for our loop.
|
||||
asyncio.set_event_loop(self._event_loop)
|
||||
self._thread_bootstrapped = True
|
||||
try:
|
||||
if self._cache_time_limit is not None:
|
||||
self._event_loop.create_task(self._time_prune_cache())
|
||||
self._event_loop.run_forever()
|
||||
except BaseException:
|
||||
# If this ever goes down we're in trouble.
|
||||
# We won't be able to log about it though...
|
||||
# Try to make some noise however we can.
|
||||
print('LogHandler died!!!', file=sys.stderr)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
async def _time_prune_cache(self) -> None:
|
||||
assert self._cache_time_limit is not None
|
||||
while bool(True):
|
||||
await asyncio.sleep(61.27)
|
||||
now = utc_now()
|
||||
with self._cache_lock:
|
||||
|
||||
# Prune the oldest entry as long as there is a first one that
|
||||
# is too old.
|
||||
while (
|
||||
self._cache
|
||||
and (now - self._cache[0][1].time) >= self._cache_time_limit
|
||||
):
|
||||
popped = self._cache.popleft()
|
||||
self._cache_size -= popped[0]
|
||||
self._cache_index_offset += 1
|
||||
|
||||
def get_cached(
|
||||
self, start_index: int = 0, max_entries: int | None = None
|
||||
) -> LogArchive:
|
||||
"""Build and return an archive of cached log entries.
|
||||
|
||||
This will only include entries that have been processed by the
|
||||
background thread, so may not include just-submitted logs or
|
||||
entries for partially written stdout/stderr lines.
|
||||
Entries from the range [start_index:start_index+max_entries]
|
||||
which are still present in the cache will be returned.
|
||||
"""
|
||||
|
||||
assert start_index >= 0
|
||||
if max_entries is not None:
|
||||
assert max_entries >= 0
|
||||
with self._cache_lock:
|
||||
# Transform start_index to our present cache space.
|
||||
start_index -= self._cache_index_offset
|
||||
# Calc end-index in our present cache space.
|
||||
end_index = (
|
||||
len(self._cache)
|
||||
if max_entries is None
|
||||
else start_index + max_entries
|
||||
)
|
||||
|
||||
# Clamp both indexes to both ends of our present space.
|
||||
start_index = max(0, min(start_index, len(self._cache)))
|
||||
end_index = max(0, min(end_index, len(self._cache)))
|
||||
|
||||
return LogArchive(
|
||||
log_size=self._cache_index_offset + len(self._cache),
|
||||
start_index=start_index + self._cache_index_offset,
|
||||
entries=self._cache_slice(start_index, end_index),
|
||||
)
|
||||
|
||||
def _cache_slice(
|
||||
self, start: int, end: int, step: int = 1
|
||||
) -> list[LogEntry]:
|
||||
# Deque doesn't natively support slicing but we can do it manually.
|
||||
# It sounds like rotating the deque and pulling from the beginning
|
||||
# is the most efficient way to do this. The downside is the deque
|
||||
# gets temporarily modified in the process so we need to make sure
|
||||
# we're holding the lock.
|
||||
assert self._cache_lock.locked()
|
||||
cache = self._cache
|
||||
cache.rotate(-start)
|
||||
slc = [e[1] for e in itertools.islice(cache, 0, end - start, step)]
|
||||
cache.rotate(start)
|
||||
return slc
|
||||
|
||||
@classmethod
|
||||
def _is_immutable_log_data(cls, data: Any) -> bool:
|
||||
if isinstance(data, (str, bool, int, float, bytes)):
|
||||
return True
|
||||
if isinstance(data, tuple):
|
||||
return all(cls._is_immutable_log_data(x) for x in data)
|
||||
return False
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
if __debug__:
|
||||
starttime = time.monotonic()
|
||||
|
||||
# Called by logging to send us records.
|
||||
|
||||
# Special case: filter out this common extra-chatty category.
|
||||
# TODO - perhaps should use a standard logging.Filter for this.
|
||||
if (
|
||||
self._suppress_non_root_debug
|
||||
and record.name != 'root'
|
||||
and record.levelname == 'DEBUG'
|
||||
):
|
||||
return
|
||||
|
||||
# Optimization: if our log args are all simple immutable values,
|
||||
# we can just kick the whole thing over to our background thread to
|
||||
# be formatted there at our leisure. If anything is mutable and
|
||||
# thus could possibly change between now and then or if we want
|
||||
# to do immediate file echoing then we need to bite the bullet
|
||||
# and do that stuff here at the call site.
|
||||
fast_path = self._echofile is None and self._is_immutable_log_data(
|
||||
record.args
|
||||
)
|
||||
|
||||
if fast_path:
|
||||
if __debug__:
|
||||
formattime = echotime = time.monotonic()
|
||||
self._event_loop.call_soon_threadsafe(
|
||||
tpartial(
|
||||
self._emit_in_thread,
|
||||
record.name,
|
||||
record.levelno,
|
||||
record.created,
|
||||
record,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Slow case; do formatting and echoing here at the log call
|
||||
# site.
|
||||
msg = self.format(record)
|
||||
|
||||
if __debug__:
|
||||
formattime = time.monotonic()
|
||||
|
||||
# Also immediately print pretty colored output to our echo file
|
||||
# (generally stderr). We do this part here instead of in our bg
|
||||
# thread because the delay can throw off command line prompts or
|
||||
# make tight debugging harder.
|
||||
if self._echofile is not None:
|
||||
ends = LEVELNO_COLOR_CODES.get(record.levelno)
|
||||
if ends is not None:
|
||||
self._echofile.write(f'{ends[0]}{msg}{ends[1]}\n')
|
||||
else:
|
||||
self._echofile.write(f'{msg}\n')
|
||||
|
||||
if __debug__:
|
||||
echotime = time.monotonic()
|
||||
|
||||
self._event_loop.call_soon_threadsafe(
|
||||
tpartial(
|
||||
self._emit_in_thread,
|
||||
record.name,
|
||||
record.levelno,
|
||||
record.created,
|
||||
msg,
|
||||
)
|
||||
)
|
||||
|
||||
if __debug__:
|
||||
# Make noise if we're taking a significant amount of time here.
|
||||
# Limit the noise to once every so often though; otherwise we
|
||||
# could get a feedback loop where every log emit results in a
|
||||
# warning log which results in another, etc.
|
||||
now = time.monotonic()
|
||||
# noinspection PyUnboundLocalVariable
|
||||
duration = now - starttime
|
||||
# noinspection PyUnboundLocalVariable
|
||||
format_duration = formattime - starttime
|
||||
# noinspection PyUnboundLocalVariable
|
||||
echo_duration = echotime - formattime
|
||||
if duration > 0.05 and (
|
||||
self._last_slow_emit_warning_time is None
|
||||
or now > self._last_slow_emit_warning_time + 10.0
|
||||
):
|
||||
# Logging calls from *within* a logging handler
|
||||
# sounds sketchy, so let's just kick this over to
|
||||
# the bg event loop thread we've already got.
|
||||
self._last_slow_emit_warning_time = now
|
||||
self._event_loop.call_soon_threadsafe(
|
||||
tpartial(
|
||||
logging.warning,
|
||||
'efro.log.LogHandler emit took too long'
|
||||
' (%.2fs total; %.2fs format, %.2fs echo,'
|
||||
' fast_path=%s).',
|
||||
duration,
|
||||
format_duration,
|
||||
echo_duration,
|
||||
fast_path,
|
||||
)
|
||||
)
|
||||
|
||||
def _emit_in_thread(
|
||||
self,
|
||||
name: str,
|
||||
levelno: int,
|
||||
created: float,
|
||||
message: str | logging.LogRecord,
|
||||
) -> None:
|
||||
try:
|
||||
|
||||
# If they passed a raw record here, bake it down to a string.
|
||||
if isinstance(message, logging.LogRecord):
|
||||
message = self.format(message)
|
||||
|
||||
self._emit_entry(
|
||||
LogEntry(
|
||||
name=name,
|
||||
message=message,
|
||||
level=LEVELNO_LOG_LEVELS.get(levelno, LogLevel.INFO),
|
||||
time=datetime.datetime.fromtimestamp(
|
||||
created, datetime.timezone.utc
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc(file=self._echofile)
|
||||
|
||||
def file_write(self, name: str, output: str) -> None:
|
||||
"""Send raw stdout/stderr output to the logger to be collated."""
|
||||
|
||||
self._event_loop.call_soon_threadsafe(
|
||||
tpartial(self._file_write_in_thread, name, output)
|
||||
)
|
||||
|
||||
def _file_write_in_thread(self, name: str, output: str) -> None:
|
||||
try:
|
||||
assert name in ('stdout', 'stderr')
|
||||
|
||||
# Here we try to be somewhat smart about breaking arbitrary
|
||||
# print output into discrete log entries.
|
||||
|
||||
self._file_chunks[name].append(output)
|
||||
|
||||
# Individual parts of a print come across as separate writes,
|
||||
# and the end of a print will be a standalone '\n' by default.
|
||||
# Let's use that as a hint that we're likely at the end of
|
||||
# a full print statement and ship what we've got.
|
||||
if output == '\n':
|
||||
self._ship_file_chunks(name, cancel_ship_task=True)
|
||||
else:
|
||||
# By default just keep adding chunks.
|
||||
# However we keep a timer running anytime we've got
|
||||
# unshipped chunks so that we can ship what we've got
|
||||
# after a short bit if we never get a newline.
|
||||
ship_task = self._file_chunk_ship_task[name]
|
||||
if ship_task is None:
|
||||
self._file_chunk_ship_task[
|
||||
name
|
||||
] = self._event_loop.create_task(
|
||||
self._ship_chunks_task(name),
|
||||
name='log ship file chunks',
|
||||
)
|
||||
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc(file=self._echofile)
|
||||
|
||||
def file_flush(self, name: str) -> None:
|
||||
"""Send raw stdout/stderr flush to the logger to be collated."""
|
||||
|
||||
self._event_loop.call_soon_threadsafe(
|
||||
tpartial(self._file_flush_in_thread, name)
|
||||
)
|
||||
|
||||
def _file_flush_in_thread(self, name: str) -> None:
|
||||
try:
|
||||
assert name in ('stdout', 'stderr')
|
||||
|
||||
# Immediately ship whatever chunks we've got.
|
||||
if self._file_chunks[name]:
|
||||
self._ship_file_chunks(name, cancel_ship_task=True)
|
||||
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc(file=self._echofile)
|
||||
|
||||
async def _ship_chunks_task(self, name: str) -> None:
|
||||
self._ship_file_chunks(name, cancel_ship_task=False)
|
||||
|
||||
def _ship_file_chunks(self, name: str, cancel_ship_task: bool) -> None:
|
||||
# Note: Raw print input generally ends in a newline, but that is
|
||||
# redundant when we break things into log entries and results
|
||||
# in extra empty lines. So strip off a single trailing newline.
|
||||
text = ''.join(self._file_chunks[name]).removesuffix('\n')
|
||||
|
||||
self._emit_entry(
|
||||
LogEntry(
|
||||
name=name, message=text, level=LogLevel.INFO, time=utc_now()
|
||||
)
|
||||
)
|
||||
self._file_chunks[name] = []
|
||||
ship_task = self._file_chunk_ship_task[name]
|
||||
if cancel_ship_task and ship_task is not None:
|
||||
ship_task.cancel()
|
||||
self._file_chunk_ship_task[name] = None
|
||||
|
||||
def _emit_entry(self, entry: LogEntry) -> None:
|
||||
assert current_thread() is self._thread
|
||||
|
||||
# Store to our cache.
|
||||
if self._cache_size_limit > 0:
|
||||
with self._cache_lock:
|
||||
# Do a rough calc of how many bytes this entry consumes.
|
||||
entry_size = sum(
|
||||
sys.getsizeof(x)
|
||||
for x in (
|
||||
entry,
|
||||
entry.name,
|
||||
entry.message,
|
||||
entry.level,
|
||||
entry.time,
|
||||
)
|
||||
)
|
||||
self._cache.append((entry_size, entry))
|
||||
self._cache_size += entry_size
|
||||
|
||||
# Prune old until we are back at or under our limit.
|
||||
while self._cache_size > self._cache_size_limit:
|
||||
popped = self._cache.popleft()
|
||||
self._cache_size -= popped[0]
|
||||
self._cache_index_offset += 1
|
||||
|
||||
# Pass to callbacks.
|
||||
with self._callbacks_lock:
|
||||
for call in self._callbacks:
|
||||
try:
|
||||
call(entry)
|
||||
except Exception:
|
||||
# Only print one callback error to avoid insanity.
|
||||
if not self._printed_callback_error:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc(file=self._echofile)
|
||||
self._printed_callback_error = True
|
||||
|
||||
# Dump to our structured log file.
|
||||
# TODO: set a timer for flushing; don't flush every line.
|
||||
if self._file is not None:
|
||||
entry_s = dataclass_to_json(entry)
|
||||
assert '\n' not in entry_s # Make sure its a single line.
|
||||
print(entry_s, file=self._file, flush=True)
|
||||
|
||||
|
||||
class FileLogEcho:
|
||||
"""A file-like object for forwarding stdout/stderr to a LogHandler."""
|
||||
|
||||
def __init__(
|
||||
self, original: TextIO, name: str, handler: LogHandler
|
||||
) -> None:
|
||||
assert name in ('stdout', 'stderr')
|
||||
self._original = original
|
||||
self._name = name
|
||||
self._handler = handler
|
||||
|
||||
def write(self, output: Any) -> None:
|
||||
"""Override standard write call."""
|
||||
self._original.write(output)
|
||||
self._handler.file_write(self._name, output)
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush the file."""
|
||||
self._original.flush()
|
||||
|
||||
# We also use this as a hint to ship whatever file chunks
|
||||
# we've accumulated (we have to try and be smart about breaking
|
||||
# our arbitrary file output into discrete entries).
|
||||
self._handler.file_flush(self._name)
|
||||
|
||||
def isatty(self) -> bool:
|
||||
"""Are we a terminal?"""
|
||||
return self._original.isatty()
|
||||
|
||||
|
||||
def setup_logging(
|
||||
log_path: str | Path | None,
|
||||
level: LogLevel,
|
||||
suppress_non_root_debug: bool = False,
|
||||
log_stdout_stderr: bool = False,
|
||||
echo_to_stderr: bool = True,
|
||||
cache_size_limit: int = 0,
|
||||
cache_time_limit: datetime.timedelta | None = None,
|
||||
) -> LogHandler:
|
||||
"""Set up our logging environment.
|
||||
|
||||
Returns the custom handler which can be used to fetch information
|
||||
about logs that have passed through it. (worst log-levels, caches, etc.).
|
||||
"""
|
||||
|
||||
lmap = {
|
||||
LogLevel.DEBUG: logging.DEBUG,
|
||||
LogLevel.INFO: logging.INFO,
|
||||
LogLevel.WARNING: logging.WARNING,
|
||||
LogLevel.ERROR: logging.ERROR,
|
||||
LogLevel.CRITICAL: logging.CRITICAL,
|
||||
}
|
||||
|
||||
# Wire logger output to go to a structured log file.
|
||||
# Also echo it to stderr IF we're running in a terminal.
|
||||
# UPDATE: Actually gonna always go to stderr. Is there a
|
||||
# reason we shouldn't? This makes debugging possible if all
|
||||
# we have is access to a non-interactive terminal or file dump.
|
||||
# We could add a '--quiet' arg or whatnot to change this behavior.
|
||||
|
||||
# Note: by passing in the *original* stderr here before we
|
||||
# (potentially) replace it, we ensure that our log echos
|
||||
# won't themselves be intercepted and sent to the logger
|
||||
# which would create an infinite loop.
|
||||
loghandler = LogHandler(
|
||||
log_path,
|
||||
echofile=sys.stderr if echo_to_stderr else None,
|
||||
suppress_non_root_debug=suppress_non_root_debug,
|
||||
cache_size_limit=cache_size_limit,
|
||||
cache_time_limit=cache_time_limit,
|
||||
)
|
||||
|
||||
# Note: going ahead with force=True here so that we replace any
|
||||
# existing logger. Though we warn if it looks like we are doing
|
||||
# that so we can try to avoid creating the first one.
|
||||
had_previous_handlers = bool(logging.root.handlers)
|
||||
logging.basicConfig(
|
||||
level=lmap[level],
|
||||
format='%(message)s',
|
||||
handlers=[loghandler],
|
||||
force=True,
|
||||
)
|
||||
if had_previous_handlers:
|
||||
logging.warning('setup_logging: force-replacing previous handlers.')
|
||||
|
||||
# Optionally intercept Python's stdout/stderr output and generate
|
||||
# log entries from it.
|
||||
if log_stdout_stderr:
|
||||
sys.stdout = FileLogEcho( # type: ignore
|
||||
sys.stdout, 'stdout', loghandler
|
||||
)
|
||||
sys.stderr = FileLogEcho( # type: ignore
|
||||
sys.stderr, 'stderr', loghandler
|
||||
)
|
||||
|
||||
return loghandler
|
||||
991
dist/ba_data/python/efro/message.py
vendored
Normal file
991
dist/ba_data/python/efro/message.py
vendored
Normal file
|
|
@ -0,0 +1,991 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for sending and responding to messages.
|
||||
Supports static typing for message types and possible return types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar, Annotated
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import inspect
|
||||
import logging
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from efro.error import CleanError, RemoteError
|
||||
from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs,
|
||||
dataclass_to_dict, dataclass_from_dict)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Optional, Sequence, Union, Awaitable
|
||||
|
||||
TM = TypeVar('TM', bound='MessageSender')
|
||||
|
||||
|
||||
class Message:
|
||||
"""Base class for messages."""
|
||||
|
||||
@classmethod
|
||||
def get_response_types(cls) -> list[type[Response]]:
|
||||
"""Return all message types this Message can result in when sent.
|
||||
|
||||
The default implementation specifies EmptyResponse, so messages with
|
||||
no particular response needs can leave this untouched.
|
||||
Note that ErrorMessage is handled as a special case and does not
|
||||
need to be specified here.
|
||||
"""
|
||||
return [EmptyResponse]
|
||||
|
||||
|
||||
class Response:
|
||||
"""Base class for responses to messages."""
|
||||
|
||||
|
||||
# Some standard response types:
|
||||
|
||||
|
||||
class ErrorType(Enum):
|
||||
"""Type of error that occurred in remote message handling."""
|
||||
OTHER = 0
|
||||
CLEAN = 1
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class ErrorResponse(Response):
|
||||
"""Message saying some error has occurred on the other end.
|
||||
|
||||
This type is unique in that it is not returned to the user; it
|
||||
instead results in a local exception being raised.
|
||||
"""
|
||||
error_message: Annotated[str, IOAttrs('m')]
|
||||
error_type: Annotated[ErrorType, IOAttrs('e')] = ErrorType.OTHER
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class EmptyResponse(Response):
|
||||
"""The response equivalent of None."""
|
||||
|
||||
|
||||
# TODO: could allow handlers to deal in raw values for these
|
||||
# types similar to how we allow None in place of EmptyResponse.
|
||||
# Though not sure if they are widely used enough to warrant the
|
||||
# extra code complexity.
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class BoolResponse(Response):
|
||||
"""A simple bool value response."""
|
||||
|
||||
value: Annotated[bool, IOAttrs('v')]
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class StringResponse(Response):
|
||||
"""A simple string value response."""
|
||||
|
||||
value: Annotated[str, IOAttrs('v')]
|
||||
|
||||
|
||||
class MessageProtocol:
|
||||
"""Wrangles a set of message types, formats, and response types.
|
||||
Both endpoints must be using a compatible Protocol for communication
|
||||
to succeed. To maintain Protocol compatibility between revisions,
|
||||
all message types must retain the same id, message attr storage names must
|
||||
not change, newly added attrs must have default values, etc.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
message_types: dict[int, type[Message]],
|
||||
response_types: dict[int, type[Response]],
|
||||
type_key: Optional[str] = None,
|
||||
preserve_clean_errors: bool = True,
|
||||
log_remote_exceptions: bool = True,
|
||||
trusted_sender: bool = False) -> None:
|
||||
"""Create a protocol with a given configuration.
|
||||
|
||||
Note that common response types are automatically registered
|
||||
with (unchanging negative ids) so they don't need to be passed
|
||||
explicitly (but can be if a different id is desired).
|
||||
|
||||
If 'type_key' is provided, the message type ID is stored as the
|
||||
provided key in the message dict; otherwise it will be stored as
|
||||
part of a top level dict with the message payload appearing as a
|
||||
child dict. This is mainly for backwards compatibility.
|
||||
|
||||
If 'preserve_clean_errors' is True, efro.error.CleanError errors
|
||||
on the remote end will result in the same error raised locally.
|
||||
All other Exception types come across as efro.error.RemoteError.
|
||||
|
||||
If 'trusted_sender' is True, stringified remote stack traces will
|
||||
be included in the responses if errors occur.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
self.message_types_by_id: dict[int, type[Message]] = {}
|
||||
self.message_ids_by_type: dict[type[Message], int] = {}
|
||||
self.response_types_by_id: dict[int, type[Response]] = {}
|
||||
self.response_ids_by_type: dict[type[Response], int] = {}
|
||||
for m_id, m_type in message_types.items():
|
||||
|
||||
# Make sure only valid message types were passed and each
|
||||
# id was assigned only once.
|
||||
assert isinstance(m_id, int)
|
||||
assert m_id >= 0
|
||||
assert (is_ioprepped_dataclass(m_type)
|
||||
and issubclass(m_type, Message))
|
||||
assert self.message_types_by_id.get(m_id) is None
|
||||
self.message_types_by_id[m_id] = m_type
|
||||
self.message_ids_by_type[m_type] = m_id
|
||||
|
||||
for r_id, r_type in response_types.items():
|
||||
assert isinstance(r_id, int)
|
||||
assert r_id >= 0
|
||||
assert (is_ioprepped_dataclass(r_type)
|
||||
and issubclass(r_type, Response))
|
||||
assert self.response_types_by_id.get(r_id) is None
|
||||
self.response_types_by_id[r_id] = r_type
|
||||
self.response_ids_by_type[r_type] = r_id
|
||||
|
||||
# Go ahead and auto-register a few common response types
|
||||
# if the user has not done so explicitly. Use unique IDs which
|
||||
# will never change or overlap with user ids.
|
||||
def _reg_if_not(reg_tp: type[Response], reg_id: int) -> None:
|
||||
if reg_tp in self.response_ids_by_type:
|
||||
return
|
||||
assert self.response_types_by_id.get(reg_id) is None
|
||||
self.response_types_by_id[reg_id] = reg_tp
|
||||
self.response_ids_by_type[reg_tp] = reg_id
|
||||
|
||||
_reg_if_not(ErrorResponse, -1)
|
||||
_reg_if_not(EmptyResponse, -2)
|
||||
# _reg_if_not(BoolResponse, -3)
|
||||
|
||||
# Some extra-thorough validation in debug mode.
|
||||
if __debug__:
|
||||
# Make sure all Message types' return types are valid
|
||||
# and have been assigned an ID as well.
|
||||
all_response_types: set[type[Response]] = set()
|
||||
for m_id, m_type in message_types.items():
|
||||
m_rtypes = m_type.get_response_types()
|
||||
assert isinstance(m_rtypes, list)
|
||||
assert m_rtypes, (
|
||||
f'Message type {m_type} specifies no return types.')
|
||||
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
|
||||
all_response_types.update(m_rtypes)
|
||||
for cls in all_response_types:
|
||||
assert is_ioprepped_dataclass(cls)
|
||||
assert issubclass(cls, Response)
|
||||
if cls not in self.response_ids_by_type:
|
||||
raise ValueError(f'Possible response type {cls}'
|
||||
f' needs to be included in response_types'
|
||||
f' for this protocol.')
|
||||
|
||||
# Make sure all registered types have unique base names.
|
||||
# We can take advantage of this to generate cleaner looking
|
||||
# protocol modules. Can revisit if this is ever a problem.
|
||||
mtypenames = set(tp.__name__ for tp in self.message_ids_by_type)
|
||||
if len(mtypenames) != len(message_types):
|
||||
raise ValueError(
|
||||
'message_types contains duplicate __name__s;'
|
||||
' all types are required to have unique names.')
|
||||
|
||||
self._type_key = type_key
|
||||
self.preserve_clean_errors = preserve_clean_errors
|
||||
self.log_remote_exceptions = log_remote_exceptions
|
||||
self.trusted_sender = trusted_sender
|
||||
|
||||
def encode_message(self, message: Message) -> str:
|
||||
"""Encode a message to a json string for transport."""
|
||||
return self._encode(message, self.message_ids_by_type, 'message')
|
||||
|
||||
def encode_response(self, response: Response) -> str:
|
||||
"""Encode a response to a json string for transport."""
|
||||
return self._encode(response, self.response_ids_by_type, 'response')
|
||||
|
||||
def _encode(self, message: Any, ids_by_type: dict[type, int],
|
||||
opname: str) -> str:
|
||||
"""Encode a message to a json string for transport."""
|
||||
|
||||
m_id: Optional[int] = ids_by_type.get(type(message))
|
||||
if m_id is None:
|
||||
raise TypeError(f'{opname} type is not registered in protocol:'
|
||||
f' {type(message)}')
|
||||
msgdict = dataclass_to_dict(message)
|
||||
|
||||
# Encode type as part of the message/response dict if desired
|
||||
# (for legacy compatibility).
|
||||
if self._type_key is not None:
|
||||
if self._type_key in msgdict:
|
||||
raise RuntimeError(f'Type-key {self._type_key}'
|
||||
f' found in msg of type {type(message)}')
|
||||
msgdict[self._type_key] = m_id
|
||||
out = msgdict
|
||||
else:
|
||||
out = {'m': msgdict, 't': m_id}
|
||||
return json.dumps(out, separators=(',', ':'))
|
||||
|
||||
def decode_message(self, data: str) -> Message:
|
||||
"""Decode a message from a json string."""
|
||||
out = self._decode(data, self.message_types_by_id, 'message')
|
||||
assert isinstance(out, Message)
|
||||
return out
|
||||
|
||||
def decode_response(self, data: str) -> Optional[Response]:
|
||||
"""Decode a response from a json string."""
|
||||
out = self._decode(data, self.response_types_by_id, 'response')
|
||||
assert isinstance(out, (Response, type(None)))
|
||||
return out
|
||||
|
||||
# Weeeird; we get mypy errors returning dict[int, type] but
|
||||
# dict[int, typing.Type] or dict[int, type[Any]] works..
|
||||
def _decode(self, data: str, types_by_id: dict[int, type[Any]],
|
||||
opname: str) -> Any:
|
||||
"""Decode a message from a json string."""
|
||||
msgfull = json.loads(data)
|
||||
assert isinstance(msgfull, dict)
|
||||
msgdict: Optional[dict]
|
||||
if self._type_key is not None:
|
||||
m_id = msgfull.pop(self._type_key)
|
||||
msgdict = msgfull
|
||||
assert isinstance(m_id, int)
|
||||
else:
|
||||
m_id = msgfull.get('t')
|
||||
msgdict = msgfull.get('m')
|
||||
assert isinstance(m_id, int)
|
||||
assert isinstance(msgdict, dict)
|
||||
|
||||
# Decode this particular type.
|
||||
msgtype = types_by_id.get(m_id)
|
||||
if msgtype is None:
|
||||
raise TypeError(f'Got unregistered {opname} type id of {m_id}.')
|
||||
out = dataclass_from_dict(msgtype, msgdict)
|
||||
|
||||
# Special case: if we get EmptyResponse, we simply return None.
|
||||
if isinstance(out, EmptyResponse):
|
||||
return None
|
||||
|
||||
# Special case: a remote error occurred. Raise a local Exception
|
||||
# instead of returning the message.
|
||||
if isinstance(out, ErrorResponse):
|
||||
assert opname == 'response'
|
||||
if (self.preserve_clean_errors
|
||||
and out.error_type is ErrorType.CLEAN):
|
||||
raise CleanError(out.error_message)
|
||||
raise RemoteError(out.error_message)
|
||||
|
||||
return out
|
||||
|
||||
def _get_module_header(self, part: str) -> str:
|
||||
"""Return common parts of generated modules."""
|
||||
# pylint: disable=too-many-locals, too-many-branches
|
||||
import textwrap
|
||||
tpimports: dict[str, list[str]] = {}
|
||||
imports: dict[str, list[str]] = {}
|
||||
|
||||
single_message_type = len(self.message_ids_by_type) == 1
|
||||
|
||||
# Always import messages
|
||||
for msgtype in list(self.message_ids_by_type) + [Message]:
|
||||
tpimports.setdefault(msgtype.__module__,
|
||||
[]).append(msgtype.__name__)
|
||||
for rsp_tp in list(self.response_ids_by_type) + [Response]:
|
||||
# Skip these as they don't actually show up in code.
|
||||
if rsp_tp is EmptyResponse or rsp_tp is ErrorResponse:
|
||||
continue
|
||||
if (single_message_type and part == 'sender'
|
||||
and rsp_tp is not Response):
|
||||
# We need to cast to the single supported response type
|
||||
# in this case so need response types at runtime.
|
||||
imports.setdefault(rsp_tp.__module__,
|
||||
[]).append(rsp_tp.__name__)
|
||||
else:
|
||||
tpimports.setdefault(rsp_tp.__module__,
|
||||
[]).append(rsp_tp.__name__)
|
||||
|
||||
import_lines = ''
|
||||
tpimport_lines = ''
|
||||
|
||||
for module, names in sorted(imports.items()):
|
||||
jnames = ', '.join(names)
|
||||
line = f'from {module} import {jnames}'
|
||||
if len(line) > 79:
|
||||
# Recreate in a wrapping-friendly form.
|
||||
line = f'from {module} import ({jnames})'
|
||||
import_lines += f'{line}\n'
|
||||
for module, names in sorted(tpimports.items()):
|
||||
jnames = ', '.join(names)
|
||||
line = f'from {module} import {jnames}'
|
||||
if len(line) > 75: # Account for indent
|
||||
# Recreate in a wrapping-friendly form.
|
||||
line = f'from {module} import ({jnames})'
|
||||
tpimport_lines += f'{line}\n'
|
||||
|
||||
if part == 'sender':
|
||||
import_lines += ('from efro.message import MessageSender,'
|
||||
' BoundMessageSender')
|
||||
tpimport_typing_extras = ''
|
||||
else:
|
||||
if single_message_type:
|
||||
import_lines += ('from efro.message import (MessageReceiver,'
|
||||
' BoundMessageReceiver, Message, Response)')
|
||||
else:
|
||||
import_lines += ('from efro.message import MessageReceiver,'
|
||||
' BoundMessageReceiver')
|
||||
tpimport_typing_extras = ', Awaitable'
|
||||
|
||||
ovld = ', overload' if not single_message_type else ''
|
||||
tpimport_lines = textwrap.indent(tpimport_lines, ' ')
|
||||
out = ('# Released under the MIT License. See LICENSE for details.\n'
|
||||
f'#\n'
|
||||
f'"""Auto-generated {part} module. Do not edit by hand."""\n'
|
||||
f'\n'
|
||||
f'from __future__ import annotations\n'
|
||||
f'\n'
|
||||
f'from typing import TYPE_CHECKING{ovld}\n'
|
||||
f'\n'
|
||||
f'{import_lines}\n'
|
||||
f'\n'
|
||||
f'if TYPE_CHECKING:\n'
|
||||
f' from typing import Union, Any, Optional, Callable'
|
||||
f'{tpimport_typing_extras}\n'
|
||||
f'{tpimport_lines}'
|
||||
f'\n'
|
||||
f'\n')
|
||||
return out
|
||||
|
||||
def do_create_sender_module(self,
|
||||
basename: str,
|
||||
protocol_create_code: str,
|
||||
enable_sync_sends: bool,
|
||||
enable_async_sends: bool,
|
||||
private: bool = False) -> str:
|
||||
"""Used by create_sender_module(); do not call directly."""
|
||||
# pylint: disable=too-many-locals
|
||||
import textwrap
|
||||
|
||||
msgtypes = list(self.message_ids_by_type.keys())
|
||||
|
||||
ppre = '_' if private else ''
|
||||
out = self._get_module_header('sender')
|
||||
ccind = textwrap.indent(protocol_create_code, ' ')
|
||||
out += (f'class {ppre}{basename}(MessageSender):\n'
|
||||
f' """Protocol-specific sender."""\n'
|
||||
f'\n'
|
||||
f' def __init__(self) -> None:\n'
|
||||
f'{ccind}\n'
|
||||
f' super().__init__(protocol)\n'
|
||||
f'\n'
|
||||
f' def __get__(self,\n'
|
||||
f' obj: Any,\n'
|
||||
f' type_in: Any = None)'
|
||||
f' -> {ppre}Bound{basename}:\n'
|
||||
f' return {ppre}Bound{basename}'
|
||||
f'(obj, self)\n'
|
||||
f'\n'
|
||||
f'\n'
|
||||
f'class {ppre}Bound{basename}(BoundMessageSender):\n'
|
||||
f' """Protocol-specific bound sender."""\n')
|
||||
|
||||
def _filt_tp_name(rtype: type[Response]) -> str:
|
||||
# We accept None to equal EmptyResponse so reflect that
|
||||
# in the type annotation.
|
||||
return 'None' if rtype is EmptyResponse else rtype.__name__
|
||||
|
||||
# Define handler() overloads for all registered message types.
|
||||
if msgtypes:
|
||||
for async_pass in False, True:
|
||||
if async_pass and not enable_async_sends:
|
||||
continue
|
||||
if not async_pass and not enable_sync_sends:
|
||||
continue
|
||||
pfx = 'async ' if async_pass else ''
|
||||
sfx = '_async' if async_pass else ''
|
||||
awt = 'await ' if async_pass else ''
|
||||
how = 'asynchronously' if async_pass else 'synchronously'
|
||||
|
||||
if len(msgtypes) == 1:
|
||||
# Special case: with a single message types we don't
|
||||
# use overloads.
|
||||
msgtype = msgtypes[0]
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
|
||||
rtypevar = f'Union[{tps}]'
|
||||
else:
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
out += (f'\n'
|
||||
f' {pfx}def send{sfx}(self,'
|
||||
f' message: {msgtypevar})'
|
||||
f' -> {rtypevar}:\n'
|
||||
f' """Send a message {how}."""\n'
|
||||
f' out = {awt}self._sender.'
|
||||
f'send{sfx}(self._obj, message)\n'
|
||||
f' assert isinstance(out, {rtypevar})\n'
|
||||
f' return out\n')
|
||||
else:
|
||||
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
|
||||
rtypevar = f'Union[{tps}]'
|
||||
else:
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
out += (f'\n'
|
||||
f' @overload\n'
|
||||
f' {pfx}def send{sfx}(self,'
|
||||
f' message: {msgtypevar})'
|
||||
f' -> {rtypevar}:\n'
|
||||
f' ...\n')
|
||||
out += (f'\n'
|
||||
f' {pfx}def send{sfx}(self, message: Message)'
|
||||
f' -> Optional[Response]:\n'
|
||||
f' """Send a message {how}."""\n'
|
||||
f' return {awt}self._sender.'
|
||||
f'send{sfx}(self._obj, message)\n')
|
||||
|
||||
return out
|
||||
|
||||
def do_create_receiver_module(self,
|
||||
basename: str,
|
||||
protocol_create_code: str,
|
||||
is_async: bool,
|
||||
private: bool = False) -> str:
|
||||
"""Used by create_receiver_module(); do not call directly."""
|
||||
# pylint: disable=too-many-locals
|
||||
import textwrap
|
||||
|
||||
desc = 'asynchronous' if is_async else 'synchronous'
|
||||
ppre = '_' if private else ''
|
||||
msgtypes = list(self.message_ids_by_type.keys())
|
||||
out = self._get_module_header('receiver')
|
||||
ccind = textwrap.indent(protocol_create_code, ' ')
|
||||
out += (f'class {ppre}{basename}(MessageReceiver):\n'
|
||||
f' """Protocol-specific {desc} receiver."""\n'
|
||||
f'\n'
|
||||
f' is_async = {is_async}\n'
|
||||
f'\n'
|
||||
f' def __init__(self) -> None:\n'
|
||||
f'{ccind}\n'
|
||||
f' super().__init__(protocol)\n'
|
||||
f'\n'
|
||||
f' def __get__(\n'
|
||||
f' self,\n'
|
||||
f' obj: Any,\n'
|
||||
f' type_in: Any = None,\n'
|
||||
f' ) -> {ppre}Bound{basename}:\n'
|
||||
f' return {ppre}Bound{basename}('
|
||||
f'obj, self)\n')
|
||||
|
||||
# Define handler() overloads for all registered message types.
|
||||
|
||||
def _filt_tp_name(rtype: type[Response]) -> str:
|
||||
# We accept None to equal EmptyResponse so reflect that
|
||||
# in the type annotation.
|
||||
return 'None' if rtype is EmptyResponse else rtype.__name__
|
||||
|
||||
if msgtypes:
|
||||
cbgn = 'Awaitable[' if is_async else ''
|
||||
cend = ']' if is_async else ''
|
||||
if len(msgtypes) == 1:
|
||||
# Special case: when we have a single message type we don't
|
||||
# use overloads.
|
||||
msgtype = msgtypes[0]
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
|
||||
rtypevar = f'Union[{tps}]'
|
||||
else:
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
rtypevar = f'{cbgn}{rtypevar}{cend}'
|
||||
out += (
|
||||
f'\n'
|
||||
f' def handler(\n'
|
||||
f' self,\n'
|
||||
f' call: Callable[[Any, {msgtypevar}], '
|
||||
f'{rtypevar}],\n'
|
||||
f' )'
|
||||
f' -> Callable[[Any, {msgtypevar}], {rtypevar}]:\n'
|
||||
f' """Decorator to register message handlers."""\n'
|
||||
f' from typing import cast, Callable, Any\n'
|
||||
f' self.register_handler(cast(Callable'
|
||||
f'[[Any, Message], Response], call))\n'
|
||||
f' return call\n')
|
||||
else:
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
|
||||
rtypevar = f'Union[{tps}]'
|
||||
else:
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
rtypevar = f'{cbgn}{rtypevar}{cend}'
|
||||
out += (f'\n'
|
||||
f' @overload\n'
|
||||
f' def handler(\n'
|
||||
f' self,\n'
|
||||
f' call: Callable[[Any, {msgtypevar}], '
|
||||
f'{rtypevar}],\n'
|
||||
f' )'
|
||||
f' -> Callable[[Any, {msgtypevar}], {rtypevar}]:\n'
|
||||
f' ...\n')
|
||||
out += (
|
||||
'\n'
|
||||
' def handler(self, call: Callable) -> Callable:\n'
|
||||
' """Decorator to register message handlers."""\n'
|
||||
' self.register_handler(call)\n'
|
||||
' return call\n')
|
||||
|
||||
out += (f'\n'
|
||||
f'\n'
|
||||
f'class {ppre}Bound{basename}(BoundMessageReceiver):\n'
|
||||
f' """Protocol-specific bound receiver."""\n')
|
||||
if is_async:
|
||||
out += (
|
||||
'\n'
|
||||
' async def handle_raw_message(self, message: str)'
|
||||
' -> str:\n'
|
||||
' """Asynchronously handle a raw incoming message."""\n'
|
||||
' return await'
|
||||
' self._receiver.handle_raw_message_async(\n'
|
||||
' self._obj, message)\n')
|
||||
else:
|
||||
out += (
|
||||
'\n'
|
||||
' def handle_raw_message(self, message: str) -> str:\n'
|
||||
' """Synchronously handle a raw incoming message."""\n'
|
||||
' return self._receiver.handle_raw_message'
|
||||
'(self._obj, message)\n')
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class MessageSender:
|
||||
"""Facilitates sending messages to a target and receiving responses.
|
||||
This is instantiated at the class level and used to register unbound
|
||||
class methods to handle raw message sending.
|
||||
|
||||
Example:
|
||||
|
||||
class MyClass:
|
||||
msg = MyMessageSender(some_protocol)
|
||||
|
||||
@msg.send_method
|
||||
def send_raw_message(self, message: str) -> str:
|
||||
# Actually send the message here.
|
||||
|
||||
# MyMessageSender class should provide overloads for send(), send_bg(),
|
||||
# etc. to ensure all sending happens with valid types.
|
||||
obj = MyClass()
|
||||
obj.msg.send(SomeMessageType())
|
||||
"""
|
||||
|
||||
def __init__(self, protocol: MessageProtocol) -> None:
|
||||
self.protocol = protocol
|
||||
self._send_raw_message_call: Optional[Callable[[Any, str], str]] = None
|
||||
self._send_async_raw_message_call: Optional[Callable[
|
||||
[Any, str], Awaitable[str]]] = None
|
||||
|
||||
def send_method(
|
||||
self, call: Callable[[Any, str],
|
||||
str]) -> Callable[[Any, str], str]:
|
||||
"""Function decorator for setting raw send method."""
|
||||
assert self._send_raw_message_call is None
|
||||
self._send_raw_message_call = call
|
||||
return call
|
||||
|
||||
def send_async_method(
|
||||
self, call: Callable[[Any, str], Awaitable[str]]
|
||||
) -> Callable[[Any, str], Awaitable[str]]:
|
||||
"""Function decorator for setting raw send-async method."""
|
||||
assert self._send_async_raw_message_call is None
|
||||
self._send_async_raw_message_call = call
|
||||
return call
|
||||
|
||||
def send(self, bound_obj: Any, message: Message) -> Optional[Response]:
|
||||
"""Send a message and receive a response.
|
||||
|
||||
Will encode the message for transport and call dispatch_raw_message()
|
||||
"""
|
||||
if self._send_raw_message_call is None:
|
||||
raise RuntimeError('send() is unimplemented for this type.')
|
||||
|
||||
msg_encoded = self.protocol.encode_message(message)
|
||||
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
|
||||
response = self.protocol.decode_response(response_encoded)
|
||||
assert isinstance(response, (Response, type(None)))
|
||||
assert (response is None
|
||||
or type(response) in type(message).get_response_types())
|
||||
return response
|
||||
|
||||
async def send_async(self, bound_obj: Any,
|
||||
message: Message) -> Optional[Response]:
|
||||
"""Send a message asynchronously using asyncio.
|
||||
|
||||
The message will be encoded for transport and passed to
|
||||
dispatch_raw_message_async.
|
||||
"""
|
||||
if self._send_async_raw_message_call is None:
|
||||
raise RuntimeError('send_async() is unimplemented for this type.')
|
||||
|
||||
msg_encoded = self.protocol.encode_message(message)
|
||||
response_encoded = await self._send_async_raw_message_call(
|
||||
bound_obj, msg_encoded)
|
||||
response = self.protocol.decode_response(response_encoded)
|
||||
assert isinstance(response, (Response, type(None)))
|
||||
assert (response is None
|
||||
or type(response) in type(message).get_response_types())
|
||||
return response
|
||||
|
||||
|
||||
class BoundMessageSender:
|
||||
"""Base class for bound senders."""
|
||||
|
||||
def __init__(self, obj: Any, sender: MessageSender) -> None:
|
||||
assert obj is not None
|
||||
self._obj = obj
|
||||
self._sender = sender
|
||||
|
||||
@property
|
||||
def protocol(self) -> MessageProtocol:
|
||||
"""Protocol associated with this sender."""
|
||||
return self._sender.protocol
|
||||
|
||||
def send_untyped(self, message: Message) -> Optional[Response]:
|
||||
"""Send a message synchronously.
|
||||
|
||||
Whenever possible, use the send() call provided by generated
|
||||
subclasses instead of this; it will provide better type safety.
|
||||
"""
|
||||
return self._sender.send(self._obj, message)
|
||||
|
||||
async def send_async_untyped(self, message: Message) -> Optional[Response]:
|
||||
"""Send a message asynchronously.
|
||||
|
||||
Whenever possible, use the send_async() call provided by generated
|
||||
subclasses instead of this; it will provide better type safety.
|
||||
"""
|
||||
return await self._sender.send_async(self._obj, message)
|
||||
|
||||
|
||||
class MessageReceiver:
|
||||
"""Facilitates receiving & responding to messages from a remote source.
|
||||
|
||||
This is instantiated at the class level with unbound methods registered
|
||||
as handlers for different message types in the protocol.
|
||||
|
||||
Example:
|
||||
|
||||
class MyClass:
|
||||
receiver = MyMessageReceiver()
|
||||
|
||||
# MyMessageReceiver fills out handler() overloads to ensure all
|
||||
# registered handlers have valid types/return-types.
|
||||
@receiver.handler
|
||||
def handle_some_message_type(self, message: SomeMsg) -> SomeResponse:
|
||||
# Deal with this message type here.
|
||||
|
||||
# This will trigger the registered handler being called.
|
||||
obj = MyClass()
|
||||
obj.receiver.handle_raw_message(some_raw_data)
|
||||
|
||||
Any unhandled Exception occurring during message handling will result in
|
||||
an Exception being raised on the sending end.
|
||||
"""
|
||||
|
||||
is_async = False
|
||||
|
||||
def __init__(self, protocol: MessageProtocol) -> None:
|
||||
self.protocol = protocol
|
||||
self._handlers: dict[type[Message], Callable] = {}
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def register_handler(
|
||||
self, call: Callable[[Any, Message], Optional[Response]]) -> None:
|
||||
"""Register a handler call.
|
||||
|
||||
The message type handled by the call is determined by its
|
||||
type annotation.
|
||||
"""
|
||||
# TODO: can use types.GenericAlias in 3.9.
|
||||
from typing import _GenericAlias # type: ignore
|
||||
from typing import get_type_hints, get_args
|
||||
|
||||
sig = inspect.getfullargspec(call)
|
||||
|
||||
# The provided callable should be a method taking one 'msg' arg.
|
||||
expectedsig = ['self', 'msg']
|
||||
if sig.args != expectedsig:
|
||||
raise ValueError(f'Expected callable signature of {expectedsig};'
|
||||
f' got {sig.args}')
|
||||
|
||||
# Make sure we are only given async methods if we are an async handler
|
||||
# and sync ones otherwise.
|
||||
is_async = inspect.iscoroutinefunction(call)
|
||||
if self.is_async != is_async:
|
||||
msg = ('Expected a sync method; found an async one.' if is_async
|
||||
else 'Expected an async method; found a sync one.')
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check annotation types to determine what message types we handle.
|
||||
# Return-type annotation can be a Union, but we probably don't
|
||||
# have it available at runtime. Explicitly pull it in.
|
||||
# UPDATE: we've updated our pylint filter to where we should
|
||||
# have all annotations available.
|
||||
# anns = get_type_hints(call, localns={'Union': Union})
|
||||
anns = get_type_hints(call)
|
||||
|
||||
msgtype = anns.get('msg')
|
||||
if not isinstance(msgtype, type):
|
||||
raise TypeError(
|
||||
f'expected a type for "msg" annotation; got {type(msgtype)}.')
|
||||
assert issubclass(msgtype, Message)
|
||||
|
||||
ret = anns.get('return')
|
||||
responsetypes: tuple[Union[type[Any], type[None]], ...]
|
||||
|
||||
# Return types can be a single type or a union of types.
|
||||
if isinstance(ret, _GenericAlias):
|
||||
targs = get_args(ret)
|
||||
if not all(isinstance(a, type) for a in targs):
|
||||
raise TypeError(f'expected only types for "return" annotation;'
|
||||
f' got {targs}.')
|
||||
responsetypes = targs
|
||||
else:
|
||||
if not isinstance(ret, type):
|
||||
raise TypeError(f'expected one or more types for'
|
||||
f' "return" annotation; got a {type(ret)}.')
|
||||
responsetypes = (ret, )
|
||||
|
||||
# Return type of None translates to EmptyResponse.
|
||||
responsetypes = tuple(EmptyResponse if r is type(None) else r
|
||||
for r in responsetypes) # noqa
|
||||
|
||||
# Make sure our protocol has this message type registered and our
|
||||
# return types exactly match. (Technically we could return a subset
|
||||
# of the supported types; can allow this in the future if it makes
|
||||
# sense).
|
||||
registered_types = self.protocol.message_ids_by_type.keys()
|
||||
|
||||
if msgtype not in registered_types:
|
||||
raise TypeError(f'Message type {msgtype} is not registered'
|
||||
f' in this Protocol.')
|
||||
|
||||
if msgtype in self._handlers:
|
||||
raise TypeError(f'Message type {msgtype} already has a registered'
|
||||
f' handler.')
|
||||
|
||||
# Make sure the responses exactly matches what the message expects.
|
||||
if set(responsetypes) != set(msgtype.get_response_types()):
|
||||
raise TypeError(
|
||||
f'Provided response types {responsetypes} do not'
|
||||
f' match the set expected by message type {msgtype}: '
|
||||
f'({msgtype.get_response_types()})')
|
||||
|
||||
# Ok; we're good!
|
||||
self._handlers[msgtype] = call
|
||||
|
||||
def validate(self, warn_only: bool = False) -> None:
|
||||
"""Check for handler completeness, valid types, etc."""
|
||||
for msgtype in self.protocol.message_ids_by_type.keys():
|
||||
if issubclass(msgtype, Response):
|
||||
continue
|
||||
if msgtype not in self._handlers:
|
||||
msg = (f'Protocol message type {msgtype} is not handled'
|
||||
f' by receiver type {type(self)}.')
|
||||
if warn_only:
|
||||
logging.warning(msg)
|
||||
else:
|
||||
raise TypeError(msg)
|
||||
|
||||
def _decode_incoming_message(self,
|
||||
msg: str) -> tuple[Message, type[Message]]:
|
||||
# Decode the incoming message.
|
||||
msg_decoded = self.protocol.decode_message(msg)
|
||||
msgtype = type(msg_decoded)
|
||||
assert issubclass(msgtype, Message)
|
||||
return msg_decoded, msgtype
|
||||
|
||||
def _encode_response(self, response: Optional[Response],
|
||||
msgtype: type[Message]) -> str:
|
||||
|
||||
# A return value of None equals EmptyResponse.
|
||||
if response is None:
|
||||
response = EmptyResponse()
|
||||
|
||||
# Re-encode the response.
|
||||
assert isinstance(response, Response)
|
||||
# (user should never explicitly return these)
|
||||
assert not isinstance(response, ErrorResponse)
|
||||
assert type(response) in msgtype.get_response_types()
|
||||
return self.protocol.encode_response(response)
|
||||
|
||||
def raw_response_for_error(self, exc: Exception) -> str:
|
||||
"""Return a raw response for an error that occurred during handling."""
|
||||
if self.protocol.log_remote_exceptions:
|
||||
logging.exception('Error handling message.')
|
||||
|
||||
# If anything goes wrong, return a ErrorResponse instead.
|
||||
if (isinstance(exc, CleanError)
|
||||
and self.protocol.preserve_clean_errors):
|
||||
err_response = ErrorResponse(error_message=str(exc),
|
||||
error_type=ErrorType.CLEAN)
|
||||
else:
|
||||
err_response = ErrorResponse(
|
||||
error_message=(traceback.format_exc()
|
||||
if self.protocol.trusted_sender else
|
||||
'An unknown error has occurred.'),
|
||||
error_type=ErrorType.OTHER)
|
||||
return self.protocol.encode_response(err_response)
|
||||
|
||||
def handle_raw_message(self, bound_obj: Any, msg: str) -> str:
|
||||
"""Decode, handle, and return an response for a message."""
|
||||
assert not self.is_async, "can't call sync handler on async receiver"
|
||||
try:
|
||||
msg_decoded, msgtype = self._decode_incoming_message(msg)
|
||||
handler = self._handlers.get(msgtype)
|
||||
if handler is None:
|
||||
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
|
||||
result = handler(bound_obj, msg_decoded)
|
||||
return self._encode_response(result, msgtype)
|
||||
|
||||
except Exception as exc:
|
||||
return self.raw_response_for_error(exc)
|
||||
|
||||
async def handle_raw_message_async(self, bound_obj: Any, msg: str) -> str:
|
||||
"""Should be called when the receiver gets a message.
|
||||
|
||||
The return value is the raw response to the message.
|
||||
"""
|
||||
assert self.is_async, "can't call async handler on sync receiver"
|
||||
try:
|
||||
msg_decoded, msgtype = self._decode_incoming_message(msg)
|
||||
handler = self._handlers.get(msgtype)
|
||||
if handler is None:
|
||||
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
|
||||
result = await handler(bound_obj, msg_decoded)
|
||||
return self._encode_response(result, msgtype)
|
||||
|
||||
except Exception as exc:
|
||||
return self.raw_response_for_error(exc)
|
||||
|
||||
|
||||
class BoundMessageReceiver:
|
||||
"""Base bound receiver class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obj: Any,
|
||||
receiver: MessageReceiver,
|
||||
) -> None:
|
||||
assert obj is not None
|
||||
self._obj = obj
|
||||
self._receiver = receiver
|
||||
|
||||
@property
|
||||
def protocol(self) -> MessageProtocol:
|
||||
"""Protocol associated with this receiver."""
|
||||
return self._receiver.protocol
|
||||
|
||||
def raw_response_for_error(self, exc: Exception) -> str:
|
||||
"""Return a raw response for an error that occurred during handling.
|
||||
|
||||
This is automatically called from standard handle_raw_message_x()
|
||||
calls but can be manually invoked if errors occur outside of there.
|
||||
This gives clients a better idea of what went wrong vs simply
|
||||
returning invalid data which they might dismiss as a connection
|
||||
related error.
|
||||
"""
|
||||
return self._receiver.raw_response_for_error(exc)
|
||||
|
||||
|
||||
def create_sender_module(basename: str,
|
||||
protocol_create_code: str,
|
||||
enable_sync_sends: bool,
|
||||
enable_async_sends: bool,
|
||||
private: bool = False) -> str:
|
||||
"""Create a Python module defining a MessageSender subclass.
|
||||
|
||||
This class is primarily for type checking and will contain overrides
|
||||
for the varieties of send calls for message/response types defined
|
||||
in the protocol.
|
||||
|
||||
Code passed for 'protocol_create_code' should import necessary
|
||||
modules and assign an instance of the Protocol to a 'protocol'
|
||||
variable.
|
||||
|
||||
Class names are based on basename; a basename 'FooSender' will
|
||||
result in classes FooSender and BoundFooSender.
|
||||
|
||||
If 'private' is True, class-names will be prefixed with an '_'.
|
||||
|
||||
Note that line lengths are not clipped, so output may need to be
|
||||
run through a formatter to prevent lint warnings about excessive
|
||||
line lengths.
|
||||
"""
|
||||
|
||||
# Exec the passed code to get a protocol which we then use to
|
||||
# generate module code. The user could simply call
|
||||
# MessageProtocol.do_create_sender_module() directly, but this allows
|
||||
# us to verify that the create code works and yields the protocol used
|
||||
# to generate the code.
|
||||
protocol = _protocol_from_code(protocol_create_code)
|
||||
return protocol.do_create_sender_module(
|
||||
basename=basename,
|
||||
protocol_create_code=protocol_create_code,
|
||||
enable_sync_sends=enable_sync_sends,
|
||||
enable_async_sends=enable_async_sends,
|
||||
private=private)
|
||||
|
||||
|
||||
def create_receiver_module(basename: str,
|
||||
protocol_create_code: str,
|
||||
is_async: bool,
|
||||
private: bool = False) -> str:
|
||||
""""Create a Python module defining a MessageReceiver subclass.
|
||||
|
||||
This class is primarily for type checking and will contain overrides
|
||||
for the register method for message/response types defined in
|
||||
the protocol.
|
||||
|
||||
Class names are based on basename; a basename 'FooReceiver' will
|
||||
result in FooReceiver and BoundFooReceiver.
|
||||
|
||||
If 'is_async' is True, handle_raw_message() will be an async method
|
||||
and the @handler decorator will expect async methods.
|
||||
|
||||
If 'private' is True, class-names will be prefixed with an '_'.
|
||||
|
||||
Note that line lengths are not clipped, so output may need to be
|
||||
run through a formatter to prevent lint warnings about excessive
|
||||
line lengths.
|
||||
"""
|
||||
# Exec the passed code to get a protocol which we then use to
|
||||
# generate module code. The user could simply call
|
||||
# MessageProtocol.do_create_sender_module() directly, but this allows
|
||||
# us to verify that the create code works and yields the protocol used
|
||||
# to generate the code.
|
||||
protocol = _protocol_from_code(protocol_create_code)
|
||||
return protocol.do_create_receiver_module(
|
||||
basename=basename,
|
||||
protocol_create_code=protocol_create_code,
|
||||
is_async=is_async,
|
||||
private=private)
|
||||
|
||||
|
||||
def _protocol_from_code(protocol_create_code: str) -> MessageProtocol:
|
||||
env: dict = {}
|
||||
exec(protocol_create_code, env) # pylint: disable=exec-used
|
||||
protocol = env.get('protocol')
|
||||
if not isinstance(protocol, MessageProtocol):
|
||||
raise RuntimeError(
|
||||
f'protocol_create_code yielded'
|
||||
f' a {type(protocol)}; expected a MessageProtocol instance.')
|
||||
return protocol
|
||||
45
dist/ba_data/python/efro/message/__init__.py
vendored
Normal file
45
dist/ba_data/python/efro/message/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for sending and responding to messages.
|
||||
Supports static typing for message types and possible return types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from efro.util import set_canonical_module
|
||||
from efro.message._protocol import MessageProtocol
|
||||
from efro.message._sender import MessageSender, BoundMessageSender
|
||||
from efro.message._receiver import MessageReceiver, BoundMessageReceiver
|
||||
from efro.message._module import create_sender_module, create_receiver_module
|
||||
from efro.message._message import (
|
||||
Message,
|
||||
Response,
|
||||
SysResponse,
|
||||
EmptySysResponse,
|
||||
ErrorSysResponse,
|
||||
StringResponse,
|
||||
BoolResponse,
|
||||
UnregisteredMessageIDError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Message',
|
||||
'Response',
|
||||
'SysResponse',
|
||||
'EmptySysResponse',
|
||||
'ErrorSysResponse',
|
||||
'StringResponse',
|
||||
'BoolResponse',
|
||||
'MessageProtocol',
|
||||
'MessageSender',
|
||||
'BoundMessageSender',
|
||||
'MessageReceiver',
|
||||
'BoundMessageReceiver',
|
||||
'create_sender_module',
|
||||
'create_receiver_module',
|
||||
'UnregisteredMessageIDError',
|
||||
]
|
||||
|
||||
# Have these things present themselves cleanly as 'thismodule.SomeClass'
|
||||
# instead of 'thismodule._internalmodule.SomeClass'
|
||||
set_canonical_module(module_globals=globals(), names=__all__)
|
||||
BIN
dist/ba_data/python/efro/message/__pycache__/__init__.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/message/__pycache__/__init__.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/message/__pycache__/_message.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/message/__pycache__/_message.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/message/__pycache__/_module.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/message/__pycache__/_module.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/message/__pycache__/_protocol.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/message/__pycache__/_protocol.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/message/__pycache__/_receiver.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/message/__pycache__/_receiver.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/message/__pycache__/_sender.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/message/__pycache__/_sender.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
108
dist/ba_data/python/efro/message/_message.py
vendored
Normal file
108
dist/ba_data/python/efro/message/_message.py
vendored
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for sending and responding to messages.
|
||||
Supports static typing for message types and possible return types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from efro.dataclassio import ioprepped, IOAttrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class UnregisteredMessageIDError(Exception):
|
||||
"""A message or response id is not covered by our protocol."""
|
||||
|
||||
|
||||
class Message:
|
||||
"""Base class for messages."""
|
||||
|
||||
@classmethod
|
||||
def get_response_types(cls) -> list[type[Response] | None]:
|
||||
"""Return all Response types this Message can return when sent.
|
||||
|
||||
The default implementation specifies a None return type.
|
||||
"""
|
||||
return [None]
|
||||
|
||||
|
||||
class Response:
|
||||
"""Base class for responses to messages."""
|
||||
|
||||
|
||||
class SysResponse:
|
||||
"""Base class for system-responses to messages.
|
||||
|
||||
These are only sent/handled by the messaging system itself;
|
||||
users of the api never see them.
|
||||
"""
|
||||
|
||||
def set_local_exception(self, exc: Exception) -> None:
|
||||
"""Attach a local exception to facilitate better logging/handling.
|
||||
|
||||
Be aware that this data does not get serialized and only
|
||||
exists on the local object.
|
||||
"""
|
||||
setattr(self, '_sr_local_exception', exc)
|
||||
|
||||
def get_local_exception(self) -> Exception | None:
|
||||
"""Fetch a local attached exception."""
|
||||
value = getattr(self, '_sr_local_exception', None)
|
||||
assert isinstance(value, Exception | None)
|
||||
return value
|
||||
|
||||
|
||||
# Some standard response types:
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class ErrorSysResponse(SysResponse):
|
||||
"""SysResponse saying some error has occurred for the send.
|
||||
|
||||
This generally results in an Exception being raised for the caller.
|
||||
"""
|
||||
|
||||
class ErrorType(Enum):
|
||||
"""Type of error that occurred while sending a message."""
|
||||
|
||||
REMOTE = 0
|
||||
REMOTE_CLEAN = 1
|
||||
LOCAL = 2
|
||||
COMMUNICATION = 3
|
||||
REMOTE_COMMUNICATION = 4
|
||||
|
||||
error_message: Annotated[str, IOAttrs('m')]
|
||||
error_type: Annotated[ErrorType, IOAttrs('e')] = ErrorType.REMOTE
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class EmptySysResponse(SysResponse):
|
||||
"""The response equivalent of None."""
|
||||
|
||||
|
||||
# TODO: could allow handlers to deal in raw values for these
|
||||
# types similar to how we allow None in place of EmptySysResponse.
|
||||
# Though not sure if they are widely used enough to warrant the
|
||||
# extra code complexity.
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class BoolResponse(Response):
|
||||
"""A simple bool value response."""
|
||||
|
||||
value: Annotated[bool, IOAttrs('v')]
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class StringResponse(Response):
|
||||
"""A simple string value response."""
|
||||
|
||||
value: Annotated[str, IOAttrs('v')]
|
||||
109
dist/ba_data/python/efro/message/_module.py
vendored
Normal file
109
dist/ba_data/python/efro/message/_module.py
vendored
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for sending and responding to messages.
|
||||
Supports static typing for message types and possible return types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from efro.message._protocol import MessageProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def create_sender_module(
|
||||
basename: str,
|
||||
protocol_create_code: str,
|
||||
enable_sync_sends: bool,
|
||||
enable_async_sends: bool,
|
||||
private: bool = False,
|
||||
protocol_module_level_import_code: str | None = None,
|
||||
build_time_protocol_create_code: str | None = None,
|
||||
) -> str:
|
||||
"""Create a Python module defining a MessageSender subclass.
|
||||
|
||||
This class is primarily for type checking and will contain overrides
|
||||
for the varieties of send calls for message/response types defined
|
||||
in the protocol.
|
||||
|
||||
Code passed for 'protocol_create_code' should import necessary
|
||||
modules and assign an instance of the Protocol to a 'protocol'
|
||||
variable.
|
||||
|
||||
Class names are based on basename; a basename 'FooSender' will
|
||||
result in classes FooSender and BoundFooSender.
|
||||
|
||||
If 'private' is True, class-names will be prefixed with an '_'.
|
||||
|
||||
Note: output code may have long lines and should generally be run
|
||||
through a formatter. We should perhaps move this functionality to
|
||||
efrotools so we can include that functionality inline.
|
||||
"""
|
||||
protocol = _protocol_from_code(
|
||||
build_time_protocol_create_code
|
||||
if build_time_protocol_create_code is not None
|
||||
else protocol_create_code
|
||||
)
|
||||
return protocol.do_create_sender_module(
|
||||
basename=basename,
|
||||
protocol_create_code=protocol_create_code,
|
||||
enable_sync_sends=enable_sync_sends,
|
||||
enable_async_sends=enable_async_sends,
|
||||
private=private,
|
||||
protocol_module_level_import_code=protocol_module_level_import_code,
|
||||
)
|
||||
|
||||
|
||||
def create_receiver_module(
|
||||
basename: str,
|
||||
protocol_create_code: str,
|
||||
is_async: bool,
|
||||
private: bool = False,
|
||||
protocol_module_level_import_code: str | None = None,
|
||||
build_time_protocol_create_code: str | None = None,
|
||||
) -> str:
|
||||
""" "Create a Python module defining a MessageReceiver subclass.
|
||||
|
||||
This class is primarily for type checking and will contain overrides
|
||||
for the register method for message/response types defined in
|
||||
the protocol.
|
||||
|
||||
Class names are based on basename; a basename 'FooReceiver' will
|
||||
result in FooReceiver and BoundFooReceiver.
|
||||
|
||||
If 'is_async' is True, handle_raw_message() will be an async method
|
||||
and the @handler decorator will expect async methods.
|
||||
|
||||
If 'private' is True, class-names will be prefixed with an '_'.
|
||||
|
||||
Note that line lengths are not clipped, so output may need to be
|
||||
run through a formatter to prevent lint warnings about excessive
|
||||
line lengths.
|
||||
"""
|
||||
protocol = _protocol_from_code(
|
||||
build_time_protocol_create_code
|
||||
if build_time_protocol_create_code is not None
|
||||
else protocol_create_code
|
||||
)
|
||||
return protocol.do_create_receiver_module(
|
||||
basename=basename,
|
||||
protocol_create_code=protocol_create_code,
|
||||
is_async=is_async,
|
||||
private=private,
|
||||
protocol_module_level_import_code=protocol_module_level_import_code,
|
||||
)
|
||||
|
||||
|
||||
def _protocol_from_code(protocol_create_code: str) -> MessageProtocol:
|
||||
env: dict = {}
|
||||
exec(protocol_create_code, env) # pylint: disable=exec-used
|
||||
protocol = env.get('protocol')
|
||||
if not isinstance(protocol, MessageProtocol):
|
||||
raise RuntimeError(
|
||||
f'protocol_create_code yielded'
|
||||
f' a {type(protocol)}; expected a MessageProtocol instance.'
|
||||
)
|
||||
return protocol
|
||||
653
dist/ba_data/python/efro/message/_protocol.py
vendored
Normal file
653
dist/ba_data/python/efro/message/_protocol.py
vendored
Normal file
|
|
@ -0,0 +1,653 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for sending and responding to messages.
|
||||
Supports static typing for message types and possible return types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
import traceback
|
||||
import json
|
||||
|
||||
from efro.error import CleanError, CommunicationError
|
||||
from efro.dataclassio import (
|
||||
is_ioprepped_dataclass,
|
||||
dataclass_to_dict,
|
||||
dataclass_from_dict,
|
||||
)
|
||||
from efro.message._message import (
|
||||
Message,
|
||||
Response,
|
||||
SysResponse,
|
||||
ErrorSysResponse,
|
||||
EmptySysResponse,
|
||||
UnregisteredMessageIDError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
class MessageProtocol:
|
||||
"""Wrangles a set of message types, formats, and response types.
|
||||
Both endpoints must be using a compatible Protocol for communication
|
||||
to succeed. To maintain Protocol compatibility between revisions,
|
||||
all message types must retain the same id, message attr storage
|
||||
names must not change, newly added attrs must have default values,
|
||||
etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_types: dict[int, type[Message]],
|
||||
response_types: dict[int, type[Response]],
|
||||
forward_communication_errors: bool = False,
|
||||
forward_clean_errors: bool = False,
|
||||
remote_errors_include_stack_traces: bool = False,
|
||||
log_remote_errors: bool = True,
|
||||
) -> None:
|
||||
"""Create a protocol with a given configuration.
|
||||
|
||||
If 'forward_communication_errors' is True,
|
||||
efro.error.CommunicationErrors raised on the receiver end will
|
||||
result in a matching error raised back on the sender. This can
|
||||
be useful if the receiver will be in some way forwarding
|
||||
messages along and the sender doesn't need to know where
|
||||
communication breakdowns occurred; only that they did.
|
||||
|
||||
If 'forward_clean_errors' is True, efro.error.CleanError
|
||||
exceptions raised on the receiver end will result in a matching
|
||||
CleanError raised back on the sender.
|
||||
|
||||
When an exception is not covered by the optional forwarding
|
||||
mechanisms above, it will come across as efro.error.RemoteError
|
||||
and the exception will be logged on the receiver
|
||||
end - at least by default (see details below).
|
||||
|
||||
If 'remote_errors_include_stack_traces' is True, stringified
|
||||
stack traces will be returned with efro.error.RemoteError
|
||||
exceptions. This is useful for debugging but should only be
|
||||
enabled in cases where the sender is trusted to see internal
|
||||
details of the receiver.
|
||||
|
||||
By default, when a message-handling exception will result in an
|
||||
efro.error.RemoteError being returned to the sender, the
|
||||
exception will be logged on the receiver. This is because the
|
||||
goal is usually to avoid returning opaque RemoteErrors and to
|
||||
instead return something meaningful as part of the expected
|
||||
response type (even if that value itself represents a logical
|
||||
error state). If 'log_remote_errors' is False, however, such
|
||||
exceptions will not be logged on the receiver. This can be
|
||||
useful in combination with 'remote_errors_include_stack_traces'
|
||||
and 'forward_clean_errors' in situations where all error
|
||||
logging/management will be happening on the sender end. Be
|
||||
aware, however, that in that case it may be possible for
|
||||
communication errors to prevent such error messages from
|
||||
ever being seen.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
self.message_types_by_id: dict[int, type[Message]] = {}
|
||||
self.message_ids_by_type: dict[type[Message], int] = {}
|
||||
self.response_types_by_id: dict[
|
||||
int, type[Response] | type[SysResponse]
|
||||
] = {}
|
||||
self.response_ids_by_type: dict[
|
||||
type[Response] | type[SysResponse], int
|
||||
] = {}
|
||||
for m_id, m_type in message_types.items():
|
||||
|
||||
# Make sure only valid message types were passed and each
|
||||
# id was assigned only once.
|
||||
assert isinstance(m_id, int)
|
||||
assert m_id >= 0
|
||||
assert is_ioprepped_dataclass(m_type) and issubclass(
|
||||
m_type, Message
|
||||
)
|
||||
assert self.message_types_by_id.get(m_id) is None
|
||||
self.message_types_by_id[m_id] = m_type
|
||||
self.message_ids_by_type[m_type] = m_id
|
||||
|
||||
for r_id, r_type in response_types.items():
|
||||
assert isinstance(r_id, int)
|
||||
assert r_id >= 0
|
||||
assert is_ioprepped_dataclass(r_type) and issubclass(
|
||||
r_type, Response
|
||||
)
|
||||
assert self.response_types_by_id.get(r_id) is None
|
||||
self.response_types_by_id[r_id] = r_type
|
||||
self.response_ids_by_type[r_type] = r_id
|
||||
|
||||
# Register our SysResponse types. These use negative
|
||||
# IDs so as to never overlap with user Response types.
|
||||
def _reg_sys(reg_tp: type[SysResponse], reg_id: int) -> None:
|
||||
assert self.response_types_by_id.get(reg_id) is None
|
||||
self.response_types_by_id[reg_id] = reg_tp
|
||||
self.response_ids_by_type[reg_tp] = reg_id
|
||||
|
||||
_reg_sys(ErrorSysResponse, -1)
|
||||
_reg_sys(EmptySysResponse, -2)
|
||||
|
||||
# Some extra-thorough validation in debug mode.
|
||||
if __debug__:
|
||||
# Make sure all Message types' return types are valid
|
||||
# and have been assigned an ID as well.
|
||||
all_response_types: set[type[Response] | None] = set()
|
||||
for m_id, m_type in message_types.items():
|
||||
m_rtypes = m_type.get_response_types()
|
||||
|
||||
assert isinstance(m_rtypes, list)
|
||||
assert (
|
||||
m_rtypes
|
||||
), f'Message type {m_type} specifies no return types.'
|
||||
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
|
||||
for m_rtype in m_rtypes:
|
||||
all_response_types.add(m_rtype)
|
||||
for cls in all_response_types:
|
||||
if cls is None:
|
||||
continue
|
||||
assert is_ioprepped_dataclass(cls)
|
||||
assert issubclass(cls, Response)
|
||||
if cls not in self.response_ids_by_type:
|
||||
raise ValueError(
|
||||
f'Possible response type {cls} needs to be included'
|
||||
f' in response_types for this protocol.'
|
||||
)
|
||||
|
||||
# Make sure all registered types have unique base names.
|
||||
# We can take advantage of this to generate cleaner looking
|
||||
# protocol modules. Can revisit if this is ever a problem.
|
||||
mtypenames = set(tp.__name__ for tp in self.message_ids_by_type)
|
||||
if len(mtypenames) != len(message_types):
|
||||
raise ValueError(
|
||||
'message_types contains duplicate __name__s;'
|
||||
' all types are required to have unique names.'
|
||||
)
|
||||
|
||||
self.forward_clean_errors = forward_clean_errors
|
||||
self.forward_communication_errors = forward_communication_errors
|
||||
self.remote_errors_include_stack_traces = (
|
||||
remote_errors_include_stack_traces
|
||||
)
|
||||
self.log_remote_errors = log_remote_errors
|
||||
|
||||
@staticmethod
|
||||
def encode_dict(obj: dict) -> str:
|
||||
"""Json-encode a provided dict."""
|
||||
return json.dumps(obj, separators=(',', ':'))
|
||||
|
||||
def message_to_dict(self, message: Message) -> dict:
|
||||
"""Encode a message to a json ready dict."""
|
||||
return self._to_dict(message, self.message_ids_by_type, 'message')
|
||||
|
||||
def response_to_dict(self, response: Response | SysResponse) -> dict:
|
||||
"""Encode a response to a json ready dict."""
|
||||
return self._to_dict(response, self.response_ids_by_type, 'response')
|
||||
|
||||
def error_to_response(self, exc: Exception) -> tuple[SysResponse, bool]:
|
||||
"""Translate an Exception to a SysResponse.
|
||||
|
||||
Also returns whether the error should be logged if this happened
|
||||
within handle_raw_message().
|
||||
"""
|
||||
|
||||
# If anything goes wrong, return a ErrorSysResponse instead.
|
||||
# (either CLEAN or generic REMOTE)
|
||||
if self.forward_clean_errors and isinstance(exc, CleanError):
|
||||
return (
|
||||
ErrorSysResponse(
|
||||
error_message=str(exc),
|
||||
error_type=ErrorSysResponse.ErrorType.REMOTE_CLEAN,
|
||||
),
|
||||
False,
|
||||
)
|
||||
if self.forward_communication_errors and isinstance(
|
||||
exc, CommunicationError
|
||||
):
|
||||
return (
|
||||
ErrorSysResponse(
|
||||
error_message=str(exc),
|
||||
error_type=ErrorSysResponse.ErrorType.REMOTE_COMMUNICATION,
|
||||
),
|
||||
False,
|
||||
)
|
||||
return (
|
||||
ErrorSysResponse(
|
||||
error_message=(
|
||||
traceback.format_exc()
|
||||
if self.remote_errors_include_stack_traces
|
||||
else 'An internal error has occurred.'
|
||||
),
|
||||
error_type=ErrorSysResponse.ErrorType.REMOTE,
|
||||
),
|
||||
self.log_remote_errors,
|
||||
)
|
||||
|
||||
def _to_dict(
|
||||
self, message: Any, ids_by_type: dict[type, int], opname: str
|
||||
) -> dict:
|
||||
"""Encode a message to a json string for transport."""
|
||||
|
||||
m_id: int | None = ids_by_type.get(type(message))
|
||||
if m_id is None:
|
||||
raise TypeError(
|
||||
f'{opname} type is not registered in protocol:'
|
||||
f' {type(message)}'
|
||||
)
|
||||
out = {'t': m_id, 'm': dataclass_to_dict(message)}
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def decode_dict(data: str) -> dict:
|
||||
"""Decode data to a dict."""
|
||||
out = json.loads(data)
|
||||
assert isinstance(out, dict)
|
||||
return out
|
||||
|
||||
def message_from_dict(self, data: dict) -> Message:
|
||||
"""Decode a message from a json string."""
|
||||
out = self._from_dict(data, self.message_types_by_id, 'message')
|
||||
assert isinstance(out, Message)
|
||||
return out
|
||||
|
||||
def response_from_dict(self, data: dict) -> Response | SysResponse:
|
||||
"""Decode a response from a json string."""
|
||||
out = self._from_dict(data, self.response_types_by_id, 'response')
|
||||
assert isinstance(out, Response | SysResponse)
|
||||
return out
|
||||
|
||||
# Weeeird; we get mypy errors returning dict[int, type] but
|
||||
# dict[int, typing.Type] or dict[int, type[Any]] works..
|
||||
def _from_dict(
|
||||
self, data: dict, types_by_id: dict[int, type[Any]], opname: str
|
||||
) -> Any:
|
||||
"""Decode a message from a json string."""
|
||||
msgdict: dict | None
|
||||
|
||||
m_id = data.get('t')
|
||||
# Allow omitting 'm' dict if its empty.
|
||||
msgdict = data.get('m', {})
|
||||
|
||||
assert isinstance(m_id, int)
|
||||
assert isinstance(msgdict, dict)
|
||||
|
||||
# Decode this particular type.
|
||||
msgtype = types_by_id.get(m_id)
|
||||
if msgtype is None:
|
||||
raise UnregisteredMessageIDError(
|
||||
f'Got unregistered {opname} id of {m_id}.'
|
||||
)
|
||||
return dataclass_from_dict(msgtype, msgdict)
|
||||
|
||||
def _get_module_header(
|
||||
self,
|
||||
part: Literal['sender', 'receiver'],
|
||||
extra_import_code: str | None,
|
||||
enable_async_sends: bool,
|
||||
) -> str:
|
||||
"""Return common parts of generated modules."""
|
||||
# pylint: disable=too-many-locals
|
||||
# pylint: disable=too-many-branches
|
||||
# pylint: disable=too-many-statements
|
||||
import textwrap
|
||||
|
||||
tpimports: dict[str, list[str]] = {}
|
||||
imports: dict[str, list[str]] = {}
|
||||
|
||||
single_message_type = len(self.message_ids_by_type) == 1
|
||||
|
||||
msgtypes = list(self.message_ids_by_type)
|
||||
if part == 'sender':
|
||||
msgtypes.append(Message)
|
||||
for msgtype in msgtypes:
|
||||
tpimports.setdefault(msgtype.__module__, []).append(
|
||||
msgtype.__name__
|
||||
)
|
||||
rsptypes = list(self.response_ids_by_type)
|
||||
if part == 'sender':
|
||||
rsptypes.append(Response)
|
||||
for rsp_tp in rsptypes:
|
||||
# Skip these as they don't actually show up in code.
|
||||
if rsp_tp is EmptySysResponse or rsp_tp is ErrorSysResponse:
|
||||
continue
|
||||
if (
|
||||
single_message_type
|
||||
and part == 'sender'
|
||||
and rsp_tp is not Response
|
||||
):
|
||||
# We need to cast to the single supported response type
|
||||
# in this case so need response types at runtime.
|
||||
imports.setdefault(rsp_tp.__module__, []).append(
|
||||
rsp_tp.__name__
|
||||
)
|
||||
else:
|
||||
tpimports.setdefault(rsp_tp.__module__, []).append(
|
||||
rsp_tp.__name__
|
||||
)
|
||||
|
||||
import_lines = ''
|
||||
tpimport_lines = ''
|
||||
|
||||
for module, names in sorted(imports.items()):
|
||||
jnames = ', '.join(names)
|
||||
line = f'from {module} import {jnames}'
|
||||
if len(line) > 79:
|
||||
# Recreate in a wrapping-friendly form.
|
||||
line = f'from {module} import ({jnames})'
|
||||
import_lines += f'{line}\n'
|
||||
for module, names in sorted(tpimports.items()):
|
||||
jnames = ', '.join(names)
|
||||
line = f'from {module} import {jnames}'
|
||||
if len(line) > 75: # Account for indent
|
||||
# Recreate in a wrapping-friendly form.
|
||||
line = f'from {module} import ({jnames})'
|
||||
tpimport_lines += f'{line}\n'
|
||||
|
||||
if part == 'sender':
|
||||
import_lines += (
|
||||
'from efro.message import MessageSender, BoundMessageSender'
|
||||
)
|
||||
tpimport_typing_extras = ''
|
||||
else:
|
||||
if single_message_type:
|
||||
import_lines += (
|
||||
'from efro.message import (MessageReceiver,'
|
||||
' BoundMessageReceiver, Message, Response)'
|
||||
)
|
||||
else:
|
||||
import_lines += (
|
||||
'from efro.message import MessageReceiver,'
|
||||
' BoundMessageReceiver'
|
||||
)
|
||||
tpimport_typing_extras = ', Awaitable'
|
||||
|
||||
if extra_import_code is not None:
|
||||
import_lines += f'\n{extra_import_code}\n'
|
||||
|
||||
ovld = ', overload' if not single_message_type else ''
|
||||
ovld2 = (
|
||||
', cast, Awaitable'
|
||||
if (single_message_type and part == 'sender' and enable_async_sends)
|
||||
else ''
|
||||
)
|
||||
tpimport_lines = textwrap.indent(tpimport_lines, ' ')
|
||||
|
||||
baseimps = ['Any']
|
||||
if part == 'receiver':
|
||||
baseimps.append('Callable')
|
||||
if part == 'sender' and enable_async_sends:
|
||||
baseimps.append('Awaitable')
|
||||
baseimps_s = ', '.join(baseimps)
|
||||
out = (
|
||||
'# Released under the MIT License. See LICENSE for details.\n'
|
||||
f'#\n'
|
||||
f'"""Auto-generated {part} module. Do not edit by hand."""\n'
|
||||
f'\n'
|
||||
f'from __future__ import annotations\n'
|
||||
f'\n'
|
||||
f'from typing import TYPE_CHECKING{ovld}{ovld2}\n'
|
||||
f'\n'
|
||||
f'{import_lines}\n'
|
||||
f'\n'
|
||||
f'if TYPE_CHECKING:\n'
|
||||
f' from typing import {baseimps_s}'
|
||||
f'{tpimport_typing_extras}\n'
|
||||
f'{tpimport_lines}'
|
||||
f'\n'
|
||||
f'\n'
|
||||
)
|
||||
return out
|
||||
|
||||
def do_create_sender_module(
|
||||
self,
|
||||
basename: str,
|
||||
protocol_create_code: str,
|
||||
enable_sync_sends: bool,
|
||||
enable_async_sends: bool,
|
||||
private: bool = False,
|
||||
protocol_module_level_import_code: str | None = None,
|
||||
) -> str:
|
||||
"""Used by create_sender_module(); do not call directly."""
|
||||
# pylint: disable=too-many-locals
|
||||
# pylint: disable=too-many-branches
|
||||
import textwrap
|
||||
|
||||
msgtypes = list(self.message_ids_by_type.keys())
|
||||
|
||||
ppre = '_' if private else ''
|
||||
out = self._get_module_header(
|
||||
'sender',
|
||||
extra_import_code=protocol_module_level_import_code,
|
||||
enable_async_sends=enable_async_sends,
|
||||
)
|
||||
ccind = textwrap.indent(protocol_create_code, ' ')
|
||||
out += (
|
||||
f'class {ppre}{basename}(MessageSender):\n'
|
||||
f' """Protocol-specific sender."""\n'
|
||||
f'\n'
|
||||
f' def __init__(self) -> None:\n'
|
||||
f'{ccind}\n'
|
||||
f' super().__init__(protocol)\n'
|
||||
f'\n'
|
||||
f' def __get__(\n'
|
||||
f' self, obj: Any, type_in: Any = None\n'
|
||||
f' ) -> {ppre}Bound{basename}:\n'
|
||||
f' return {ppre}Bound{basename}(obj, self)\n'
|
||||
f'\n'
|
||||
f'\n'
|
||||
f'class {ppre}Bound{basename}(BoundMessageSender):\n'
|
||||
f' """Protocol-specific bound sender."""\n'
|
||||
)
|
||||
|
||||
def _filt_tp_name(rtype: type[Response] | None) -> str:
|
||||
return 'None' if rtype is None else rtype.__name__
|
||||
|
||||
# Define handler() overloads for all registered message types.
|
||||
if msgtypes:
|
||||
for async_pass in False, True:
|
||||
if async_pass and not enable_async_sends:
|
||||
continue
|
||||
if not async_pass and not enable_sync_sends:
|
||||
continue
|
||||
pfx = 'async ' if async_pass else ''
|
||||
sfx = '_async' if async_pass else ''
|
||||
# awt = 'await ' if async_pass else ''
|
||||
awt = ''
|
||||
how = 'asynchronously' if async_pass else 'synchronously'
|
||||
|
||||
if len(msgtypes) == 1:
|
||||
# Special case: with a single message types we don't
|
||||
# use overloads.
|
||||
msgtype = msgtypes[0]
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
rtypevar = ' | '.join(_filt_tp_name(t) for t in rtypes)
|
||||
else:
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
if async_pass:
|
||||
rtypevar = f'Awaitable[{rtypevar}]'
|
||||
out += (
|
||||
f'\n'
|
||||
f' def send{sfx}(self,'
|
||||
f' message: {msgtypevar})'
|
||||
f' -> {rtypevar}:\n'
|
||||
f' """Send a message {how}."""\n'
|
||||
f' out = {awt}self._sender.'
|
||||
f'send{sfx}(self._obj, message)\n'
|
||||
)
|
||||
if not async_pass:
|
||||
out += (
|
||||
f' assert isinstance(out, {rtypevar})\n'
|
||||
' return out\n'
|
||||
)
|
||||
else:
|
||||
out += f' return cast({rtypevar}, out)\n'
|
||||
|
||||
else:
|
||||
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
rtypevar = ' | '.join(
|
||||
_filt_tp_name(t) for t in rtypes
|
||||
)
|
||||
else:
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
out += (
|
||||
f'\n'
|
||||
f' @overload\n'
|
||||
f' {pfx}def send{sfx}(self,'
|
||||
f' message: {msgtypevar})'
|
||||
f' -> {rtypevar}:\n'
|
||||
f' ...\n'
|
||||
)
|
||||
rtypevar = 'Response | None'
|
||||
if async_pass:
|
||||
rtypevar = f'Awaitable[{rtypevar}]'
|
||||
out += (
|
||||
f'\n'
|
||||
f' def send{sfx}(self, message: Message)'
|
||||
f' -> {rtypevar}:\n'
|
||||
f' """Send a message {how}."""\n'
|
||||
f' return {awt}self._sender.'
|
||||
f'send{sfx}(self._obj, message)\n'
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def do_create_receiver_module(
|
||||
self,
|
||||
basename: str,
|
||||
protocol_create_code: str,
|
||||
is_async: bool,
|
||||
private: bool = False,
|
||||
protocol_module_level_import_code: str | None = None,
|
||||
) -> str:
|
||||
"""Used by create_receiver_module(); do not call directly."""
|
||||
# pylint: disable=too-many-locals
|
||||
import textwrap
|
||||
|
||||
desc = 'asynchronous' if is_async else 'synchronous'
|
||||
ppre = '_' if private else ''
|
||||
msgtypes = list(self.message_ids_by_type.keys())
|
||||
out = self._get_module_header(
|
||||
'receiver',
|
||||
extra_import_code=protocol_module_level_import_code,
|
||||
enable_async_sends=False,
|
||||
)
|
||||
ccind = textwrap.indent(protocol_create_code, ' ')
|
||||
out += (
|
||||
f'class {ppre}{basename}(MessageReceiver):\n'
|
||||
f' """Protocol-specific {desc} receiver."""\n'
|
||||
f'\n'
|
||||
f' is_async = {is_async}\n'
|
||||
f'\n'
|
||||
f' def __init__(self) -> None:\n'
|
||||
f'{ccind}\n'
|
||||
f' super().__init__(protocol)\n'
|
||||
f'\n'
|
||||
f' def __get__(\n'
|
||||
f' self,\n'
|
||||
f' obj: Any,\n'
|
||||
f' type_in: Any = None,\n'
|
||||
f' ) -> {ppre}Bound{basename}:\n'
|
||||
f' return {ppre}Bound{basename}('
|
||||
f'obj, self)\n'
|
||||
)
|
||||
|
||||
# Define handler() overloads for all registered message types.
|
||||
|
||||
def _filt_tp_name(rtype: type[Response] | None) -> str:
|
||||
return 'None' if rtype is None else rtype.__name__
|
||||
|
||||
if msgtypes:
|
||||
cbgn = 'Awaitable[' if is_async else ''
|
||||
cend = ']' if is_async else ''
|
||||
if len(msgtypes) == 1:
|
||||
# Special case: when we have a single message type we don't
|
||||
# use overloads.
|
||||
msgtype = msgtypes[0]
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
rtypevar = ' | '.join(_filt_tp_name(t) for t in rtypes)
|
||||
else:
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
rtypevar = f'{cbgn}{rtypevar}{cend}'
|
||||
out += (
|
||||
f'\n'
|
||||
f' def handler(\n'
|
||||
f' self,\n'
|
||||
f' call: Callable[[Any, {msgtypevar}], '
|
||||
f'{rtypevar}],\n'
|
||||
f' )'
|
||||
f' -> Callable[[Any, {msgtypevar}], {rtypevar}]:\n'
|
||||
f' """Decorator to register message handlers."""\n'
|
||||
f' from typing import cast, Callable, Any\n'
|
||||
f'\n'
|
||||
f' self.register_handler(cast(Callable'
|
||||
f'[[Any, Message], Response], call))\n'
|
||||
f' return call\n'
|
||||
)
|
||||
else:
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
rtypevar = ' | '.join(_filt_tp_name(t) for t in rtypes)
|
||||
else:
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
rtypevar = f'{cbgn}{rtypevar}{cend}'
|
||||
out += (
|
||||
f'\n'
|
||||
f' @overload\n'
|
||||
f' def handler(\n'
|
||||
f' self,\n'
|
||||
f' call: Callable[[Any, {msgtypevar}], '
|
||||
f'{rtypevar}],\n'
|
||||
f' )'
|
||||
f' -> Callable[[Any, {msgtypevar}], {rtypevar}]:\n'
|
||||
f' ...\n'
|
||||
)
|
||||
out += (
|
||||
'\n'
|
||||
' def handler(self, call: Callable) -> Callable:\n'
|
||||
' """Decorator to register message handlers."""\n'
|
||||
' self.register_handler(call)\n'
|
||||
' return call\n'
|
||||
)
|
||||
|
||||
out += (
|
||||
f'\n'
|
||||
f'\n'
|
||||
f'class {ppre}Bound{basename}(BoundMessageReceiver):\n'
|
||||
f' """Protocol-specific bound receiver."""\n'
|
||||
)
|
||||
if is_async:
|
||||
out += (
|
||||
'\n'
|
||||
' def handle_raw_message(\n'
|
||||
' self, message: str, raise_unregistered: bool = False\n'
|
||||
' ) -> Awaitable[str]:\n'
|
||||
' """Asynchronously handle a raw incoming message."""\n'
|
||||
' return self._receiver.'
|
||||
'handle_raw_message_async(\n'
|
||||
' self._obj, message, raise_unregistered\n'
|
||||
' )\n'
|
||||
)
|
||||
|
||||
else:
|
||||
out += (
|
||||
'\n'
|
||||
' def handle_raw_message(\n'
|
||||
' self, message: str, raise_unregistered: bool = False\n'
|
||||
' ) -> str:\n'
|
||||
' """Synchronously handle a raw incoming message."""\n'
|
||||
' return self._receiver.handle_raw_message(\n'
|
||||
' self._obj, message, raise_unregistered\n'
|
||||
' )\n'
|
||||
)
|
||||
|
||||
return out
|
||||
420
dist/ba_data/python/efro/message/_receiver.py
vendored
Normal file
420
dist/ba_data/python/efro/message/_receiver.py
vendored
Normal file
|
|
@ -0,0 +1,420 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for sending and responding to messages.
|
||||
Supports static typing for message types and possible return types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from efro.message._message import (
|
||||
Message,
|
||||
Response,
|
||||
EmptySysResponse,
|
||||
UnregisteredMessageIDError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
from efro.message._protocol import MessageProtocol
|
||||
from efro.message._message import SysResponse
|
||||
|
||||
|
||||
class MessageReceiver:
|
||||
"""Facilitates receiving & responding to messages from a remote source.
|
||||
|
||||
This is instantiated at the class level with unbound methods registered
|
||||
as handlers for different message types in the protocol.
|
||||
|
||||
Example:
|
||||
|
||||
class MyClass:
|
||||
receiver = MyMessageReceiver()
|
||||
|
||||
# MyMessageReceiver fills out handler() overloads to ensure all
|
||||
# registered handlers have valid types/return-types.
|
||||
@receiver.handler
|
||||
def handle_some_message_type(self, message: SomeMsg) -> SomeResponse:
|
||||
# Deal with this message type here.
|
||||
|
||||
# This will trigger the registered handler being called.
|
||||
obj = MyClass()
|
||||
obj.receiver.handle_raw_message(some_raw_data)
|
||||
|
||||
Any unhandled Exception occurring during message handling will result in
|
||||
an Exception being raised on the sending end.
|
||||
"""
|
||||
|
||||
is_async = False
|
||||
|
||||
def __init__(self, protocol: MessageProtocol) -> None:
|
||||
self.protocol = protocol
|
||||
self._handlers: dict[type[Message], Callable] = {}
|
||||
self._decode_filter_call: Callable[
|
||||
[Any, dict, Message], None
|
||||
] | None = None
|
||||
self._encode_filter_call: Callable[
|
||||
[Any, Message | None, Response | SysResponse, dict], None
|
||||
] | None = None
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def register_handler(
|
||||
self, call: Callable[[Any, Message], Response | None]
|
||||
) -> None:
|
||||
"""Register a handler call.
|
||||
|
||||
The message type handled by the call is determined by its
|
||||
type annotation.
|
||||
"""
|
||||
# TODO: can use types.GenericAlias in 3.9.
|
||||
# (hmm though now that we're there, it seems a drop-in
|
||||
# replace gives us errors. Should re-test in 3.10 as it seems
|
||||
# that typing_extensions handles it differently in that case)
|
||||
from typing import _GenericAlias # type: ignore
|
||||
from typing import get_type_hints, get_args
|
||||
|
||||
sig = inspect.getfullargspec(call)
|
||||
|
||||
# The provided callable should be a method taking one 'msg' arg.
|
||||
expectedsig = ['self', 'msg']
|
||||
if sig.args != expectedsig:
|
||||
raise ValueError(
|
||||
f'Expected callable signature of {expectedsig};'
|
||||
f' got {sig.args}'
|
||||
)
|
||||
|
||||
# Make sure we are only given async methods if we are an async handler
|
||||
# and sync ones otherwise.
|
||||
# UPDATE - can't do this anymore since we now sometimes use
|
||||
# regular functions which return awaitables instead of having
|
||||
# the entire function be async.
|
||||
# is_async = inspect.iscoroutinefunction(call)
|
||||
# if self.is_async != is_async:
|
||||
# msg = (
|
||||
# 'Expected a sync method; found an async one.'
|
||||
# if is_async
|
||||
# else 'Expected an async method; found a sync one.'
|
||||
# )
|
||||
# raise ValueError(msg)
|
||||
|
||||
# Check annotation types to determine what message types we handle.
|
||||
# Return-type annotation can be a Union, but we probably don't
|
||||
# have it available at runtime. Explicitly pull it in.
|
||||
# UPDATE: we've updated our pylint filter to where we should
|
||||
# have all annotations available.
|
||||
# anns = get_type_hints(call, localns={'Union': Union})
|
||||
anns = get_type_hints(call)
|
||||
|
||||
msgtype = anns.get('msg')
|
||||
if not isinstance(msgtype, type):
|
||||
raise TypeError(
|
||||
f'expected a type for "msg" annotation; got {type(msgtype)}.'
|
||||
)
|
||||
assert issubclass(msgtype, Message)
|
||||
|
||||
ret = anns.get('return')
|
||||
responsetypes: tuple[type[Any] | None, ...]
|
||||
|
||||
# Return types can be a single type or a union of types.
|
||||
if isinstance(ret, (_GenericAlias, types.UnionType)):
|
||||
targs = get_args(ret)
|
||||
if not all(isinstance(a, (type, type(None))) for a in targs):
|
||||
raise TypeError(
|
||||
f'expected only types for "return" annotation;'
|
||||
f' got {targs}.'
|
||||
)
|
||||
responsetypes = targs
|
||||
else:
|
||||
if not isinstance(ret, (type, type(None))):
|
||||
raise TypeError(
|
||||
f'expected one or more types for'
|
||||
f' "return" annotation; got a {type(ret)}.'
|
||||
)
|
||||
# This seems like maybe a mypy bug. Appeared after adding
|
||||
# types.UnionType above.
|
||||
responsetypes = (ret,)
|
||||
|
||||
# This will contain NoneType for empty return cases, but
|
||||
# we expect it to be None.
|
||||
responsetypes = tuple(
|
||||
None if r is type(None) else r for r in responsetypes
|
||||
)
|
||||
|
||||
# Make sure our protocol has this message type registered and our
|
||||
# return types exactly match. (Technically we could return a subset
|
||||
# of the supported types; can allow this in the future if it makes
|
||||
# sense).
|
||||
registered_types = self.protocol.message_ids_by_type.keys()
|
||||
|
||||
if msgtype not in registered_types:
|
||||
raise TypeError(
|
||||
f'Message type {msgtype} is not registered'
|
||||
f' in this Protocol.'
|
||||
)
|
||||
|
||||
if msgtype in self._handlers:
|
||||
raise TypeError(
|
||||
f'Message type {msgtype} already has a registered' f' handler.'
|
||||
)
|
||||
|
||||
# Make sure the responses exactly matches what the message expects.
|
||||
if set(responsetypes) != set(msgtype.get_response_types()):
|
||||
raise TypeError(
|
||||
f'Provided response types {responsetypes} do not'
|
||||
f' match the set expected by message type {msgtype}: '
|
||||
f'({msgtype.get_response_types()})'
|
||||
)
|
||||
|
||||
# Ok; we're good!
|
||||
self._handlers[msgtype] = call
|
||||
|
||||
def decode_filter_method(
|
||||
self, call: Callable[[Any, dict, Message], None]
|
||||
) -> Callable[[Any, dict, Message], None]:
|
||||
"""Function decorator for defining a decode filter.
|
||||
|
||||
Decode filters can be used to extract extra data from incoming
|
||||
message dicts. This version will work for both handle_raw_message()
|
||||
and handle_raw_message_async()
|
||||
"""
|
||||
assert self._decode_filter_call is None
|
||||
self._decode_filter_call = call
|
||||
return call
|
||||
|
||||
def encode_filter_method(
|
||||
self,
|
||||
call: Callable[
|
||||
[Any, Message | None, Response | SysResponse, dict], None
|
||||
],
|
||||
) -> Callable[[Any, Message | None, Response, dict], None]:
|
||||
"""Function decorator for defining an encode filter.
|
||||
|
||||
Encode filters can be used to add extra data to the message
|
||||
dict before is is encoded to a string and sent out.
|
||||
"""
|
||||
assert self._encode_filter_call is None
|
||||
self._encode_filter_call = call
|
||||
return call
|
||||
|
||||
def validate(self, log_only: bool = False) -> None:
|
||||
"""Check for handler completeness, valid types, etc."""
|
||||
for msgtype in self.protocol.message_ids_by_type.keys():
|
||||
if issubclass(msgtype, Response):
|
||||
continue
|
||||
if msgtype not in self._handlers:
|
||||
msg = (
|
||||
f'Protocol message type {msgtype} is not handled'
|
||||
f' by receiver type {type(self)}.'
|
||||
)
|
||||
if log_only:
|
||||
logging.error(msg)
|
||||
else:
|
||||
raise TypeError(msg)
|
||||
|
||||
def _decode_incoming_message_base(
|
||||
self, bound_obj: Any, msg: str
|
||||
) -> tuple[Any, dict, Message]:
|
||||
# Decode the incoming message.
|
||||
msg_dict = self.protocol.decode_dict(msg)
|
||||
msg_decoded = self.protocol.message_from_dict(msg_dict)
|
||||
assert isinstance(msg_decoded, Message)
|
||||
if self._decode_filter_call is not None:
|
||||
self._decode_filter_call(bound_obj, msg_dict, msg_decoded)
|
||||
return bound_obj, msg_dict, msg_decoded
|
||||
|
||||
def _decode_incoming_message(self, bound_obj: Any, msg: str) -> Message:
|
||||
bound_obj, _msg_dict, msg_decoded = self._decode_incoming_message_base(
|
||||
bound_obj=bound_obj, msg=msg
|
||||
)
|
||||
return msg_decoded
|
||||
|
||||
def encode_user_response(
|
||||
self, bound_obj: Any, message: Message, response: Response | None
|
||||
) -> str:
|
||||
"""Encode a response provided by the user for sending."""
|
||||
|
||||
assert isinstance(response, Response | None)
|
||||
# (user should never explicitly return error-responses)
|
||||
assert (
|
||||
response is None or type(response) in message.get_response_types()
|
||||
)
|
||||
|
||||
# A return value of None equals EmptySysResponse.
|
||||
out_response: Response | SysResponse
|
||||
if response is None:
|
||||
out_response = EmptySysResponse()
|
||||
else:
|
||||
out_response = response
|
||||
|
||||
response_dict = self.protocol.response_to_dict(out_response)
|
||||
if self._encode_filter_call is not None:
|
||||
self._encode_filter_call(
|
||||
bound_obj, message, out_response, response_dict
|
||||
)
|
||||
return self.protocol.encode_dict(response_dict)
|
||||
|
||||
def encode_error_response(
|
||||
self, bound_obj: Any, message: Message | None, exc: Exception
|
||||
) -> tuple[str, bool]:
|
||||
"""Given an error, return sysresponse str and whether to log."""
|
||||
response, dolog = self.protocol.error_to_response(exc)
|
||||
response_dict = self.protocol.response_to_dict(response)
|
||||
if self._encode_filter_call is not None:
|
||||
self._encode_filter_call(
|
||||
bound_obj, message, response, response_dict
|
||||
)
|
||||
return self.protocol.encode_dict(response_dict), dolog
|
||||
|
||||
def handle_raw_message(
|
||||
self, bound_obj: Any, msg: str, raise_unregistered: bool = False
|
||||
) -> str:
|
||||
"""Decode, handle, and return an response for a message.
|
||||
|
||||
if 'raise_unregistered' is True, will raise an
|
||||
efro.message.UnregisteredMessageIDError for messages not handled by
|
||||
the protocol. In all other cases local errors will translate to
|
||||
error responses returned to the sender.
|
||||
"""
|
||||
assert not self.is_async, "can't call sync handler on async receiver"
|
||||
msg_decoded: Message | None = None
|
||||
msgtype: type[Message] | None = None
|
||||
try:
|
||||
msg_decoded = self._decode_incoming_message(bound_obj, msg)
|
||||
msgtype = type(msg_decoded)
|
||||
handler = self._handlers.get(msgtype)
|
||||
if handler is None:
|
||||
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
|
||||
response = handler(bound_obj, msg_decoded)
|
||||
assert isinstance(response, Response | None)
|
||||
return self.encode_user_response(bound_obj, msg_decoded, response)
|
||||
|
||||
except Exception as exc:
|
||||
if raise_unregistered and isinstance(
|
||||
exc, UnregisteredMessageIDError
|
||||
):
|
||||
raise
|
||||
rstr, dolog = self.encode_error_response(
|
||||
bound_obj, msg_decoded, exc
|
||||
)
|
||||
if dolog:
|
||||
if msgtype is not None:
|
||||
logging.exception(
|
||||
'Error handling %s.%s message.',
|
||||
msgtype.__module__,
|
||||
msgtype.__qualname__,
|
||||
)
|
||||
else:
|
||||
logging.exception('Error in efro.message handling.')
|
||||
return rstr
|
||||
|
||||
def handle_raw_message_async(
|
||||
self, bound_obj: Any, msg: str, raise_unregistered: bool = False
|
||||
) -> Awaitable[str]:
|
||||
"""Should be called when the receiver gets a message.
|
||||
|
||||
The return value is the raw response to the message.
|
||||
"""
|
||||
|
||||
# Note: This call is synchronous so that the first part of it can
|
||||
# happen synchronously. If the whole call were async we wouldn't be
|
||||
# able to guarantee that messages handlers would be called in the
|
||||
# order the messages were received.
|
||||
|
||||
assert self.is_async, "can't call async handler on sync receiver"
|
||||
msg_decoded: Message | None = None
|
||||
msgtype: type[Message] | None = None
|
||||
try:
|
||||
msg_decoded = self._decode_incoming_message(bound_obj, msg)
|
||||
msgtype = type(msg_decoded)
|
||||
handler = self._handlers.get(msgtype)
|
||||
if handler is None:
|
||||
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
|
||||
handler_awaitable = handler(bound_obj, msg_decoded)
|
||||
|
||||
except Exception as exc:
|
||||
if raise_unregistered and isinstance(
|
||||
exc, UnregisteredMessageIDError
|
||||
):
|
||||
raise
|
||||
return self._handle_raw_message_async_error(
|
||||
bound_obj, msg_decoded, msgtype, exc
|
||||
)
|
||||
|
||||
# Return an awaitable to handle the rest asynchronously.
|
||||
return self._handle_raw_message_async(
|
||||
bound_obj, msg_decoded, msgtype, handler_awaitable
|
||||
)
|
||||
|
||||
async def _handle_raw_message_async_error(
|
||||
self,
|
||||
bound_obj: Any,
|
||||
msg_decoded: Message | None,
|
||||
msgtype: type[Message] | None,
|
||||
exc: Exception,
|
||||
) -> str:
|
||||
rstr, dolog = self.encode_error_response(bound_obj, msg_decoded, exc)
|
||||
if dolog:
|
||||
if msgtype is not None:
|
||||
logging.exception(
|
||||
'Error handling %s.%s message.',
|
||||
msgtype.__module__,
|
||||
msgtype.__qualname__,
|
||||
)
|
||||
else:
|
||||
logging.exception('Error in efro.message handling.')
|
||||
return rstr
|
||||
|
||||
async def _handle_raw_message_async(
|
||||
self,
|
||||
bound_obj: Any,
|
||||
msg_decoded: Message,
|
||||
msgtype: type[Message] | None,
|
||||
handler_awaitable: Awaitable[Response | None],
|
||||
) -> str:
|
||||
"""Should be called when the receiver gets a message.
|
||||
|
||||
The return value is the raw response to the message.
|
||||
"""
|
||||
try:
|
||||
response = await handler_awaitable
|
||||
assert isinstance(response, Response | None)
|
||||
return self.encode_user_response(bound_obj, msg_decoded, response)
|
||||
|
||||
except Exception as exc:
|
||||
return await self._handle_raw_message_async_error(
|
||||
bound_obj, msg_decoded, msgtype, exc
|
||||
)
|
||||
|
||||
|
||||
class BoundMessageReceiver:
|
||||
"""Base bound receiver class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obj: Any,
|
||||
receiver: MessageReceiver,
|
||||
) -> None:
|
||||
assert obj is not None
|
||||
self._obj = obj
|
||||
self._receiver = receiver
|
||||
|
||||
@property
|
||||
def protocol(self) -> MessageProtocol:
|
||||
"""Protocol associated with this receiver."""
|
||||
return self._receiver.protocol
|
||||
|
||||
def encode_error_response(self, exc: Exception) -> str:
|
||||
"""Given an error, return a response ready to send.
|
||||
|
||||
This should be used for any errors that happen outside of
|
||||
standard handle_raw_message calls. Any errors within those
|
||||
calls will be automatically returned as encoded strings.
|
||||
"""
|
||||
# Passing None for Message here; we would only have that available
|
||||
# for things going wrong in the handler (which this is not for).
|
||||
return self._receiver.encode_error_response(self._obj, None, exc)[0]
|
||||
465
dist/ba_data/python/efro/message/_sender.py
vendored
Normal file
465
dist/ba_data/python/efro/message/_sender.py
vendored
Normal file
|
|
@ -0,0 +1,465 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for sending and responding to messages.
|
||||
Supports static typing for message types and possible return types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from efro.error import CleanError, RemoteError, CommunicationError
|
||||
from efro.message._message import EmptySysResponse, ErrorSysResponse, Response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
from efro.message._message import Message, SysResponse
|
||||
from efro.message._protocol import MessageProtocol
|
||||
|
||||
|
||||
class MessageSender:
|
||||
"""Facilitates sending messages to a target and receiving responses.
|
||||
This is instantiated at the class level and used to register unbound
|
||||
class methods to handle raw message sending.
|
||||
|
||||
Example:
|
||||
|
||||
class MyClass:
|
||||
msg = MyMessageSender(some_protocol)
|
||||
|
||||
@msg.send_method
|
||||
def send_raw_message(self, message: str) -> str:
|
||||
# Actually send the message here.
|
||||
|
||||
# MyMessageSender class should provide overloads for send(), send_async(),
|
||||
# etc. to ensure all sending happens with valid types.
|
||||
obj = MyClass()
|
||||
obj.msg.send(SomeMessageType())
|
||||
"""
|
||||
|
||||
def __init__(self, protocol: MessageProtocol) -> None:
|
||||
self.protocol = protocol
|
||||
self._send_raw_message_call: Callable[[Any, str], str] | None = None
|
||||
self._send_async_raw_message_call: Callable[
|
||||
[Any, str], Awaitable[str]
|
||||
] | None = None
|
||||
self._send_async_raw_message_ex_call: Callable[
|
||||
[Any, str, Message], Awaitable[str]
|
||||
] | None = None
|
||||
self._encode_filter_call: Callable[
|
||||
[Any, Message, dict], None
|
||||
] | None = None
|
||||
self._decode_filter_call: Callable[
|
||||
[Any, Message, dict, Response | SysResponse], None
|
||||
] | None = None
|
||||
self._peer_desc_call: Callable[[Any], str] | None = None
|
||||
|
||||
def send_method(
|
||||
self, call: Callable[[Any, str], str]
|
||||
) -> Callable[[Any, str], str]:
|
||||
"""Function decorator for setting raw send method.
|
||||
|
||||
Send methods take strings and should return strings.
|
||||
CommunicationErrors raised here will be returned to the sender
|
||||
as such; all other exceptions will result in a RuntimeError for
|
||||
the sender.
|
||||
"""
|
||||
assert self._send_raw_message_call is None
|
||||
self._send_raw_message_call = call
|
||||
return call
|
||||
|
||||
def send_async_method(
|
||||
self, call: Callable[[Any, str], Awaitable[str]]
|
||||
) -> Callable[[Any, str], Awaitable[str]]:
|
||||
"""Function decorator for setting raw send-async method.
|
||||
|
||||
Send methods take strings and should return strings.
|
||||
CommunicationErrors raised here will be returned to the sender
|
||||
as such; all other exceptions will result in a RuntimeError for
|
||||
the sender.
|
||||
|
||||
IMPORTANT: Generally async send methods should not be implemented
|
||||
as 'async' methods, but instead should be regular methods that
|
||||
return awaitable objects. This way it can be guaranteed that
|
||||
outgoing messages are synchronously enqueued in the correct
|
||||
order, and then async calls can be returned which finish each
|
||||
send. If the entire call is async, they may be enqueued out of
|
||||
order in rare cases.
|
||||
"""
|
||||
assert self._send_async_raw_message_call is None
|
||||
self._send_async_raw_message_call = call
|
||||
return call
|
||||
|
||||
def send_async_ex_method(
|
||||
self, call: Callable[[Any, str, Message], Awaitable[str]]
|
||||
) -> Callable[[Any, str, Message], Awaitable[str]]:
|
||||
"""Function decorator for extended send-async method.
|
||||
|
||||
Version of send_async_method which is also is passed the original
|
||||
unencoded message; can be useful for cases where metadata is sent
|
||||
along with messages referring to their payloads/etc.
|
||||
"""
|
||||
assert self._send_async_raw_message_ex_call is None
|
||||
self._send_async_raw_message_ex_call = call
|
||||
return call
|
||||
|
||||
def encode_filter_method(
|
||||
self, call: Callable[[Any, Message, dict], None]
|
||||
) -> Callable[[Any, Message, dict], None]:
|
||||
"""Function decorator for defining an encode filter.
|
||||
|
||||
Encode filters can be used to add extra data to the message
|
||||
dict before is is encoded to a string and sent out.
|
||||
"""
|
||||
assert self._encode_filter_call is None
|
||||
self._encode_filter_call = call
|
||||
return call
|
||||
|
||||
def decode_filter_method(
|
||||
self, call: Callable[[Any, Message, dict, Response | SysResponse], None]
|
||||
) -> Callable[[Any, Message, dict, Response], None]:
|
||||
"""Function decorator for defining a decode filter.
|
||||
|
||||
Decode filters can be used to extract extra data from incoming
|
||||
message dicts.
|
||||
"""
|
||||
assert self._decode_filter_call is None
|
||||
self._decode_filter_call = call
|
||||
return call
|
||||
|
||||
def peer_desc_method(
|
||||
self, call: Callable[[Any], str]
|
||||
) -> Callable[[Any], str]:
|
||||
"""Function decorator for defining peer descriptions.
|
||||
|
||||
These are included in error messages or other diagnostics.
|
||||
"""
|
||||
assert self._peer_desc_call is None
|
||||
self._peer_desc_call = call
|
||||
return call
|
||||
|
||||
def send(self, bound_obj: Any, message: Message) -> Response | None:
|
||||
"""Send a message synchronously."""
|
||||
return self.unpack_raw_response(
|
||||
bound_obj=bound_obj,
|
||||
message=message,
|
||||
raw_response=self.fetch_raw_response(
|
||||
bound_obj=bound_obj,
|
||||
message=message,
|
||||
),
|
||||
)
|
||||
|
||||
def send_async(
|
||||
self, bound_obj: Any, message: Message
|
||||
) -> Awaitable[Response | None]:
|
||||
"""Send a message asynchronously."""
|
||||
|
||||
# Note: This call is synchronous so that the first part of it can
|
||||
# happen synchronously. If the whole call were async we wouldn't be
|
||||
# able to guarantee that messages sent in order would actually go
|
||||
# out in order.
|
||||
raw_response_awaitable = self.fetch_raw_response_async(
|
||||
bound_obj=bound_obj,
|
||||
message=message,
|
||||
)
|
||||
# Now return an awaitable that will finish the send.
|
||||
return self._send_async_awaitable(
|
||||
bound_obj, message, raw_response_awaitable
|
||||
)
|
||||
|
||||
async def _send_async_awaitable(
|
||||
self,
|
||||
bound_obj: Any,
|
||||
message: Message,
|
||||
raw_response_awaitable: Awaitable[Response | SysResponse],
|
||||
) -> Response | None:
|
||||
return self.unpack_raw_response(
|
||||
bound_obj=bound_obj,
|
||||
message=message,
|
||||
raw_response=await raw_response_awaitable,
|
||||
)
|
||||
|
||||
def fetch_raw_response(
|
||||
self, bound_obj: Any, message: Message
|
||||
) -> Response | SysResponse:
|
||||
"""Send a message synchronously.
|
||||
|
||||
Generally you can just call send(); these split versions are
|
||||
for when message sending and response handling need to happen
|
||||
in different contexts/threads.
|
||||
"""
|
||||
if self._send_raw_message_call is None:
|
||||
raise RuntimeError('send() is unimplemented for this type.')
|
||||
|
||||
msg_encoded = self._encode_message(bound_obj, message)
|
||||
try:
|
||||
response_encoded = self._send_raw_message_call(
|
||||
bound_obj, msg_encoded
|
||||
)
|
||||
except Exception as exc:
|
||||
response = ErrorSysResponse(
|
||||
error_message='Error in MessageSender @send_method.',
|
||||
error_type=(
|
||||
ErrorSysResponse.ErrorType.COMMUNICATION
|
||||
if isinstance(exc, CommunicationError)
|
||||
else ErrorSysResponse.ErrorType.LOCAL
|
||||
),
|
||||
)
|
||||
# Can include the actual exception since we'll be looking at
|
||||
# this locally; might be helpful.
|
||||
response.set_local_exception(exc)
|
||||
return response
|
||||
return self._decode_raw_response(bound_obj, message, response_encoded)
|
||||
|
||||
def fetch_raw_response_async(
|
||||
self, bound_obj: Any, message: Message
|
||||
) -> Awaitable[Response | SysResponse]:
|
||||
"""Fetch a raw message response awaitable.
|
||||
|
||||
The result of this should be awaited and then passed to
|
||||
unpack_raw_response() to produce the final message result.
|
||||
|
||||
Generally you can just call send(); calling fetch and unpack
|
||||
manually is for when message sending and response handling need
|
||||
to happen in different contexts/threads.
|
||||
"""
|
||||
|
||||
# Note: This call is synchronous so that the first part of it can
|
||||
# happen synchronously. If the whole call were async we wouldn't be
|
||||
# able to guarantee that messages sent in order would actually go
|
||||
# out in order.
|
||||
if (
|
||||
self._send_async_raw_message_call is None
|
||||
and self._send_async_raw_message_ex_call is None
|
||||
):
|
||||
raise RuntimeError('send_async() is unimplemented for this type.')
|
||||
|
||||
msg_encoded = self._encode_message(bound_obj, message)
|
||||
try:
|
||||
if self._send_async_raw_message_ex_call is not None:
|
||||
send_awaitable = self._send_async_raw_message_ex_call(
|
||||
bound_obj, msg_encoded, message
|
||||
)
|
||||
else:
|
||||
assert self._send_async_raw_message_call is not None
|
||||
send_awaitable = self._send_async_raw_message_call(
|
||||
bound_obj, msg_encoded
|
||||
)
|
||||
except Exception as exc:
|
||||
return self._error_awaitable(exc)
|
||||
|
||||
# Now return an awaitable to finish the job.
|
||||
return self._fetch_raw_response_awaitable(
|
||||
bound_obj, message, send_awaitable
|
||||
)
|
||||
|
||||
async def _error_awaitable(self, exc: Exception) -> SysResponse:
|
||||
response = ErrorSysResponse(
|
||||
error_message='Error in MessageSender @send_async_method.',
|
||||
error_type=(
|
||||
ErrorSysResponse.ErrorType.COMMUNICATION
|
||||
if isinstance(exc, CommunicationError)
|
||||
else ErrorSysResponse.ErrorType.LOCAL
|
||||
),
|
||||
)
|
||||
# Can include the actual exception since we'll be looking at
|
||||
# this locally; might be helpful.
|
||||
response.set_local_exception(exc)
|
||||
return response
|
||||
|
||||
async def _fetch_raw_response_awaitable(
|
||||
self, bound_obj: Any, message: Message, send_awaitable: Awaitable[str]
|
||||
) -> Response | SysResponse:
|
||||
|
||||
try:
|
||||
response_encoded = await send_awaitable
|
||||
except Exception as exc:
|
||||
response = ErrorSysResponse(
|
||||
error_message='Error in MessageSender @send_async_method.',
|
||||
error_type=(
|
||||
ErrorSysResponse.ErrorType.COMMUNICATION
|
||||
if isinstance(exc, CommunicationError)
|
||||
else ErrorSysResponse.ErrorType.LOCAL
|
||||
),
|
||||
)
|
||||
# Can include the actual exception since we'll be looking at
|
||||
# this locally; might be helpful.
|
||||
response.set_local_exception(exc)
|
||||
return response
|
||||
return self._decode_raw_response(bound_obj, message, response_encoded)
|
||||
|
||||
def unpack_raw_response(
|
||||
self,
|
||||
bound_obj: Any,
|
||||
message: Message,
|
||||
raw_response: Response | SysResponse,
|
||||
) -> Response | None:
|
||||
"""Convert a raw fetched response into a final response/error/etc.
|
||||
|
||||
Generally you can just call send(); calling fetch and unpack
|
||||
manually is for when message sending and response handling need
|
||||
to happen in different contexts/threads.
|
||||
"""
|
||||
response = self._unpack_raw_response(bound_obj, raw_response)
|
||||
assert (
|
||||
response is None
|
||||
or type(response) in type(message).get_response_types()
|
||||
)
|
||||
return response
|
||||
|
||||
def _encode_message(self, bound_obj: Any, message: Message) -> str:
|
||||
"""Encode a message for sending."""
|
||||
msg_dict = self.protocol.message_to_dict(message)
|
||||
if self._encode_filter_call is not None:
|
||||
self._encode_filter_call(bound_obj, message, msg_dict)
|
||||
return self.protocol.encode_dict(msg_dict)
|
||||
|
||||
def _decode_raw_response(
|
||||
self, bound_obj: Any, message: Message, response_encoded: str
|
||||
) -> Response | SysResponse:
|
||||
"""Create a Response from returned data.
|
||||
|
||||
These Responses may encapsulate things like remote errors and
|
||||
should not be handed directly to users. _unpack_raw_response()
|
||||
should be used to translate to special values like None or raise
|
||||
Exceptions. This function itself should never raise Exceptions.
|
||||
"""
|
||||
response: Response | SysResponse
|
||||
try:
|
||||
response_dict = self.protocol.decode_dict(response_encoded)
|
||||
response = self.protocol.response_from_dict(response_dict)
|
||||
if self._decode_filter_call is not None:
|
||||
self._decode_filter_call(
|
||||
bound_obj, message, response_dict, response
|
||||
)
|
||||
except Exception as exc:
|
||||
response = ErrorSysResponse(
|
||||
error_message='Error decoding raw response.',
|
||||
error_type=ErrorSysResponse.ErrorType.LOCAL,
|
||||
)
|
||||
# Since we'll be looking at this locally, we can include
|
||||
# extra info for logging/etc.
|
||||
response.set_local_exception(exc)
|
||||
return response
|
||||
|
||||
def _unpack_raw_response(
|
||||
self, bound_obj: Any, raw_response: Response | SysResponse
|
||||
) -> Response | None:
|
||||
"""Given a raw Response, unpacks to special values or Exceptions.
|
||||
|
||||
The result of this call is what should be passed to users.
|
||||
For complex messaging situations such as response callbacks
|
||||
operating across different threads, this last stage should be
|
||||
run such that any raised Exception is active when the callback
|
||||
fires; not on the thread where the message was sent.
|
||||
"""
|
||||
# EmptySysResponse translates to None
|
||||
if isinstance(raw_response, EmptySysResponse):
|
||||
return None
|
||||
|
||||
# Some error occurred. Raise a local Exception for it.
|
||||
if isinstance(raw_response, ErrorSysResponse):
|
||||
|
||||
# Errors that happened locally can attach their exceptions
|
||||
# here for extra logging goodness.
|
||||
local_exception = raw_response.get_local_exception()
|
||||
|
||||
if (
|
||||
raw_response.error_type
|
||||
is ErrorSysResponse.ErrorType.COMMUNICATION
|
||||
):
|
||||
raise CommunicationError(
|
||||
raw_response.error_message
|
||||
) from local_exception
|
||||
|
||||
# If something went wrong on *our* end of the connection,
|
||||
# don't say it was a remote error.
|
||||
if raw_response.error_type is ErrorSysResponse.ErrorType.LOCAL:
|
||||
raise RuntimeError(
|
||||
raw_response.error_message
|
||||
) from local_exception
|
||||
|
||||
# If they want to support clean errors, do those.
|
||||
if (
|
||||
self.protocol.forward_clean_errors
|
||||
and raw_response.error_type
|
||||
is ErrorSysResponse.ErrorType.REMOTE_CLEAN
|
||||
):
|
||||
raise CleanError(
|
||||
raw_response.error_message
|
||||
) from local_exception
|
||||
|
||||
if (
|
||||
self.protocol.forward_communication_errors
|
||||
and raw_response.error_type
|
||||
is ErrorSysResponse.ErrorType.REMOTE_COMMUNICATION
|
||||
):
|
||||
raise CommunicationError(
|
||||
raw_response.error_message
|
||||
) from local_exception
|
||||
|
||||
# Everything else gets lumped in as a remote error.
|
||||
raise RemoteError(
|
||||
raw_response.error_message,
|
||||
peer_desc=(
|
||||
'peer'
|
||||
if self._peer_desc_call is None
|
||||
else self._peer_desc_call(bound_obj)
|
||||
),
|
||||
) from local_exception
|
||||
|
||||
assert isinstance(raw_response, Response)
|
||||
return raw_response
|
||||
|
||||
|
||||
class BoundMessageSender:
|
||||
"""Base class for bound senders."""
|
||||
|
||||
def __init__(self, obj: Any, sender: MessageSender) -> None:
|
||||
# Note: not checking obj here since we want to support
|
||||
# at least our protocol property when accessed via type.
|
||||
self._obj = obj
|
||||
self._sender = sender
|
||||
|
||||
@property
|
||||
def protocol(self) -> MessageProtocol:
|
||||
"""Protocol associated with this sender."""
|
||||
return self._sender.protocol
|
||||
|
||||
def send_untyped(self, message: Message) -> Response | None:
|
||||
"""Send a message synchronously.
|
||||
|
||||
Whenever possible, use the send() call provided by generated
|
||||
subclasses instead of this; it will provide better type safety.
|
||||
"""
|
||||
assert self._obj is not None
|
||||
return self._sender.send(bound_obj=self._obj, message=message)
|
||||
|
||||
def send_async_untyped(
|
||||
self, message: Message
|
||||
) -> Awaitable[Response | None]:
|
||||
"""Send a message asynchronously.
|
||||
|
||||
Whenever possible, use the send_async() call provided by generated
|
||||
subclasses instead of this; it will provide better type safety.
|
||||
"""
|
||||
assert self._obj is not None
|
||||
return self._sender.send_async(bound_obj=self._obj, message=message)
|
||||
|
||||
def fetch_raw_response_async_untyped(
|
||||
self, message: Message
|
||||
) -> Awaitable[Response | SysResponse]:
|
||||
"""Split send (part 1 of 2)."""
|
||||
assert self._obj is not None
|
||||
return self._sender.fetch_raw_response_async(
|
||||
bound_obj=self._obj, message=message
|
||||
)
|
||||
|
||||
def unpack_raw_response_untyped(
|
||||
self, message: Message, raw_response: Response | SysResponse
|
||||
) -> Response | None:
|
||||
"""Split send (part 2 of 2)."""
|
||||
return self._sender.unpack_raw_response(
|
||||
bound_obj=self._obj, message=message, raw_response=raw_response
|
||||
)
|
||||
75
dist/ba_data/python/efro/net.py
vendored
Normal file
75
dist/ba_data/python/efro/net.py
vendored
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Network related functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def is_urllib_network_error(exc: BaseException) -> bool:
|
||||
"""Is the provided exception a network-related error?
|
||||
|
||||
This should be passed an exception which resulted from opening or
|
||||
reading a urllib Request. It should return True for any errors that
|
||||
could conceivably arise due to unavailable/poor network connections,
|
||||
firewall/connectivity issues, etc. These issues can often be safely
|
||||
ignored or presented to the user as general 'network-unavailable'
|
||||
states.
|
||||
"""
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import http.client
|
||||
import errno
|
||||
import socket
|
||||
if isinstance(
|
||||
exc,
|
||||
(urllib.error.URLError, ConnectionError, http.client.IncompleteRead,
|
||||
http.client.BadStatusLine, socket.timeout)):
|
||||
return True
|
||||
if isinstance(exc, OSError):
|
||||
if exc.errno == 10051: # Windows unreachable network error.
|
||||
return True
|
||||
if exc.errno in {
|
||||
errno.ETIMEDOUT,
|
||||
errno.EHOSTUNREACH,
|
||||
errno.ENETUNREACH,
|
||||
}:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_udp_network_error(exc: BaseException) -> bool:
|
||||
"""Is the provided exception a network-related error?
|
||||
|
||||
This should be passed an exception which resulted from creating and
|
||||
using a socket.SOCK_DGRAM type socket. It should return True for any
|
||||
errors that could conceivably arise due to unavailable/poor network
|
||||
connections, firewall/connectivity issues, etc. These issues can often
|
||||
be safely ignored or presented to the user as general
|
||||
'network-unavailable' states.
|
||||
"""
|
||||
import errno
|
||||
if isinstance(exc, ConnectionRefusedError):
|
||||
return True
|
||||
if isinstance(exc, OSError):
|
||||
if exc.errno == 10051: # Windows unreachable network error.
|
||||
return True
|
||||
if exc.errno in {
|
||||
errno.EADDRNOTAVAIL,
|
||||
errno.ETIMEDOUT,
|
||||
errno.EHOSTUNREACH,
|
||||
errno.ENETUNREACH,
|
||||
errno.EINVAL,
|
||||
errno.EPERM,
|
||||
errno.EACCES,
|
||||
# Windows 'invalid argument' error.
|
||||
10022,
|
||||
# Windows 'a socket operation was attempted to'
|
||||
# 'an unreachable network' error.
|
||||
10051,
|
||||
}:
|
||||
return True
|
||||
return False
|
||||
945
dist/ba_data/python/efro/rpc.py
vendored
Normal file
945
dist/ba_data/python/efro/rpc.py
vendored
Normal file
|
|
@ -0,0 +1,945 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Remote procedure call related functionality."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import weakref
|
||||
from enum import Enum
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from threading import current_thread
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
from efro.util import assert_never
|
||||
from efro.error import (
|
||||
CommunicationError,
|
||||
is_asyncio_streams_communication_error,
|
||||
)
|
||||
from efro.dataclassio import (
|
||||
dataclass_to_json,
|
||||
dataclass_from_json,
|
||||
ioprepped,
|
||||
IOAttrs,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal, Awaitable, Callable
|
||||
|
||||
# Terminology:
|
||||
# Packet: A chunk of data consisting of a type and some type-dependent
|
||||
# payload. Even though we use streams we organize our transmission
|
||||
# into 'packets'.
|
||||
# Message: User data which we transmit using one or more packets.
|
||||
|
||||
|
||||
class _PacketType(Enum):
|
||||
HANDSHAKE = 0
|
||||
KEEPALIVE = 1
|
||||
MESSAGE = 2
|
||||
RESPONSE = 3
|
||||
MESSAGE_BIG = 4
|
||||
RESPONSE_BIG = 5
|
||||
|
||||
|
||||
_BYTE_ORDER: Literal['big'] = 'big'
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _PeerInfo:
|
||||
|
||||
# So we can gracefully evolve how we communicate in the future.
|
||||
protocol: Annotated[int, IOAttrs('p')]
|
||||
|
||||
# How often we'll be sending out keepalives (in seconds).
|
||||
keepalive_interval: Annotated[float, IOAttrs('k')]
|
||||
|
||||
|
||||
# Note: we are expected to be forward and backward compatible; we can
|
||||
# increment protocol freely and expect everyone else to still talk to us.
|
||||
# Likewise we should retain logic to communicate with older protocols.
|
||||
# Protocol history:
|
||||
# 1 - initial release
|
||||
# 2 - gained big (32-bit len val) package/response packets
|
||||
OUR_PROTOCOL = 2
|
||||
|
||||
|
||||
def ssl_stream_writer_underlying_transport_info(
|
||||
writer: asyncio.StreamWriter,
|
||||
) -> str:
|
||||
"""For debugging SSL Stream connections; returns raw transport info."""
|
||||
# Note: accessing internals here so just returning info and not
|
||||
# actual objs to reduce potential for breakage.
|
||||
transport = getattr(writer, '_transport', None)
|
||||
if transport is not None:
|
||||
sslproto = getattr(transport, '_ssl_protocol', None)
|
||||
if sslproto is not None:
|
||||
raw_transport = getattr(sslproto, '_transport', None)
|
||||
if raw_transport is not None:
|
||||
return str(raw_transport)
|
||||
return '(not found)'
|
||||
|
||||
|
||||
def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None:
|
||||
"""Ensure a writer is closed; hacky workaround for odd hang."""
|
||||
from efro.call import tpartial
|
||||
from threading import Thread
|
||||
|
||||
# Disabling for now..
|
||||
if bool(True):
|
||||
return
|
||||
|
||||
# Hopefully can remove this in Python 3.11?...
|
||||
# see issue with is_closing() below for more details.
|
||||
transport = getattr(writer, '_transport', None)
|
||||
if transport is not None:
|
||||
sslproto = getattr(transport, '_ssl_protocol', None)
|
||||
if sslproto is not None:
|
||||
raw_transport = getattr(sslproto, '_transport', None)
|
||||
if raw_transport is not None:
|
||||
Thread(
|
||||
target=tpartial(
|
||||
_do_writer_force_close_check,
|
||||
weakref.ref(raw_transport),
|
||||
),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
|
||||
def _do_writer_force_close_check(transport_weak: weakref.ref) -> None:
|
||||
try:
|
||||
# Attempt to bail as soon as the obj dies.
|
||||
# If it hasn't done so by our timeout, force-kill it.
|
||||
starttime = time.monotonic()
|
||||
while time.monotonic() - starttime < 10.0:
|
||||
time.sleep(0.1)
|
||||
if transport_weak() is None:
|
||||
return
|
||||
transport = transport_weak()
|
||||
if transport is not None:
|
||||
logging.info('Forcing abort on stuck transport %s.', transport)
|
||||
transport.abort()
|
||||
except Exception:
|
||||
logging.warning('Error in writer-force-close-check', exc_info=True)
|
||||
|
||||
|
||||
class _InFlightMessage:
|
||||
"""Represents a message that is out on the wire."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._response: bytes | None = None
|
||||
self._got_response = asyncio.Event()
|
||||
self.wait_task = asyncio.create_task(
|
||||
self._wait(), name='rpc in flight msg wait'
|
||||
)
|
||||
|
||||
async def _wait(self) -> bytes:
|
||||
await self._got_response.wait()
|
||||
assert self._response is not None
|
||||
return self._response
|
||||
|
||||
def set_response(self, data: bytes) -> None:
|
||||
"""Set response data."""
|
||||
assert self._response is None
|
||||
self._response = data
|
||||
self._got_response.set()
|
||||
|
||||
|
||||
class _KeepaliveTimeoutError(Exception):
|
||||
"""Raised if we time out due to not receiving keepalives."""
|
||||
|
||||
|
||||
class RPCEndpoint:
|
||||
"""Facilitates asynchronous multiplexed remote procedure calls.
|
||||
|
||||
Be aware that, while multiple calls can be in flight in either direction
|
||||
simultaneously, packets are still sent serially in a single
|
||||
stream. So excessively long messages/responses will delay all other
|
||||
communication. If/when this becomes an issue we can look into breaking up
|
||||
long messages into multiple packets.
|
||||
"""
|
||||
|
||||
# Set to True on an instance to test keepalive failures.
|
||||
test_suppress_keepalives: bool = False
|
||||
|
||||
# How long we should wait before giving up on a message by default.
|
||||
# Note this includes processing time on the other end.
|
||||
DEFAULT_MESSAGE_TIMEOUT = 60.0
|
||||
|
||||
# How often we send out keepalive packets by default.
|
||||
DEFAULT_KEEPALIVE_INTERVAL = 10.73 # (avoid too regular of values)
|
||||
|
||||
# How long we can go without receiving a keepalive packet before we
|
||||
# disconnect.
|
||||
DEFAULT_KEEPALIVE_TIMEOUT = 30.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
label: str,
|
||||
debug_print: bool = False,
|
||||
debug_print_io: bool = False,
|
||||
debug_print_call: Callable[[str], None] | None = None,
|
||||
keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
|
||||
keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT,
|
||||
) -> None:
|
||||
self._handle_raw_message_call = handle_raw_message_call
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
self.debug_print = debug_print
|
||||
self.debug_print_io = debug_print_io
|
||||
if debug_print_call is None:
|
||||
debug_print_call = print
|
||||
self.debug_print_call: Callable[[str], None] = debug_print_call
|
||||
self._label = label
|
||||
self._thread = current_thread()
|
||||
self._closing = False
|
||||
self._did_wait_closed = False
|
||||
self._event_loop = asyncio.get_running_loop()
|
||||
self._out_packets = deque[bytes]()
|
||||
self._have_out_packets = asyncio.Event()
|
||||
self._run_called = False
|
||||
self._peer_info: _PeerInfo | None = None
|
||||
self._keepalive_interval = keepalive_interval
|
||||
self._keepalive_timeout = keepalive_timeout
|
||||
self._did_close_writer = False
|
||||
self._did_wait_closed_writer = False
|
||||
self._did_out_packets_buildup_warning = False
|
||||
self._total_bytes_read = 0
|
||||
self._create_time = time.monotonic()
|
||||
|
||||
# Need to hold weak-refs to these otherwise it creates dep-loops
|
||||
# which keeps us alive.
|
||||
self._tasks: list[asyncio.Task] = []
|
||||
|
||||
# When we last got a keepalive or equivalent (time.monotonic value)
|
||||
self._last_keepalive_receive_time: float | None = None
|
||||
|
||||
# (Start near the end to make sure our looping logic is sound).
|
||||
self._next_message_id = 65530
|
||||
|
||||
self._in_flight_messages: dict[int, _InFlightMessage] = {}
|
||||
|
||||
if self.debug_print:
|
||||
peername = self._writer.get_extra_info('peername')
|
||||
self.debug_print_call(
|
||||
f'{self._label}: connected to {peername} at {self._tm()}.'
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._run_called:
|
||||
if not self._did_close_writer:
|
||||
logging.warning(
|
||||
'RPCEndpoint %d dying with run'
|
||||
' called but writer not closed (transport=%s).',
|
||||
id(self),
|
||||
ssl_stream_writer_underlying_transport_info(self._writer),
|
||||
)
|
||||
elif not self._did_wait_closed_writer:
|
||||
logging.warning(
|
||||
'RPCEndpoint %d dying with run called'
|
||||
' but writer not wait-closed (transport=%s).',
|
||||
id(self),
|
||||
ssl_stream_writer_underlying_transport_info(self._writer),
|
||||
)
|
||||
|
||||
# Currently seeing rare issue where sockets don't go down;
|
||||
# let's add a timer to force the issue until we can figure it out.
|
||||
ssl_stream_writer_force_close_check(self._writer)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the endpoint until the connection is lost or closed.
|
||||
|
||||
Handles closing the provided reader/writer on close.
|
||||
"""
|
||||
try:
|
||||
await self._do_run()
|
||||
except asyncio.CancelledError:
|
||||
# We aren't really designed to be cancelled so let's warn
|
||||
# if it happens.
|
||||
logging.warning(
|
||||
'RPCEndpoint.run got CancelledError;'
|
||||
' want to try and avoid this.'
|
||||
)
|
||||
raise
|
||||
|
||||
async def _do_run(self) -> None:
|
||||
|
||||
self._check_env()
|
||||
|
||||
if self._run_called:
|
||||
raise RuntimeError('Run can be called only once per endpoint.')
|
||||
self._run_called = True
|
||||
|
||||
core_tasks = [
|
||||
asyncio.create_task(
|
||||
self._run_core_task('keepalive', self._run_keepalive_task()),
|
||||
name='rpc keepalive',
|
||||
),
|
||||
asyncio.create_task(
|
||||
self._run_core_task('read', self._run_read_task()),
|
||||
name='rpc read',
|
||||
),
|
||||
asyncio.create_task(
|
||||
self._run_core_task('write', self._run_write_task()),
|
||||
name='rpc write',
|
||||
),
|
||||
]
|
||||
self._tasks += core_tasks
|
||||
|
||||
# Run our core tasks until they all complete.
|
||||
results = await asyncio.gather(*core_tasks, return_exceptions=True)
|
||||
|
||||
# Core tasks should handle their own errors; the only ones
|
||||
# we expect to bubble up are CancelledError.
|
||||
for result in results:
|
||||
# We want to know if any errors happened aside from CancelledError
|
||||
# (which are BaseExceptions, not Exception).
|
||||
if isinstance(result, Exception):
|
||||
logging.warning(
|
||||
'Got unexpected error from %s core task: %s',
|
||||
self._label,
|
||||
result,
|
||||
)
|
||||
|
||||
if not all(task.done() for task in core_tasks):
|
||||
logging.warning(
|
||||
'RPCEndpoint %d: not all core tasks marked done after gather.',
|
||||
id(self),
|
||||
)
|
||||
|
||||
# Shut ourself down.
|
||||
try:
|
||||
self.close()
|
||||
await self.wait_closed()
|
||||
except Exception:
|
||||
logging.exception('Error closing %s.', self._label)
|
||||
|
||||
if self.debug_print:
|
||||
self.debug_print_call(f'{self._label}: finished.')
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
message: bytes,
|
||||
timeout: float | None = None,
|
||||
close_on_error: bool = True,
|
||||
) -> Awaitable[bytes]:
|
||||
"""Send a message to the peer and return a response.
|
||||
|
||||
If timeout is not provided, the default will be used.
|
||||
Raises a CommunicationError if the round trip is not completed
|
||||
for any reason.
|
||||
|
||||
By default, the entire endpoint will go down in the case of
|
||||
errors. This allows messages to be treated as 'reliable' with
|
||||
respect to a given endpoint. Pass close_on_error=False to
|
||||
override this for a particular message.
|
||||
"""
|
||||
# Note: This call is synchronous so that the first part of it
|
||||
# (enqueueing outgoing messages) happens synchronously. If it were
|
||||
# a pure async call it could be possible for send order to vary
|
||||
# based on how the async tasks get processed.
|
||||
|
||||
if self.debug_print_io:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: sending message of size {len(message)}'
|
||||
f' at {self._tm()}.'
|
||||
)
|
||||
|
||||
self._check_env()
|
||||
|
||||
if self._closing:
|
||||
raise CommunicationError('Endpoint is closed.')
|
||||
|
||||
if self.debug_print_io:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: have peerinfo? {self._peer_info is not None}.'
|
||||
)
|
||||
|
||||
# message_id is a 16 bit looping value.
|
||||
message_id = self._next_message_id
|
||||
self._next_message_id = (self._next_message_id + 1) % 65536
|
||||
|
||||
if self.debug_print_io:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: will enqueue at {self._tm()}.'
|
||||
)
|
||||
|
||||
# FIXME - should handle backpressure (waiting here if there are
|
||||
# enough packets already enqueued).
|
||||
|
||||
if len(message) > 65535:
|
||||
# Payload consists of type (1b), message_id (2b),
|
||||
# len (4b), and data.
|
||||
self._enqueue_outgoing_packet(
|
||||
_PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER)
|
||||
+ message_id.to_bytes(2, _BYTE_ORDER)
|
||||
+ len(message).to_bytes(4, _BYTE_ORDER)
|
||||
+ message
|
||||
)
|
||||
else:
|
||||
# Payload consists of type (1b), message_id (2b),
|
||||
# len (2b), and data.
|
||||
self._enqueue_outgoing_packet(
|
||||
_PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER)
|
||||
+ message_id.to_bytes(2, _BYTE_ORDER)
|
||||
+ len(message).to_bytes(2, _BYTE_ORDER)
|
||||
+ message
|
||||
)
|
||||
|
||||
if self.debug_print_io:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: enqueued message of size {len(message)}'
|
||||
f' at {self._tm()}.'
|
||||
)
|
||||
|
||||
# Make an entry so we know this message is out there.
|
||||
assert message_id not in self._in_flight_messages
|
||||
msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
|
||||
|
||||
# Also add its task to our list so we properly cancel it if we die.
|
||||
self._prune_tasks() # Keep our list from filling with dead tasks.
|
||||
self._tasks.append(msgobj.wait_task)
|
||||
|
||||
# Note: we always want to incorporate a timeout. Individual
|
||||
# messages may hang or error on the other end and this ensures
|
||||
# we won't build up lots of zombie tasks waiting around for
|
||||
# responses that will never arrive.
|
||||
if timeout is None:
|
||||
timeout = self.DEFAULT_MESSAGE_TIMEOUT
|
||||
assert timeout is not None
|
||||
|
||||
bytes_awaitable = msgobj.wait_task
|
||||
|
||||
# Now complete the send asynchronously.
|
||||
return self._send_message(
|
||||
message, timeout, close_on_error, bytes_awaitable, message_id
|
||||
)
|
||||
|
||||
async def _send_message(
|
||||
self,
|
||||
message: bytes,
|
||||
timeout: float | None,
|
||||
close_on_error: bool,
|
||||
bytes_awaitable: asyncio.Task[bytes],
|
||||
message_id: int,
|
||||
) -> bytes:
|
||||
|
||||
# We need to know their protocol, so if we haven't gotten a handshake
|
||||
# from them yet, just wait.
|
||||
while self._peer_info is None:
|
||||
await asyncio.sleep(0.01)
|
||||
assert self._peer_info is not None
|
||||
|
||||
if self._peer_info.protocol == 1:
|
||||
if len(message) > 65535:
|
||||
raise RuntimeError('Message cannot be larger than 65535 bytes')
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(bytes_awaitable, timeout=timeout)
|
||||
except asyncio.CancelledError as exc:
|
||||
# Question: we assume this means the above wait_for() was
|
||||
# cancelled; how do we distinguish between this and *us* being
|
||||
# cancelled though?
|
||||
if self.debug_print:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: message {message_id} was cancelled.'
|
||||
)
|
||||
if close_on_error:
|
||||
self.close()
|
||||
|
||||
raise CommunicationError() from exc
|
||||
except Exception as exc:
|
||||
|
||||
# If our timer timed-out or anything else went wrong with
|
||||
# the stream, lump it in as a communication error.
|
||||
if isinstance(
|
||||
exc, asyncio.TimeoutError
|
||||
) or is_asyncio_streams_communication_error(exc):
|
||||
|
||||
if self.debug_print:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: got {type(exc)} sending message'
|
||||
f' {message_id}; raising CommunicationError.'
|
||||
)
|
||||
|
||||
# Stop waiting on the response.
|
||||
bytes_awaitable.cancel()
|
||||
|
||||
# Remove the record of this message.
|
||||
del self._in_flight_messages[message_id]
|
||||
|
||||
if close_on_error:
|
||||
self.close()
|
||||
|
||||
# Let the user know something went wrong.
|
||||
raise CommunicationError() from exc
|
||||
|
||||
# Some unexpected error; let it bubble up.
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
"""I said seagulls; mmmm; stop it now."""
|
||||
self._check_env()
|
||||
|
||||
if self._closing:
|
||||
return
|
||||
|
||||
if self.debug_print:
|
||||
self.debug_print_call(f'{self._label}: closing...')
|
||||
|
||||
self._closing = True
|
||||
|
||||
# Kill all of our in-flight tasks.
|
||||
if self.debug_print:
|
||||
self.debug_print_call(f'{self._label}: cancelling tasks...')
|
||||
for task in self._get_live_tasks():
|
||||
task.cancel()
|
||||
|
||||
# Close our writer.
|
||||
assert not self._did_close_writer
|
||||
if self.debug_print:
|
||||
self.debug_print_call(f'{self._label}: closing writer...')
|
||||
self._writer.close()
|
||||
self._did_close_writer = True
|
||||
|
||||
# We don't need this anymore and it is likely to be creating a
|
||||
# dependency loop.
|
||||
del self._handle_raw_message_call
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
"""Have we begun the process of closing?"""
|
||||
return self._closing
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
"""I said seagulls; mmmm; stop it now.
|
||||
|
||||
Wait for the endpoint to finish closing. This is called by run()
|
||||
so generally does not need to be explicitly called.
|
||||
"""
|
||||
# pylint: disable=too-many-branches
|
||||
self._check_env()
|
||||
|
||||
# Make sure we only *enter* this call once.
|
||||
if self._did_wait_closed:
|
||||
return
|
||||
self._did_wait_closed = True
|
||||
|
||||
if not self._closing:
|
||||
raise RuntimeError('Must be called after close()')
|
||||
|
||||
if not self._did_close_writer:
|
||||
logging.warning(
|
||||
'RPCEndpoint wait_closed() called but never'
|
||||
' explicitly closed writer.'
|
||||
)
|
||||
|
||||
live_tasks = self._get_live_tasks()
|
||||
|
||||
# Don't need our task list anymore; this should
|
||||
# break any cyclical refs from tasks referring to us.
|
||||
self._tasks = []
|
||||
|
||||
if self.debug_print:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: waiting for tasks to finish: '
|
||||
f' ({live_tasks=})...'
|
||||
)
|
||||
|
||||
# Wait for all of our in-flight tasks to wrap up.
|
||||
results = await asyncio.gather(*live_tasks, return_exceptions=True)
|
||||
for result in results:
|
||||
# We want to know if any errors happened aside from CancelledError
|
||||
# (which are BaseExceptions, not Exception).
|
||||
if isinstance(result, Exception):
|
||||
logging.warning(
|
||||
'Got unexpected error cleaning up %s task: %s',
|
||||
self._label,
|
||||
result,
|
||||
)
|
||||
|
||||
if not all(task.done() for task in live_tasks):
|
||||
logging.warning(
|
||||
'RPCEndpoint %d: not all live tasks marked done after gather.',
|
||||
id(self),
|
||||
)
|
||||
|
||||
if self.debug_print:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: tasks finished; waiting for writer close...'
|
||||
)
|
||||
|
||||
# Now wait for our writer to finish going down.
|
||||
# When we close our writer it generally triggers errors
|
||||
# in our current blocked read/writes. However that same
|
||||
# error is also sometimes returned from _writer.wait_closed().
|
||||
# See connection_lost() in asyncio/streams.py to see why.
|
||||
# So let's silently ignore it when that happens.
|
||||
assert self._writer.is_closing()
|
||||
try:
|
||||
# It seems that as of Python 3.9.x it is possible for this to hang
|
||||
# indefinitely. See https://github.com/python/cpython/issues/83939
|
||||
# It sounds like this should be fixed in 3.11 but for now just
|
||||
# forcing the issue with a timeout here.
|
||||
await asyncio.wait_for(
|
||||
self._writer.wait_closed(),
|
||||
# timeout=60.0 * 6.0,
|
||||
timeout=30.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logging.info(
|
||||
'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
|
||||
self._label,
|
||||
ssl_stream_writer_underlying_transport_info(self._writer),
|
||||
)
|
||||
if self.debug_print:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: got timeout in _writer.wait_closed();'
|
||||
' This should be fixed in future Python versions.'
|
||||
)
|
||||
except Exception as exc:
|
||||
if not self._is_expected_connection_error(exc):
|
||||
logging.exception('Error closing _writer for %s.', self._label)
|
||||
else:
|
||||
if self.debug_print:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: silently ignoring error in'
|
||||
f' _writer.wait_closed(): {exc}.'
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logging.warning(
|
||||
'RPCEndpoint.wait_closed()'
|
||||
' got asyncio.CancelledError; not expected.'
|
||||
)
|
||||
raise
|
||||
assert not self._did_wait_closed_writer
|
||||
self._did_wait_closed_writer = True
|
||||
|
||||
def _tm(self) -> str:
|
||||
"""Simple readable time value for debugging."""
|
||||
tval = time.monotonic() % 100.0
|
||||
return f'{tval:.2f}'
|
||||
|
||||
async def _run_read_task(self) -> None:
|
||||
"""Read from the peer."""
|
||||
self._check_env()
|
||||
assert self._peer_info is None
|
||||
|
||||
# Bug fix: if we don't have this set we will never time out
|
||||
# if we never receive any data from the other end.
|
||||
self._last_keepalive_receive_time = time.monotonic()
|
||||
|
||||
# The first thing they should send us is their handshake; then
|
||||
# we'll know if/how we can talk to them.
|
||||
mlen = await self._read_int_32()
|
||||
message = await self._reader.readexactly(mlen)
|
||||
self._total_bytes_read += mlen
|
||||
self._peer_info = dataclass_from_json(_PeerInfo, message.decode())
|
||||
self._last_keepalive_receive_time = time.monotonic()
|
||||
if self.debug_print:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: received handshake at {self._tm()}.'
|
||||
)
|
||||
|
||||
# Now just sit and handle stuff as it comes in.
|
||||
while True:
|
||||
if self._closing:
|
||||
return
|
||||
|
||||
# Read message type.
|
||||
mtype = _PacketType(await self._read_int_8())
|
||||
if mtype is _PacketType.HANDSHAKE:
|
||||
raise RuntimeError('Got multiple handshakes')
|
||||
|
||||
if mtype is _PacketType.KEEPALIVE:
|
||||
if self.debug_print_io:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: received keepalive'
|
||||
f' at {self._tm()}.'
|
||||
)
|
||||
self._last_keepalive_receive_time = time.monotonic()
|
||||
|
||||
elif mtype is _PacketType.MESSAGE:
|
||||
await self._handle_message_packet(big=False)
|
||||
|
||||
elif mtype is _PacketType.MESSAGE_BIG:
|
||||
await self._handle_message_packet(big=True)
|
||||
|
||||
elif mtype is _PacketType.RESPONSE:
|
||||
await self._handle_response_packet(big=False)
|
||||
|
||||
elif mtype is _PacketType.RESPONSE_BIG:
|
||||
await self._handle_response_packet(big=True)
|
||||
|
||||
else:
|
||||
assert_never(mtype)
|
||||
|
||||
async def _handle_message_packet(self, big: bool) -> None:
|
||||
assert self._peer_info is not None
|
||||
msgid = await self._read_int_16()
|
||||
if big:
|
||||
msglen = await self._read_int_32()
|
||||
else:
|
||||
msglen = await self._read_int_16()
|
||||
msg = await self._reader.readexactly(msglen)
|
||||
self._total_bytes_read += msglen
|
||||
if self.debug_print_io:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: received message {msgid}'
|
||||
f' of size {msglen} at {self._tm()}.'
|
||||
)
|
||||
|
||||
# Create a message-task to handle this message and return
|
||||
# a response (we don't want to block while that happens).
|
||||
assert not self._closing
|
||||
self._prune_tasks() # Keep from filling with dead tasks.
|
||||
self._tasks.append(
|
||||
asyncio.create_task(
|
||||
self._handle_raw_message(message_id=msgid, message=msg),
|
||||
name='efro rpc message handle',
|
||||
)
|
||||
)
|
||||
if self.debug_print:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: done handling message at {self._tm()}.'
|
||||
)
|
||||
|
||||
async def _handle_response_packet(self, big: bool) -> None:
|
||||
assert self._peer_info is not None
|
||||
msgid = await self._read_int_16()
|
||||
# Protocol 2 gained 32 bit data lengths.
|
||||
if big:
|
||||
rsplen = await self._read_int_32()
|
||||
else:
|
||||
rsplen = await self._read_int_16()
|
||||
if self.debug_print_io:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: received response {msgid}'
|
||||
f' of size {rsplen} at {self._tm()}.'
|
||||
)
|
||||
rsp = await self._reader.readexactly(rsplen)
|
||||
self._total_bytes_read += rsplen
|
||||
msgobj = self._in_flight_messages.get(msgid)
|
||||
if msgobj is None:
|
||||
# It's possible for us to get a response to a message
|
||||
# that has timed out. In this case we will have no local
|
||||
# record of it.
|
||||
if self.debug_print:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: got response for nonexistent'
|
||||
f' message id {msgid}; perhaps it timed out?'
|
||||
)
|
||||
else:
|
||||
msgobj.set_response(rsp)
|
||||
|
||||
async def _run_write_task(self) -> None:
|
||||
"""Write to the peer."""
|
||||
|
||||
self._check_env()
|
||||
|
||||
# Introduce ourself so our peer knows how it can talk to us.
|
||||
data = dataclass_to_json(
|
||||
_PeerInfo(
|
||||
protocol=OUR_PROTOCOL,
|
||||
keepalive_interval=self._keepalive_interval,
|
||||
)
|
||||
).encode()
|
||||
self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data)
|
||||
|
||||
# Now just write out-messages as they come in.
|
||||
while True:
|
||||
|
||||
# Wait until some data comes in.
|
||||
await self._have_out_packets.wait()
|
||||
|
||||
assert self._out_packets
|
||||
data = self._out_packets.popleft()
|
||||
|
||||
# Important: only clear this once all packets are sent.
|
||||
if not self._out_packets:
|
||||
self._have_out_packets.clear()
|
||||
|
||||
self._writer.write(data)
|
||||
|
||||
# This should keep our writer from buffering huge amounts
|
||||
# of outgoing data. We must remember though that we also
|
||||
# need to prevent _out_packets from growing too large and
|
||||
# that part's on us.
|
||||
await self._writer.drain()
|
||||
|
||||
# For now we're not applying backpressure, but let's make
|
||||
# noise if this gets out of hand.
|
||||
if len(self._out_packets) > 200:
|
||||
if not self._did_out_packets_buildup_warning:
|
||||
logging.warning(
|
||||
'_out_packets building up too'
|
||||
' much on RPCEndpoint %s.',
|
||||
id(self),
|
||||
)
|
||||
self._did_out_packets_buildup_warning = True
|
||||
|
||||
async def _run_keepalive_task(self) -> None:
|
||||
"""Send periodic keepalive packets."""
|
||||
self._check_env()
|
||||
|
||||
# We explicitly send our own keepalive packets so we can stay
|
||||
# more on top of the connection state and possibly decide to
|
||||
# kill it when contact is lost more quickly than the OS would
|
||||
# do itself (or at least keep the user informed that the
|
||||
# connection is lagging). It sounds like we could have the TCP
|
||||
# layer do this sort of thing itself but that might be
|
||||
# OS-specific so gonna go this way for now.
|
||||
while True:
|
||||
assert not self._closing
|
||||
await asyncio.sleep(self._keepalive_interval)
|
||||
if not self.test_suppress_keepalives:
|
||||
self._enqueue_outgoing_packet(
|
||||
_PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER)
|
||||
)
|
||||
|
||||
# Also go ahead and handle dropping the connection if we
|
||||
# haven't heard from the peer in a while.
|
||||
# NOTE: perhaps we want to do something more exact than
|
||||
# this which only checks once per keepalive-interval?..
|
||||
now = time.monotonic()
|
||||
if (
|
||||
self._last_keepalive_receive_time is not None
|
||||
and now - self._last_keepalive_receive_time
|
||||
> self._keepalive_timeout
|
||||
):
|
||||
if self.debug_print:
|
||||
since = now - self._last_keepalive_receive_time
|
||||
self.debug_print_call(
|
||||
f'{self._label}: reached keepalive time-out'
|
||||
f' ({since:.1f}s).'
|
||||
)
|
||||
raise _KeepaliveTimeoutError()
|
||||
|
||||
async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None:
|
||||
try:
|
||||
await call
|
||||
except Exception as exc:
|
||||
# We expect connection errors to put us here, but make noise
|
||||
# if something else does.
|
||||
if not self._is_expected_connection_error(exc):
|
||||
logging.exception(
|
||||
'Unexpected error in rpc %s %s task'
|
||||
' (age=%.1f, total_bytes_read=%d).',
|
||||
self._label,
|
||||
tasklabel,
|
||||
time.monotonic() - self._create_time,
|
||||
self._total_bytes_read,
|
||||
)
|
||||
else:
|
||||
if self.debug_print:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: {tasklabel} task will exit cleanly'
|
||||
f' due to {exc!r}.'
|
||||
)
|
||||
finally:
|
||||
# Any core task exiting triggers shutdown.
|
||||
if self.debug_print:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: {tasklabel} task exiting...'
|
||||
)
|
||||
self.close()
|
||||
|
||||
async def _handle_raw_message(
|
||||
self, message_id: int, message: bytes
|
||||
) -> None:
|
||||
try:
|
||||
response = await self._handle_raw_message_call(message)
|
||||
except Exception:
|
||||
# We expect local message handler to always succeed.
|
||||
# If that doesn't happen, make a fuss so we know to fix it.
|
||||
# The other end will simply never get a response to this
|
||||
# message.
|
||||
logging.exception('Error handling raw rpc message')
|
||||
return
|
||||
|
||||
assert self._peer_info is not None
|
||||
|
||||
if self._peer_info.protocol == 1:
|
||||
if len(response) > 65535:
|
||||
raise RuntimeError('Response cannot be larger than 65535 bytes')
|
||||
|
||||
# Now send back our response.
|
||||
# Payload consists of type (1b), msgid (2b), len (2b), and data.
|
||||
if len(response) > 65535:
|
||||
self._enqueue_outgoing_packet(
|
||||
_PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER)
|
||||
+ message_id.to_bytes(2, _BYTE_ORDER)
|
||||
+ len(response).to_bytes(4, _BYTE_ORDER)
|
||||
+ response
|
||||
)
|
||||
else:
|
||||
self._enqueue_outgoing_packet(
|
||||
_PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER)
|
||||
+ message_id.to_bytes(2, _BYTE_ORDER)
|
||||
+ len(response).to_bytes(2, _BYTE_ORDER)
|
||||
+ response
|
||||
)
|
||||
|
||||
async def _read_int_8(self) -> int:
|
||||
out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER)
|
||||
self._total_bytes_read += 1
|
||||
return out
|
||||
|
||||
async def _read_int_16(self) -> int:
|
||||
out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER)
|
||||
self._total_bytes_read += 2
|
||||
return out
|
||||
|
||||
async def _read_int_32(self) -> int:
|
||||
out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER)
|
||||
self._total_bytes_read += 4
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def _is_expected_connection_error(cls, exc: Exception) -> bool:
|
||||
"""Stuff we expect to end our connection in normal circumstances."""
|
||||
|
||||
if isinstance(exc, _KeepaliveTimeoutError):
|
||||
return True
|
||||
|
||||
return is_asyncio_streams_communication_error(exc)
|
||||
|
||||
def _check_env(self) -> None:
|
||||
# I was seeing that asyncio stuff wasn't working as expected if
|
||||
# created in one thread and used in another (and have verified
|
||||
# that this is part of the design), so let's enforce a single
|
||||
# thread for all use of an instance.
|
||||
if current_thread() is not self._thread:
|
||||
raise RuntimeError(
|
||||
'This must be called from the same thread'
|
||||
' that the endpoint was created in.'
|
||||
)
|
||||
|
||||
# This should always be the case if thread is the same.
|
||||
assert asyncio.get_running_loop() is self._event_loop
|
||||
|
||||
def _enqueue_outgoing_packet(self, data: bytes) -> None:
|
||||
"""Enqueue a raw packet to be sent. Must be called from our loop."""
|
||||
self._check_env()
|
||||
|
||||
if self.debug_print_io:
|
||||
self.debug_print_call(
|
||||
f'{self._label}: enqueueing outgoing packet'
|
||||
f' {data[:50]!r} at {self._tm()}.'
|
||||
)
|
||||
|
||||
# Add the data and let our write task know about it.
|
||||
self._out_packets.append(data)
|
||||
self._have_out_packets.set()
|
||||
|
||||
def _prune_tasks(self) -> None:
|
||||
self._tasks = self._get_live_tasks()
|
||||
|
||||
def _get_live_tasks(self) -> list[asyncio.Task]:
|
||||
return [t for t in self._tasks if not t.done()]
|
||||
321
dist/ba_data/python/efro/terminal.py
vendored
Normal file
321
dist/ba_data/python/efro/terminal.py
vendored
Normal file
|
|
@ -0,0 +1,321 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality related to terminal IO."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, ClassVar
|
||||
|
||||
|
||||
@unique
|
||||
class TerminalColor(Enum):
|
||||
"""Color codes for printing to terminals.
|
||||
|
||||
Generally the Clr class should be used when incorporating color into
|
||||
terminal output, as it handles non-color-supporting terminals/etc.
|
||||
"""
|
||||
|
||||
# Styles
|
||||
RESET = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
UNDERLINE = '\033[4m'
|
||||
INVERSE = '\033[7m'
|
||||
|
||||
# Normal foreground colors
|
||||
BLACK = '\033[30m'
|
||||
RED = '\033[31m'
|
||||
GREEN = '\033[32m'
|
||||
YELLOW = '\033[33m'
|
||||
BLUE = '\033[34m'
|
||||
MAGENTA = '\033[35m'
|
||||
CYAN = '\033[36m'
|
||||
WHITE = '\033[37m'
|
||||
|
||||
# Normal background colors.
|
||||
BG_BLACK = '\033[40m'
|
||||
BG_RED = '\033[41m'
|
||||
BG_GREEN = '\033[42m'
|
||||
BG_YELLOW = '\033[43m'
|
||||
BG_BLUE = '\033[44m'
|
||||
BG_MAGENTA = '\033[45m'
|
||||
BG_CYAN = '\033[46m'
|
||||
BG_WHITE = '\033[47m'
|
||||
|
||||
# Strong foreground colors
|
||||
STRONG_BLACK = '\033[90m'
|
||||
STRONG_RED = '\033[91m'
|
||||
STRONG_GREEN = '\033[92m'
|
||||
STRONG_YELLOW = '\033[93m'
|
||||
STRONG_BLUE = '\033[94m'
|
||||
STRONG_MAGENTA = '\033[95m'
|
||||
STRONG_CYAN = '\033[96m'
|
||||
STRONG_WHITE = '\033[97m'
|
||||
|
||||
# Strong background colors.
|
||||
STRONG_BG_BLACK = '\033[100m'
|
||||
STRONG_BG_RED = '\033[101m'
|
||||
STRONG_BG_GREEN = '\033[102m'
|
||||
STRONG_BG_YELLOW = '\033[103m'
|
||||
STRONG_BG_BLUE = '\033[104m'
|
||||
STRONG_BG_MAGENTA = '\033[105m'
|
||||
STRONG_BG_CYAN = '\033[106m'
|
||||
STRONG_BG_WHITE = '\033[107m'
|
||||
|
||||
|
||||
def _default_color_enabled() -> bool:
|
||||
"""Return whether we should enable ANSI color codes by default."""
|
||||
import platform
|
||||
|
||||
# If we're not attached to a terminal, go with no-color.
|
||||
if not sys.__stdout__.isatty():
|
||||
return False
|
||||
|
||||
# Another common way to say the terminal can't do fancy stuff like color:
|
||||
if os.environ.get('TERM') == 'dumb':
|
||||
return False
|
||||
|
||||
# On windows, try to enable ANSI color mode.
|
||||
if platform.system() == 'Windows':
|
||||
return _windows_enable_color()
|
||||
|
||||
# We seem to be a terminal with color support; let's do it!
|
||||
return True
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
def _windows_enable_color() -> bool:
|
||||
"""Attempt to enable ANSI color on windows terminal; return success."""
|
||||
# pylint: disable=invalid-name, import-error, undefined-variable
|
||||
# Pulled from: https://bugs.python.org/issue30075
|
||||
import msvcrt
|
||||
import ctypes
|
||||
from ctypes import wintypes
|
||||
|
||||
kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) # type: ignore
|
||||
|
||||
ERROR_INVALID_PARAMETER = 0x0057
|
||||
ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004
|
||||
|
||||
def _check_bool(result: Any, _func: Any, args: Any) -> Any:
|
||||
if not result:
|
||||
raise ctypes.WinError(ctypes.get_last_error()) # type: ignore
|
||||
return args
|
||||
|
||||
LPDWORD = ctypes.POINTER(wintypes.DWORD)
|
||||
kernel32.GetConsoleMode.errcheck = _check_bool
|
||||
kernel32.GetConsoleMode.argtypes = (wintypes.HANDLE, LPDWORD)
|
||||
kernel32.SetConsoleMode.errcheck = _check_bool
|
||||
kernel32.SetConsoleMode.argtypes = (wintypes.HANDLE, wintypes.DWORD)
|
||||
|
||||
def set_conout_mode(new_mode: int, mask: int = 0xFFFFFFFF) -> int:
|
||||
# don't assume StandardOutput is a console.
|
||||
# open CONOUT$ instead
|
||||
fdout = os.open('CONOUT$', os.O_RDWR)
|
||||
try:
|
||||
hout = msvcrt.get_osfhandle(fdout) # type: ignore
|
||||
# pylint: disable=useless-suppression
|
||||
# pylint: disable=no-value-for-parameter
|
||||
old_mode = wintypes.DWORD()
|
||||
# pylint: enable=useless-suppression
|
||||
kernel32.GetConsoleMode(hout, ctypes.byref(old_mode))
|
||||
mode = (new_mode & mask) | (old_mode.value & ~mask)
|
||||
kernel32.SetConsoleMode(hout, mode)
|
||||
return old_mode.value
|
||||
finally:
|
||||
os.close(fdout)
|
||||
|
||||
def enable_vt_mode() -> int:
|
||||
mode = mask = ENABLE_VIRTUAL_TERMINAL_PROCESSING
|
||||
try:
|
||||
return set_conout_mode(mode, mask)
|
||||
except WindowsError as exc: # type: ignore
|
||||
if exc.winerror == ERROR_INVALID_PARAMETER:
|
||||
raise NotImplementedError from exc
|
||||
raise
|
||||
|
||||
try:
|
||||
enable_vt_mode()
|
||||
return True
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
|
||||
class ClrBase:
|
||||
"""Base class for color convenience class."""
|
||||
|
||||
RST: ClassVar[str]
|
||||
BLD: ClassVar[str]
|
||||
UND: ClassVar[str]
|
||||
INV: ClassVar[str]
|
||||
|
||||
# Normal foreground colors
|
||||
BLK: ClassVar[str]
|
||||
RED: ClassVar[str]
|
||||
GRN: ClassVar[str]
|
||||
YLW: ClassVar[str]
|
||||
BLU: ClassVar[str]
|
||||
MAG: ClassVar[str]
|
||||
CYN: ClassVar[str]
|
||||
WHT: ClassVar[str]
|
||||
|
||||
# Normal background colors.
|
||||
BBLK: ClassVar[str]
|
||||
BRED: ClassVar[str]
|
||||
BGRN: ClassVar[str]
|
||||
BYLW: ClassVar[str]
|
||||
BBLU: ClassVar[str]
|
||||
BMAG: ClassVar[str]
|
||||
BCYN: ClassVar[str]
|
||||
BWHT: ClassVar[str]
|
||||
|
||||
# Strong foreground colors
|
||||
SBLK: ClassVar[str]
|
||||
SRED: ClassVar[str]
|
||||
SGRN: ClassVar[str]
|
||||
SYLW: ClassVar[str]
|
||||
SBLU: ClassVar[str]
|
||||
SMAG: ClassVar[str]
|
||||
SCYN: ClassVar[str]
|
||||
SWHT: ClassVar[str]
|
||||
|
||||
# Strong background colors.
|
||||
SBBLK: ClassVar[str]
|
||||
SBRED: ClassVar[str]
|
||||
SBGRN: ClassVar[str]
|
||||
SBYLW: ClassVar[str]
|
||||
SBBLU: ClassVar[str]
|
||||
SBMAG: ClassVar[str]
|
||||
SBCYN: ClassVar[str]
|
||||
SBWHT: ClassVar[str]
|
||||
|
||||
|
||||
class ClrAlways(ClrBase):
|
||||
"""Convenience class for color terminal output.
|
||||
|
||||
This version has colors always enabled. Generally you should use Clr which
|
||||
points to the correct enabled/disabled class depending on the environment.
|
||||
"""
|
||||
|
||||
color_enabled = True
|
||||
|
||||
# Styles
|
||||
RST = TerminalColor.RESET.value
|
||||
BLD = TerminalColor.BOLD.value
|
||||
UND = TerminalColor.UNDERLINE.value
|
||||
INV = TerminalColor.INVERSE.value
|
||||
|
||||
# Normal foreground colors
|
||||
BLK = TerminalColor.BLACK.value
|
||||
RED = TerminalColor.RED.value
|
||||
GRN = TerminalColor.GREEN.value
|
||||
YLW = TerminalColor.YELLOW.value
|
||||
BLU = TerminalColor.BLUE.value
|
||||
MAG = TerminalColor.MAGENTA.value
|
||||
CYN = TerminalColor.CYAN.value
|
||||
WHT = TerminalColor.WHITE.value
|
||||
|
||||
# Normal background colors.
|
||||
BBLK = TerminalColor.BG_BLACK.value
|
||||
BRED = TerminalColor.BG_RED.value
|
||||
BGRN = TerminalColor.BG_GREEN.value
|
||||
BYLW = TerminalColor.BG_YELLOW.value
|
||||
BBLU = TerminalColor.BG_BLUE.value
|
||||
BMAG = TerminalColor.BG_MAGENTA.value
|
||||
BCYN = TerminalColor.BG_CYAN.value
|
||||
BWHT = TerminalColor.BG_WHITE.value
|
||||
|
||||
# Strong foreground colors
|
||||
SBLK = TerminalColor.STRONG_BLACK.value
|
||||
SRED = TerminalColor.STRONG_RED.value
|
||||
SGRN = TerminalColor.STRONG_GREEN.value
|
||||
SYLW = TerminalColor.STRONG_YELLOW.value
|
||||
SBLU = TerminalColor.STRONG_BLUE.value
|
||||
SMAG = TerminalColor.STRONG_MAGENTA.value
|
||||
SCYN = TerminalColor.STRONG_CYAN.value
|
||||
SWHT = TerminalColor.STRONG_WHITE.value
|
||||
|
||||
# Strong background colors.
|
||||
SBBLK = TerminalColor.STRONG_BG_BLACK.value
|
||||
SBRED = TerminalColor.STRONG_BG_RED.value
|
||||
SBGRN = TerminalColor.STRONG_BG_GREEN.value
|
||||
SBYLW = TerminalColor.STRONG_BG_YELLOW.value
|
||||
SBBLU = TerminalColor.STRONG_BG_BLUE.value
|
||||
SBMAG = TerminalColor.STRONG_BG_MAGENTA.value
|
||||
SBCYN = TerminalColor.STRONG_BG_CYAN.value
|
||||
SBWHT = TerminalColor.STRONG_BG_WHITE.value
|
||||
|
||||
|
||||
class ClrNever(ClrBase):
|
||||
"""Convenience class for color terminal output.
|
||||
|
||||
This version has colors disabled. Generally you should use Clr which
|
||||
points to the correct enabled/disabled class depending on the environment.
|
||||
"""
|
||||
|
||||
color_enabled = False
|
||||
|
||||
# Styles
|
||||
RST = ''
|
||||
BLD = ''
|
||||
UND = ''
|
||||
INV = ''
|
||||
|
||||
# Normal foreground colors
|
||||
BLK = ''
|
||||
RED = ''
|
||||
GRN = ''
|
||||
YLW = ''
|
||||
BLU = ''
|
||||
MAG = ''
|
||||
CYN = ''
|
||||
WHT = ''
|
||||
|
||||
# Normal background colors.
|
||||
BBLK = ''
|
||||
BRED = ''
|
||||
BGRN = ''
|
||||
BYLW = ''
|
||||
BBLU = ''
|
||||
BMAG = ''
|
||||
BCYN = ''
|
||||
BWHT = ''
|
||||
|
||||
# Strong foreground colors
|
||||
SBLK = ''
|
||||
SRED = ''
|
||||
SGRN = ''
|
||||
SYLW = ''
|
||||
SBLU = ''
|
||||
SMAG = ''
|
||||
SCYN = ''
|
||||
SWHT = ''
|
||||
|
||||
# Strong background colors.
|
||||
SBBLK = ''
|
||||
SBRED = ''
|
||||
SBGRN = ''
|
||||
SBYLW = ''
|
||||
SBBLU = ''
|
||||
SBMAG = ''
|
||||
SBCYN = ''
|
||||
SBWHT = ''
|
||||
|
||||
|
||||
_envval = os.environ.get('EFRO_TERMCOLORS')
|
||||
_color_enabled: bool = (
|
||||
True
|
||||
if _envval == '1'
|
||||
else False
|
||||
if _envval == '0'
|
||||
else _default_color_enabled()
|
||||
)
|
||||
Clr: type[ClrBase]
|
||||
if _color_enabled:
|
||||
Clr = ClrAlways
|
||||
else:
|
||||
Clr = ClrNever
|
||||
735
dist/ba_data/python/efro/util.py
vendored
Normal file
735
dist/ba_data/python/efro/util.py
vendored
Normal file
|
|
@ -0,0 +1,735 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Small handy bits of functionality."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import weakref
|
||||
import datetime
|
||||
import functools
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, cast, TypeVar, Generic
|
||||
|
||||
_pytz_utc: Any
|
||||
|
||||
# We don't *require* pytz, but we want to support it for tzinfos if available.
|
||||
try:
|
||||
import pytz
|
||||
|
||||
_pytz_utc = pytz.utc
|
||||
except ModuleNotFoundError:
|
||||
_pytz_utc = None # pylint: disable=invalid-name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import asyncio
|
||||
from efro.call import Call as Call # 'as Call' so we re-export.
|
||||
from typing import Any, Callable, NoReturn
|
||||
|
||||
T = TypeVar('T')
|
||||
ValT = TypeVar('ValT')
|
||||
ArgT = TypeVar('ArgT')
|
||||
SelfT = TypeVar('SelfT')
|
||||
RetT = TypeVar('RetT')
|
||||
EnumT = TypeVar('EnumT', bound=Enum)
|
||||
|
||||
|
||||
class _EmptyObj:
|
||||
pass
|
||||
|
||||
|
||||
# TODO: kill this and just use efro.call.tpartial
|
||||
if TYPE_CHECKING:
|
||||
Call = Call
|
||||
else:
|
||||
Call = functools.partial
|
||||
|
||||
|
||||
def enum_by_value(cls: type[EnumT], value: Any) -> EnumT:
|
||||
"""Create an enum from a value.
|
||||
|
||||
This is basically the same as doing 'obj = EnumType(value)' except
|
||||
that it works around an issue where a reference loop is created
|
||||
if an exception is thrown due to an invalid value. Since we disable
|
||||
the cyclic garbage collector for most of the time, such loops can lead
|
||||
to our objects sticking around longer than we want.
|
||||
This issue has been submitted to Python as a bug so hopefully we can
|
||||
remove this eventually if it gets fixed: https://bugs.python.org/issue42248
|
||||
UPDATE: This has been fixed as of later 3.8 builds, so we can kill this
|
||||
off once we are 3.9+ across the board.
|
||||
"""
|
||||
|
||||
# Note: we don't recreate *ALL* the functionality of the Enum constructor
|
||||
# such as the _missing_ hook; but this should cover our basic needs.
|
||||
value2member_map = getattr(cls, '_value2member_map_')
|
||||
assert value2member_map is not None
|
||||
try:
|
||||
out = value2member_map[value]
|
||||
assert isinstance(out, cls)
|
||||
return out
|
||||
except KeyError:
|
||||
# pylint: disable=consider-using-f-string
|
||||
raise ValueError(
|
||||
'%r is not a valid %s' % (value, cls.__name__)
|
||||
) from None
|
||||
|
||||
|
||||
def check_utc(value: datetime.datetime) -> None:
|
||||
"""Ensure a datetime value is timezone-aware utc."""
|
||||
if value.tzinfo is not datetime.timezone.utc and (
|
||||
_pytz_utc is None or value.tzinfo is not _pytz_utc
|
||||
):
|
||||
raise ValueError(
|
||||
'datetime value does not have timezone set as'
|
||||
' datetime.timezone.utc'
|
||||
)
|
||||
|
||||
|
||||
def utc_now() -> datetime.datetime:
|
||||
"""Get offset-aware current utc time.
|
||||
|
||||
This should be used for all datetimes getting sent over the network,
|
||||
used with the entity system, etc.
|
||||
(datetime.utcnow() gives a utc time value, but it is not timezone-aware
|
||||
which makes it less safe to use)
|
||||
"""
|
||||
return datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
|
||||
def utc_today() -> datetime.datetime:
|
||||
"""Get offset-aware midnight in the utc time zone."""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
return datetime.datetime(
|
||||
year=now.year, month=now.month, day=now.day, tzinfo=now.tzinfo
|
||||
)
|
||||
|
||||
|
||||
def utc_this_hour() -> datetime.datetime:
|
||||
"""Get offset-aware beginning of the current hour in the utc time zone."""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
return datetime.datetime(
|
||||
year=now.year,
|
||||
month=now.month,
|
||||
day=now.day,
|
||||
hour=now.hour,
|
||||
tzinfo=now.tzinfo,
|
||||
)
|
||||
|
||||
|
||||
def utc_this_minute() -> datetime.datetime:
|
||||
"""Get offset-aware beginning of current minute in the utc time zone."""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
return datetime.datetime(
|
||||
year=now.year,
|
||||
month=now.month,
|
||||
day=now.day,
|
||||
hour=now.hour,
|
||||
minute=now.minute,
|
||||
tzinfo=now.tzinfo,
|
||||
)
|
||||
|
||||
|
||||
def empty_weakref(objtype: type[T]) -> weakref.ref[T]:
|
||||
"""Return an invalidated weak-reference for the specified type."""
|
||||
# At runtime, all weakrefs are the same; our type arg is just
|
||||
# for the static type checker.
|
||||
del objtype # Unused.
|
||||
# Just create an object and let it die. Is there a cleaner way to do this?
|
||||
return weakref.ref(_EmptyObj()) # type: ignore
|
||||
|
||||
|
||||
def data_size_str(bytecount: int) -> str:
|
||||
"""Given a size in bytes, returns a short human readable string.
|
||||
|
||||
This should be 6 or fewer chars for most all sane file sizes.
|
||||
"""
|
||||
# pylint: disable=too-many-return-statements
|
||||
if bytecount <= 999:
|
||||
return f'{bytecount} B'
|
||||
kbytecount = bytecount / 1024
|
||||
if round(kbytecount, 1) < 10.0:
|
||||
return f'{kbytecount:.1f} KB'
|
||||
if round(kbytecount, 0) < 999:
|
||||
return f'{kbytecount:.0f} KB'
|
||||
mbytecount = bytecount / (1024 * 1024)
|
||||
if round(mbytecount, 1) < 10.0:
|
||||
return f'{mbytecount:.1f} MB'
|
||||
if round(mbytecount, 0) < 999:
|
||||
return f'{mbytecount:.0f} MB'
|
||||
gbytecount = bytecount / (1024 * 1024 * 1024)
|
||||
if round(gbytecount, 1) < 10.0:
|
||||
return f'{mbytecount:.1f} GB'
|
||||
return f'{gbytecount:.0f} GB'
|
||||
|
||||
|
||||
class DirtyBit:
|
||||
"""Manages whether a thing is dirty and regulates attempts to clean it.
|
||||
|
||||
To use, simply set the 'dirty' value on this object to True when some
|
||||
action is needed, and then check the 'should_update' value to regulate
|
||||
when attempts to clean it should be made. Set 'dirty' back to False after
|
||||
a successful update.
|
||||
If 'use_lock' is True, an asyncio Lock will be created and incorporated
|
||||
into update attempts to prevent simultaneous updates (should_update will
|
||||
only return True when the lock is unlocked). Note that It is up to the user
|
||||
to lock/unlock the lock during the actual update attempt.
|
||||
If a value is passed for 'auto_dirty_seconds', the dirtybit will flip
|
||||
itself back to dirty after being clean for the given amount of time.
|
||||
'min_update_interval' can be used to enforce a minimum update
|
||||
interval even when updates are successful (retry_interval only applies
|
||||
when updates fail)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirty: bool = False,
|
||||
retry_interval: float = 5.0,
|
||||
use_lock: bool = False,
|
||||
auto_dirty_seconds: float | None = None,
|
||||
min_update_interval: float | None = None,
|
||||
):
|
||||
curtime = time.time()
|
||||
self._retry_interval = retry_interval
|
||||
self._auto_dirty_seconds = auto_dirty_seconds
|
||||
self._min_update_interval = min_update_interval
|
||||
self._dirty = dirty
|
||||
self._next_update_time: float | None = curtime if dirty else None
|
||||
self._last_update_time: float | None = None
|
||||
self._next_auto_dirty_time: float | None = (
|
||||
(curtime + self._auto_dirty_seconds)
|
||||
if (not dirty and self._auto_dirty_seconds is not None)
|
||||
else None
|
||||
)
|
||||
self._use_lock = use_lock
|
||||
self.lock: asyncio.Lock
|
||||
if self._use_lock:
|
||||
import asyncio
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def dirty(self) -> bool:
|
||||
"""Whether the target is currently dirty.
|
||||
|
||||
This should be set to False once an update is successful.
|
||||
"""
|
||||
return self._dirty
|
||||
|
||||
@dirty.setter
|
||||
def dirty(self, value: bool) -> None:
|
||||
|
||||
# If we're freshly clean, set our next auto-dirty time (if we have
|
||||
# one).
|
||||
if self._dirty and not value and self._auto_dirty_seconds is not None:
|
||||
self._next_auto_dirty_time = time.time() + self._auto_dirty_seconds
|
||||
|
||||
# If we're freshly dirty, schedule an immediate update.
|
||||
if not self._dirty and value:
|
||||
self._next_update_time = time.time()
|
||||
|
||||
# If they want to enforce a minimum update interval,
|
||||
# push out the next update time if it hasn't been long enough.
|
||||
if (
|
||||
self._min_update_interval is not None
|
||||
and self._last_update_time is not None
|
||||
):
|
||||
self._next_update_time = max(
|
||||
self._next_update_time,
|
||||
self._last_update_time + self._min_update_interval,
|
||||
)
|
||||
|
||||
self._dirty = value
|
||||
|
||||
@property
|
||||
def should_update(self) -> bool:
|
||||
"""Whether an attempt should be made to clean the target now.
|
||||
|
||||
Always returns False if the target is not dirty.
|
||||
Takes into account the amount of time passed since the target
|
||||
was marked dirty or since should_update last returned True.
|
||||
"""
|
||||
curtime = time.time()
|
||||
|
||||
# Auto-dirty ourself if we're into that.
|
||||
if (
|
||||
self._next_auto_dirty_time is not None
|
||||
and curtime > self._next_auto_dirty_time
|
||||
):
|
||||
self.dirty = True
|
||||
self._next_auto_dirty_time = None
|
||||
if not self._dirty:
|
||||
return False
|
||||
if self._use_lock and self.lock.locked():
|
||||
return False
|
||||
assert self._next_update_time is not None
|
||||
if curtime > self._next_update_time:
|
||||
self._next_update_time = curtime + self._retry_interval
|
||||
self._last_update_time = curtime
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class DispatchMethodWrapper(Generic[ArgT, RetT]):
|
||||
"""Type-aware standin for the dispatch func returned by dispatchmethod."""
|
||||
|
||||
def __call__(self, arg: ArgT) -> RetT:
|
||||
raise RuntimeError('Should not get here')
|
||||
|
||||
@staticmethod
|
||||
def register(
|
||||
func: Callable[[Any, Any], RetT]
|
||||
) -> Callable[[Any, Any], RetT]:
|
||||
"""Register a new dispatch handler for this dispatch-method."""
|
||||
raise RuntimeError('Should not get here')
|
||||
|
||||
registry: dict[Any, Callable]
|
||||
|
||||
|
||||
# noinspection PyProtectedMember,PyTypeHints
|
||||
def dispatchmethod(
|
||||
func: Callable[[Any, ArgT], RetT]
|
||||
) -> DispatchMethodWrapper[ArgT, RetT]:
|
||||
"""A variation of functools.singledispatch for methods.
|
||||
|
||||
Note: as of Python 3.9 there is now functools.singledispatchmethod,
|
||||
but it currently (as of Jan 2021) is not type-aware (at least in mypy),
|
||||
which gives us a reason to keep this one around for now.
|
||||
"""
|
||||
from functools import singledispatch, update_wrapper
|
||||
|
||||
origwrapper: Any = singledispatch(func)
|
||||
|
||||
# Pull this out so hopefully origwrapper can die,
|
||||
# otherwise we reference origwrapper in our wrapper.
|
||||
dispatch = origwrapper.dispatch
|
||||
|
||||
# All we do here is recreate the end of functools.singledispatch
|
||||
# where it returns a wrapper except instead of the wrapper using the
|
||||
# first arg to the function ours uses the second (to skip 'self').
|
||||
# This was made against Python 3.7; we should probably check up on
|
||||
# this in later versions in case anything has changed.
|
||||
# (or hopefully they'll add this functionality to their version)
|
||||
# NOTE: sounds like we can use functools singledispatchmethod in 3.8
|
||||
def wrapper(*args: Any, **kw: Any) -> Any:
|
||||
if not args or len(args) < 2:
|
||||
raise TypeError(
|
||||
f'{funcname} requires at least ' '2 positional arguments'
|
||||
)
|
||||
|
||||
return dispatch(args[1].__class__)(*args, **kw)
|
||||
|
||||
funcname = getattr(func, '__name__', 'dispatchmethod method')
|
||||
wrapper.register = origwrapper.register # type: ignore
|
||||
wrapper.dispatch = dispatch # type: ignore
|
||||
wrapper.registry = origwrapper.registry # type: ignore
|
||||
# pylint: disable=protected-access
|
||||
wrapper._clear_cache = origwrapper._clear_cache # type: ignore
|
||||
update_wrapper(wrapper, func)
|
||||
# pylint: enable=protected-access
|
||||
return cast(DispatchMethodWrapper, wrapper)
|
||||
|
||||
|
||||
def valuedispatch(call: Callable[[ValT], RetT]) -> ValueDispatcher[ValT, RetT]:
|
||||
"""Decorator for functions to allow dispatching based on a value.
|
||||
|
||||
This differs from functools.singledispatch in that it dispatches based
|
||||
on the value of an argument, not based on its type.
|
||||
The 'register' method of a value-dispatch function can be used
|
||||
to assign new functions to handle particular values.
|
||||
Unhandled values wind up in the original dispatch function."""
|
||||
return ValueDispatcher(call)
|
||||
|
||||
|
||||
class ValueDispatcher(Generic[ValT, RetT]):
|
||||
"""Used by the valuedispatch decorator"""
|
||||
|
||||
def __init__(self, call: Callable[[ValT], RetT]) -> None:
|
||||
self._base_call = call
|
||||
self._handlers: dict[ValT, Callable[[], RetT]] = {}
|
||||
|
||||
def __call__(self, value: ValT) -> RetT:
|
||||
handler = self._handlers.get(value)
|
||||
if handler is not None:
|
||||
return handler()
|
||||
return self._base_call(value)
|
||||
|
||||
def _add_handler(
|
||||
self, value: ValT, call: Callable[[], RetT]
|
||||
) -> Callable[[], RetT]:
|
||||
if value in self._handlers:
|
||||
raise RuntimeError(f'Duplicate handlers added for {value}')
|
||||
self._handlers[value] = call
|
||||
return call
|
||||
|
||||
def register(
|
||||
self, value: ValT
|
||||
) -> Callable[[Callable[[], RetT]], Callable[[], RetT]]:
|
||||
"""Add a handler to the dispatcher."""
|
||||
from functools import partial
|
||||
|
||||
return partial(self._add_handler, value)
|
||||
|
||||
|
||||
def valuedispatch1arg(
|
||||
call: Callable[[ValT, ArgT], RetT]
|
||||
) -> ValueDispatcher1Arg[ValT, ArgT, RetT]:
|
||||
"""Like valuedispatch but for functions taking an extra argument."""
|
||||
return ValueDispatcher1Arg(call)
|
||||
|
||||
|
||||
class ValueDispatcher1Arg(Generic[ValT, ArgT, RetT]):
|
||||
"""Used by the valuedispatch1arg decorator"""
|
||||
|
||||
def __init__(self, call: Callable[[ValT, ArgT], RetT]) -> None:
|
||||
self._base_call = call
|
||||
self._handlers: dict[ValT, Callable[[ArgT], RetT]] = {}
|
||||
|
||||
def __call__(self, value: ValT, arg: ArgT) -> RetT:
|
||||
handler = self._handlers.get(value)
|
||||
if handler is not None:
|
||||
return handler(arg)
|
||||
return self._base_call(value, arg)
|
||||
|
||||
def _add_handler(
|
||||
self, value: ValT, call: Callable[[ArgT], RetT]
|
||||
) -> Callable[[ArgT], RetT]:
|
||||
if value in self._handlers:
|
||||
raise RuntimeError(f'Duplicate handlers added for {value}')
|
||||
self._handlers[value] = call
|
||||
return call
|
||||
|
||||
def register(
|
||||
self, value: ValT
|
||||
) -> Callable[[Callable[[ArgT], RetT]], Callable[[ArgT], RetT]]:
|
||||
"""Add a handler to the dispatcher."""
|
||||
from functools import partial
|
||||
|
||||
return partial(self._add_handler, value)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class ValueDispatcherMethod(Generic[ValT, RetT]):
|
||||
"""Used by the valuedispatchmethod decorator."""
|
||||
|
||||
def __call__(self, value: ValT) -> RetT:
|
||||
...
|
||||
|
||||
def register(
|
||||
self, value: ValT
|
||||
) -> Callable[[Callable[[SelfT], RetT]], Callable[[SelfT], RetT]]:
|
||||
"""Add a handler to the dispatcher."""
|
||||
...
|
||||
|
||||
|
||||
def valuedispatchmethod(
|
||||
call: Callable[[SelfT, ValT], RetT]
|
||||
) -> ValueDispatcherMethod[ValT, RetT]:
|
||||
"""Like valuedispatch but works with methods instead of functions."""
|
||||
|
||||
# NOTE: It seems that to wrap a method with a decorator and have self
|
||||
# dispatching do the right thing, we must return a function and not
|
||||
# an executable object. So for this version we store our data here
|
||||
# in the function call dict and simply return a call.
|
||||
|
||||
_base_call = call
|
||||
_handlers: dict[ValT, Callable[[SelfT], RetT]] = {}
|
||||
|
||||
def _add_handler(value: ValT, addcall: Callable[[SelfT], RetT]) -> None:
|
||||
if value in _handlers:
|
||||
raise RuntimeError(f'Duplicate handlers added for {value}')
|
||||
_handlers[value] = addcall
|
||||
|
||||
def _register(value: ValT) -> Callable[[Callable[[SelfT], RetT]], None]:
|
||||
from functools import partial
|
||||
|
||||
return partial(_add_handler, value)
|
||||
|
||||
def _call_wrapper(self: SelfT, value: ValT) -> RetT:
|
||||
handler = _handlers.get(value)
|
||||
if handler is not None:
|
||||
return handler(self)
|
||||
return _base_call(self, value)
|
||||
|
||||
# We still want to use our returned object to register handlers, but we're
|
||||
# actually just returning a function. So manually stuff the call onto it.
|
||||
setattr(_call_wrapper, 'register', _register)
|
||||
|
||||
# To the type checker's eyes we return a ValueDispatchMethod instance;
|
||||
# this lets it know about our register func and type-check its usage.
|
||||
# In reality we just return a raw function call (for reasons listed above).
|
||||
# pylint: disable=undefined-variable, no-else-return
|
||||
if TYPE_CHECKING:
|
||||
return ValueDispatcherMethod[ValT, RetT]()
|
||||
else:
|
||||
return _call_wrapper
|
||||
|
||||
|
||||
def make_hash(obj: Any) -> int:
|
||||
"""Makes a hash from a dictionary, list, tuple or set to any level,
|
||||
that contains only other hashable types (including any lists, tuples,
|
||||
sets, and dictionaries).
|
||||
|
||||
Note that this uses Python's hash() function internally so collisions/etc.
|
||||
may be more common than with fancy cryptographic hashes.
|
||||
|
||||
Also be aware that Python's hash() output varies across processes, so
|
||||
this should only be used for values that will remain in a single process.
|
||||
"""
|
||||
import copy
|
||||
|
||||
if isinstance(obj, (set, tuple, list)):
|
||||
return hash(tuple(make_hash(e) for e in obj))
|
||||
if not isinstance(obj, dict):
|
||||
return hash(obj)
|
||||
|
||||
new_obj = copy.deepcopy(obj)
|
||||
for k, v in new_obj.items():
|
||||
new_obj[k] = make_hash(v)
|
||||
|
||||
# NOTE: there is sorted works correctly because it compares only
|
||||
# unique first values (i.e. dict keys)
|
||||
return hash(tuple(frozenset(sorted(new_obj.items()))))
|
||||
|
||||
|
||||
def asserttype(obj: Any, typ: type[T]) -> T:
|
||||
"""Return an object typed as a given type.
|
||||
|
||||
Assert is used to check its actual type, so only use this when
|
||||
failures are not expected. Otherwise use checktype.
|
||||
"""
|
||||
assert isinstance(typ, type), 'only actual types accepted'
|
||||
assert isinstance(obj, typ)
|
||||
return obj
|
||||
|
||||
|
||||
def asserttype_o(obj: Any, typ: type[T]) -> T | None:
|
||||
"""Return an object typed as a given optional type.
|
||||
|
||||
Assert is used to check its actual type, so only use this when
|
||||
failures are not expected. Otherwise use checktype.
|
||||
"""
|
||||
assert isinstance(typ, type), 'only actual types accepted'
|
||||
assert isinstance(obj, (typ, type(None)))
|
||||
return obj
|
||||
|
||||
|
||||
def checktype(obj: Any, typ: type[T]) -> T:
|
||||
"""Return an object typed as a given type.
|
||||
|
||||
Always checks the type at runtime with isinstance and throws a TypeError
|
||||
on failure. Use asserttype for more efficient (but less safe) equivalent.
|
||||
"""
|
||||
assert isinstance(typ, type), 'only actual types accepted'
|
||||
if not isinstance(obj, typ):
|
||||
raise TypeError(f'Expected a {typ}; got a {type(obj)}.')
|
||||
return obj
|
||||
|
||||
|
||||
def checktype_o(obj: Any, typ: type[T]) -> T | None:
|
||||
"""Return an object typed as a given optional type.
|
||||
|
||||
Always checks the type at runtime with isinstance and throws a TypeError
|
||||
on failure. Use asserttype for more efficient (but less safe) equivalent.
|
||||
"""
|
||||
assert isinstance(typ, type), 'only actual types accepted'
|
||||
if not isinstance(obj, (typ, type(None))):
|
||||
raise TypeError(f'Expected a {typ} or None; got a {type(obj)}.')
|
||||
return obj
|
||||
|
||||
|
||||
def warntype(obj: Any, typ: type[T]) -> T:
|
||||
"""Return an object typed as a given type.
|
||||
|
||||
Always checks the type at runtime and simply logs a warning if it is
|
||||
not what is expected.
|
||||
"""
|
||||
assert isinstance(typ, type), 'only actual types accepted'
|
||||
if not isinstance(obj, typ):
|
||||
import logging
|
||||
|
||||
logging.warning('warntype: expected a %s, got a %s', typ, type(obj))
|
||||
return obj # type: ignore
|
||||
|
||||
|
||||
def warntype_o(obj: Any, typ: type[T]) -> T | None:
|
||||
"""Return an object typed as a given type.
|
||||
|
||||
Always checks the type at runtime and simply logs a warning if it is
|
||||
not what is expected.
|
||||
"""
|
||||
assert isinstance(typ, type), 'only actual types accepted'
|
||||
if not isinstance(obj, (typ, type(None))):
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
'warntype: expected a %s or None, got a %s', typ, type(obj)
|
||||
)
|
||||
return obj # type: ignore
|
||||
|
||||
|
||||
def assert_non_optional(obj: T | None) -> T:
|
||||
"""Return an object with Optional typing removed.
|
||||
|
||||
Assert is used to check its actual type, so only use this when
|
||||
failures are not expected. Use check_non_optional otherwise.
|
||||
"""
|
||||
assert obj is not None
|
||||
return obj
|
||||
|
||||
|
||||
def check_non_optional(obj: T | None) -> T:
|
||||
"""Return an object with Optional typing removed.
|
||||
|
||||
Always checks the actual type and throws a TypeError on failure.
|
||||
Use assert_non_optional for a more efficient (but less safe) equivalent.
|
||||
"""
|
||||
if obj is None:
|
||||
raise TypeError('Got None value in check_non_optional.')
|
||||
return obj
|
||||
|
||||
|
||||
def smoothstep(edge0: float, edge1: float, x: float) -> float:
|
||||
"""A smooth transition function.
|
||||
|
||||
Returns a value that smoothly moves from 0 to 1 as we go between edges.
|
||||
Values outside of the range return 0 or 1.
|
||||
"""
|
||||
y = min(1.0, max(0.0, (x - edge0) / (edge1 - edge0)))
|
||||
return y * y * (3.0 - 2.0 * y)
|
||||
|
||||
|
||||
def linearstep(edge0: float, edge1: float, x: float) -> float:
|
||||
"""A linear transition function.
|
||||
|
||||
Returns a value that linearly moves from 0 to 1 as we go between edges.
|
||||
Values outside of the range return 0 or 1.
|
||||
"""
|
||||
return max(0.0, min(1.0, (x - edge0) / (edge1 - edge0)))
|
||||
|
||||
|
||||
def _compact_id(num: int, chars: str) -> str:
|
||||
if num < 0:
|
||||
raise ValueError('Negative integers not allowed.')
|
||||
|
||||
# Chars must be in sorted order for sorting to work correctly
|
||||
# on our output.
|
||||
assert ''.join(sorted(list(chars))) == chars
|
||||
|
||||
base = len(chars)
|
||||
out = ''
|
||||
while num:
|
||||
out += chars[num % base]
|
||||
num //= base
|
||||
return out[::-1] or '0'
|
||||
|
||||
|
||||
def human_readable_compact_id(num: int) -> str:
|
||||
"""Given a positive int, return a compact string representation for it.
|
||||
|
||||
Handy for visualizing unique numeric ids using as few as possible chars.
|
||||
This representation uses only lowercase letters and numbers (minus the
|
||||
following letters for readability):
|
||||
's' is excluded due to similarity to '5'.
|
||||
'l' is excluded due to similarity to '1'.
|
||||
'i' is excluded due to similarity to '1'.
|
||||
'o' is excluded due to similarity to '0'.
|
||||
'z' is excluded due to similarity to '2'.
|
||||
|
||||
Therefore for n chars this can store values of 21^n.
|
||||
|
||||
When reading human input consisting of these IDs, it may be desirable
|
||||
to map the disallowed chars to their corresponding allowed ones
|
||||
('o' -> '0', etc).
|
||||
|
||||
Sort order for these ids is the same as the original numbers.
|
||||
|
||||
If more compactness is desired at the expense of readability,
|
||||
use compact_id() instead.
|
||||
"""
|
||||
return _compact_id(num, '0123456789abcdefghjkmnpqrtuvwxy')
|
||||
|
||||
|
||||
def compact_id(num: int) -> str:
|
||||
"""Given a positive int, return a compact string representation for it.
|
||||
|
||||
Handy for visualizing unique numeric ids using as few as possible chars.
|
||||
This version is more compact than human_readable_compact_id() but less
|
||||
friendly to humans due to using both capital and lowercase letters,
|
||||
both 'O' and '0', etc.
|
||||
|
||||
Therefore for n chars this can store values of 62^n.
|
||||
|
||||
Sort order for these ids is the same as the original numbers.
|
||||
"""
|
||||
return _compact_id(
|
||||
num, '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||
)
|
||||
|
||||
|
||||
# NOTE: Even though this is available as part of typing_extensions, keeping
|
||||
# it in here for now so we don't require typing_extensions as a dependency.
|
||||
# Once 3.11 rolls around we can kill this and use typing.assert_never.
|
||||
def assert_never(value: NoReturn) -> NoReturn:
|
||||
"""Trick for checking exhaustive handling of Enums, etc.
|
||||
See https://github.com/python/typing/issues/735
|
||||
"""
|
||||
assert False, f'Unhandled value: {value} ({type(value).__name__})'
|
||||
|
||||
|
||||
def unchanging_hostname() -> str:
|
||||
"""Return an unchanging name for the local device.
|
||||
|
||||
Similar to the `hostname` call (or os.uname().nodename in Python)
|
||||
except attempts to give a name that doesn't change depending on
|
||||
network conditions. (A Mac will tend to go from Foo to Foo.local,
|
||||
Foo.lan etc. throughout its various adventures)
|
||||
"""
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
# On Mac, this should give the computer name assigned in System Prefs.
|
||||
if platform.system() == 'Darwin':
|
||||
return (
|
||||
subprocess.run(
|
||||
['scutil', '--get', 'ComputerName'],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
.stdout.decode()
|
||||
.strip()
|
||||
.replace(' ', '-')
|
||||
)
|
||||
return os.uname().nodename
|
||||
|
||||
|
||||
def set_canonical_module(
|
||||
module_globals: dict[str, Any], names: list[str]
|
||||
) -> None:
|
||||
"""Override any __module__ attrs on passed classes/etc.
|
||||
|
||||
This allows classes to present themselves using clean paths such as
|
||||
mymodule.MyClass instead of possibly ugly internal ones such as
|
||||
mymodule._internal._stuff.MyClass.
|
||||
"""
|
||||
modulename = module_globals.get('__name__')
|
||||
if not isinstance(modulename, str):
|
||||
raise RuntimeError('Unable to get module name.')
|
||||
for name in names:
|
||||
obj = module_globals[name]
|
||||
existing = getattr(obj, '__module__', None)
|
||||
try:
|
||||
if existing is not None and existing != modulename:
|
||||
obj.__module__ = modulename
|
||||
except Exception:
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
'set_canonical_module: unable to change __module__'
|
||||
" from '%s' to '%s' on %s object at '%s'.",
|
||||
existing,
|
||||
modulename,
|
||||
type(obj),
|
||||
name,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue