vh-bombsquad-modded-server-.../dist/ba_data/python/efro/dataclasses.py

296 lines
12 KiB
Python
Raw Normal View History

2024-02-26 00:17:10 +05:30
# Released under the MIT License. See LICENSE for details.
#
"""Custom functionality for dealing with dataclasses."""
# Note: We do lots of comparing of exact types here which is normally
# frowned upon (stuff like isinstance() is usually encouraged).
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
import dataclasses
import inspect
from enum import Enum
from typing import TYPE_CHECKING, TypeVar, Generic
from efro.util import enum_by_value
if TYPE_CHECKING:
from typing import Any, Dict, Type, Tuple, Optional
T = TypeVar('T')
SIMPLE_NAMES_TO_TYPES: Dict[str, Type] = {
'int': int,
'bool': bool,
'str': str,
'float': float,
}
SIMPLE_TYPES_TO_NAMES = {tp: nm for nm, tp in SIMPLE_NAMES_TO_TYPES.items()}
def dataclass_to_dict(obj: Any, coerce_to_float: bool = True) -> dict:
"""Given a dataclass object, emit a json-friendly dict.
All values will be checked to ensure they match the types specified
on fields. Note that only a limited set of types is supported.
If coerce_to_float is True, integer values present on float typed fields
will be converted to floats in the dict output. If False, a TypeError
will be triggered.
"""
out = _Outputter(obj, create=True, coerce_to_float=coerce_to_float).run()
assert isinstance(out, dict)
return out
def dataclass_from_dict(cls: Type[T],
values: dict,
coerce_to_float: bool = True) -> T:
"""Given a dict, instantiates a dataclass of the given type.
The dict must be in the json-friendly format as emitted from
dataclass_to_dict. This means that sequence values such as tuples or
sets should be passed as lists, enums should be passed as their
associated values, and nested dataclasses should be passed as dicts.
If coerce_to_float is True, int values passed for float typed fields
will be converted to float values. Otherwise a TypeError is raised.
"""
return _Inputter(cls, coerce_to_float=coerce_to_float).run(values)
def dataclass_validate(obj: Any, coerce_to_float: bool = True) -> None:
"""Ensure that current values in a dataclass are the correct types."""
_Outputter(obj, create=False, coerce_to_float=coerce_to_float).run()
def _field_type_str(cls: Type, field: dataclasses.Field) -> str:
# We expect to be operating under 'from __future__ import annotations'
# so field types should always be strings for us; not actual types.
# (Can pull this check out once we get to Python 3.10)
typestr: str = field.type # type: ignore
if not isinstance(typestr, str):
raise RuntimeError(
f'Dataclass {cls.__name__} seems to have'
f' been created without "from __future__ import annotations";'
f' those dataclasses are unsupported here.')
return typestr
def _raise_type_error(fieldpath: str, valuetype: Type,
expected: Tuple[Type, ...]) -> None:
"""Raise an error when a field value's type does not match expected."""
assert isinstance(expected, tuple)
assert all(isinstance(e, type) for e in expected)
if len(expected) == 1:
expected_str = expected[0].__name__
else:
names = ', '.join(t.__name__ for t in expected)
expected_str = f'Union[{names}]'
raise TypeError(f'Invalid value type for "{fieldpath}";'
f' expected "{expected_str}", got'
f' "{valuetype.__name__}".')
class _Outputter:
def __init__(self, obj: Any, create: bool, coerce_to_float: bool) -> None:
self._obj = obj
self._create = create
self._coerce_to_float = coerce_to_float
def run(self) -> Any:
"""Do the thing."""
return self._dataclass_to_output(self._obj, '')
def _value_to_output(self, fieldpath: str, typestr: str,
value: Any) -> Any:
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# For simple flat types, look for exact matches:
simpletype = SIMPLE_NAMES_TO_TYPES.get(typestr)
if simpletype is not None:
if type(value) is not simpletype:
# Special case: if they want to coerce ints to floats, do so.
if (self._coerce_to_float and simpletype is float
and type(value) is int):
return float(value) if self._create else None
_raise_type_error(fieldpath, type(value), (simpletype, ))
return value
if typestr.startswith('Optional[') and typestr.endswith(']'):
subtypestr = typestr[9:-1]
# Handle the 'None' case special and do the default otherwise.
if value is None:
return None
return self._value_to_output(fieldpath, subtypestr, value)
if typestr.startswith('List[') and typestr.endswith(']'):
subtypestr = typestr[5:-1]
if not isinstance(value, list):
raise TypeError(f'Expected a list for {fieldpath};'
f' found a {type(value)}')
if self._create:
return [
self._value_to_output(fieldpath, subtypestr, x)
for x in value
]
for x in value:
self._value_to_output(fieldpath, subtypestr, x)
return None
if typestr.startswith('Set[') and typestr.endswith(']'):
subtypestr = typestr[4:-1]
if not isinstance(value, set):
raise TypeError(f'Expected a set for {fieldpath};'
f' found a {type(value)}')
if self._create:
# Note: we output json-friendly values so this becomes a list.
return [
self._value_to_output(fieldpath, subtypestr, x)
for x in value
]
for x in value:
self._value_to_output(fieldpath, subtypestr, x)
return None
if dataclasses.is_dataclass(value):
return self._dataclass_to_output(value, fieldpath)
if isinstance(value, Enum):
enumvalue = value.value
if type(enumvalue) not in SIMPLE_TYPES_TO_NAMES:
raise TypeError(f'Invalid enum value type {type(enumvalue)}'
f' for "{fieldpath}".')
return enumvalue
raise TypeError(
f"Field '{fieldpath}' of type '{typestr}' is unsupported here.")
def _dataclass_to_output(self, obj: Any, fieldpath: str) -> Any:
if not dataclasses.is_dataclass(obj):
raise TypeError(f'Passed obj {obj} is not a dataclass.')
fields = dataclasses.fields(obj)
out: Optional[Dict[str, Any]] = {} if self._create else None
for field in fields:
fieldname = field.name
if fieldpath:
subfieldpath = f'{fieldpath}.{fieldname}'
else:
subfieldpath = fieldname
typestr = _field_type_str(type(obj), field)
value = getattr(obj, fieldname)
outvalue = self._value_to_output(subfieldpath, typestr, value)
if self._create:
assert out is not None
out[fieldname] = outvalue
return out
class _Inputter(Generic[T]):
def __init__(self, cls: Type[T], coerce_to_float: bool):
self._cls = cls
self._coerce_to_float = coerce_to_float
def run(self, values: dict) -> T:
"""Do the thing."""
return self._dataclass_from_input( # type: ignore
self._cls, '', values)
def _value_from_input(self, cls: Type, fieldpath: str, typestr: str,
value: Any) -> Any:
"""Convert an assigned value to what a dataclass field expects."""
# pylint: disable=too-many-return-statements
simpletype = SIMPLE_NAMES_TO_TYPES.get(typestr)
if simpletype is not None:
if type(value) is not simpletype:
# Special case: if they want to coerce ints to floats, do so.
if (self._coerce_to_float and simpletype is float
and type(value) is int):
return float(value)
_raise_type_error(fieldpath, type(value), (simpletype, ))
return value
if typestr.startswith('List[') and typestr.endswith(']'):
return self._sequence_from_input(cls, fieldpath, typestr, value,
'List', list)
if typestr.startswith('Set[') and typestr.endswith(']'):
return self._sequence_from_input(cls, fieldpath, typestr, value,
'Set', set)
if typestr.startswith('Optional[') and typestr.endswith(']'):
subtypestr = typestr[9:-1]
# Handle the 'None' case special and do the default
# thing otherwise.
if value is None:
return None
return self._value_from_input(cls, fieldpath, subtypestr, value)
# Ok, its not a builtin type. It might be an enum or nested dataclass.
cls2 = getattr(inspect.getmodule(cls), typestr, None)
if cls2 is None:
raise RuntimeError(f"Unable to resolve '{typestr}'"
f" used by class '{cls.__name__}';"
f' make sure all nested types are declared'
f' in the global namespace of the module where'
f" '{cls.__name__} is defined.")
if dataclasses.is_dataclass(cls2):
return self._dataclass_from_input(cls2, fieldpath, value)
if issubclass(cls2, Enum):
return enum_by_value(cls2, value)
raise TypeError(
f"Field '{fieldpath}' of type '{typestr}' is unsupported here.")
def _dataclass_from_input(self, cls: Type, fieldpath: str,
values: dict) -> Any:
"""Given a dict, instantiates a dataclass of the given type.
The dict must be in the json-friendly format as emitted from
dataclass_to_dict. This means that sequence values such as tuples or
sets should be passed as lists, enums should be passed as their
associated values, and nested dataclasses should be passed as dicts.
"""
if not dataclasses.is_dataclass(cls):
raise TypeError(f'Passed class {cls} is not a dataclass.')
if not isinstance(values, dict):
raise TypeError("Expected a dict for 'values' arg.")
# noinspection PyDataclass
fields = dataclasses.fields(cls)
fields_by_name = {f.name: f for f in fields}
args: Dict[str, Any] = {}
for key, value in values.items():
field = fields_by_name.get(key)
if field is None:
raise AttributeError(f"'{cls.__name__}' has no '{key}' field.")
typestr = _field_type_str(cls, field)
subfieldpath = (f'{fieldpath}.{field.name}'
if fieldpath else field.name)
args[key] = self._value_from_input(cls, subfieldpath, typestr,
value)
return cls(**args)
def _sequence_from_input(self, cls: Type, fieldpath: str, typestr: str,
value: Any, seqtypestr: str,
seqtype: Type) -> Any:
# Because we are json-centric, we expect a list for all sequences.
if type(value) is not list:
raise TypeError(f'Invalid input value for "{fieldpath}";'
f' expected a list, got a {type(value).__name__}')
subtypestr = typestr[len(seqtypestr) + 1:-1]
return seqtype(
self._value_from_input(cls, fieldpath, subtypestr, i)
for i in value)