mirror of
https://github.com/hypervortex/VH-Bombsquad-Modded-Server-Files
synced 2025-11-07 17:36:08 +00:00
Added new files
This commit is contained in:
parent
867634cc5c
commit
3a407868d4
1775 changed files with 550222 additions and 0 deletions
50
dist/ba_data/python/efro/dataclassio/__init__.py
vendored
Normal file
50
dist/ba_data/python/efro/dataclassio/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for importing, exporting, and validating dataclasses.
|
||||
|
||||
This allows complex nested dataclasses to be flattened to json-compatible
|
||||
data and restored from said data. It also gracefully handles and preserves
|
||||
unrecognized attribute data, allowing older clients to interact with newer
|
||||
data formats in a nondestructive manner.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from efro.util import set_canonical_module
|
||||
from efro.dataclassio._base import Codec, IOAttrs, IOExtendedData
|
||||
from efro.dataclassio._prep import (
|
||||
ioprep,
|
||||
ioprepped,
|
||||
will_ioprep,
|
||||
is_ioprepped_dataclass,
|
||||
)
|
||||
from efro.dataclassio._pathcapture import DataclassFieldLookup
|
||||
from efro.dataclassio._api import (
|
||||
JsonStyle,
|
||||
dataclass_to_dict,
|
||||
dataclass_to_json,
|
||||
dataclass_from_dict,
|
||||
dataclass_from_json,
|
||||
dataclass_validate,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'JsonStyle',
|
||||
'Codec',
|
||||
'IOAttrs',
|
||||
'IOExtendedData',
|
||||
'ioprep',
|
||||
'ioprepped',
|
||||
'will_ioprep',
|
||||
'is_ioprepped_dataclass',
|
||||
'DataclassFieldLookup',
|
||||
'dataclass_to_dict',
|
||||
'dataclass_to_json',
|
||||
'dataclass_from_dict',
|
||||
'dataclass_from_json',
|
||||
'dataclass_validate',
|
||||
]
|
||||
|
||||
# Have these things present themselves cleanly as 'thismodule.SomeClass'
|
||||
# instead of 'thismodule._internalmodule.SomeClass'
|
||||
set_canonical_module(module_globals=globals(), names=__all__)
|
||||
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/__init__.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/__init__.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_api.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_api.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_base.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_base.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_inputter.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_inputter.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_outputter.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_outputter.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_pathcapture.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_pathcapture.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_prep.cpython-310.opt-1.pyc
vendored
Normal file
BIN
dist/ba_data/python/efro/dataclassio/__pycache__/_prep.cpython-310.opt-1.pyc
vendored
Normal file
Binary file not shown.
163
dist/ba_data/python/efro/dataclassio/_api.py
vendored
Normal file
163
dist/ba_data/python/efro/dataclassio/_api.py
vendored
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for importing, exporting, and validating dataclasses.
|
||||
|
||||
This allows complex nested dataclasses to be flattened to json-compatible
|
||||
data and restored from said data. It also gracefully handles and preserves
|
||||
unrecognized attribute data, allowing older clients to interact with newer
|
||||
data formats in a nondestructive manner.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from efro.dataclassio._outputter import _Outputter
|
||||
from efro.dataclassio._inputter import _Inputter
|
||||
from efro.dataclassio._base import Codec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class JsonStyle(Enum):
|
||||
"""Different style types for json."""
|
||||
|
||||
# Single line, no spaces, no sorting. Not deterministic.
|
||||
# Use this for most storage purposes.
|
||||
FAST = 'fast'
|
||||
|
||||
# Single line, no spaces, sorted keys. Deterministic.
|
||||
# Use this when output may be hashed or compared for equality.
|
||||
SORTED = 'sorted'
|
||||
|
||||
# Multiple lines, spaces, sorted keys. Deterministic.
|
||||
# Use this for pretty human readable output.
|
||||
PRETTY = 'pretty'
|
||||
|
||||
|
||||
def dataclass_to_dict(
|
||||
obj: Any, codec: Codec = Codec.JSON, coerce_to_float: bool = True
|
||||
) -> dict:
|
||||
"""Given a dataclass object, return a json-friendly dict.
|
||||
|
||||
All values will be checked to ensure they match the types specified
|
||||
on fields. Note that a limited set of types and data configurations is
|
||||
supported.
|
||||
|
||||
Values with type Any will be checked to ensure they match types supported
|
||||
directly by json. This does not include types such as tuples which are
|
||||
implicitly translated by Python's json module (as this would break
|
||||
the ability to do a lossless round-trip with data).
|
||||
|
||||
If coerce_to_float is True, integer values present on float typed fields
|
||||
will be converted to float in the dict output. If False, a TypeError
|
||||
will be triggered.
|
||||
"""
|
||||
|
||||
out = _Outputter(
|
||||
obj, create=True, codec=codec, coerce_to_float=coerce_to_float
|
||||
).run()
|
||||
assert isinstance(out, dict)
|
||||
return out
|
||||
|
||||
|
||||
def dataclass_to_json(
|
||||
obj: Any,
|
||||
coerce_to_float: bool = True,
|
||||
pretty: bool = False,
|
||||
sort_keys: bool | None = None,
|
||||
) -> str:
|
||||
"""Utility function; return a json string from a dataclass instance.
|
||||
|
||||
Basically json.dumps(dataclass_to_dict(...)).
|
||||
By default, keys are sorted for pretty output and not otherwise, but
|
||||
this can be overridden by supplying a value for the 'sort_keys' arg.
|
||||
"""
|
||||
import json
|
||||
|
||||
jdict = dataclass_to_dict(
|
||||
obj=obj, coerce_to_float=coerce_to_float, codec=Codec.JSON
|
||||
)
|
||||
if sort_keys is None:
|
||||
sort_keys = pretty
|
||||
if pretty:
|
||||
return json.dumps(jdict, indent=2, sort_keys=sort_keys)
|
||||
return json.dumps(jdict, separators=(',', ':'), sort_keys=sort_keys)
|
||||
|
||||
|
||||
def dataclass_from_dict(
|
||||
cls: type[T],
|
||||
values: dict,
|
||||
codec: Codec = Codec.JSON,
|
||||
coerce_to_float: bool = True,
|
||||
allow_unknown_attrs: bool = True,
|
||||
discard_unknown_attrs: bool = False,
|
||||
) -> T:
|
||||
"""Given a dict, return a dataclass of a given type.
|
||||
|
||||
The dict must be formatted to match the specified codec (generally
|
||||
json-friendly object types). This means that sequence values such as
|
||||
tuples or sets should be passed as lists, enums should be passed as their
|
||||
associated values, nested dataclasses should be passed as dicts, etc.
|
||||
|
||||
All values are checked to ensure their types/values are valid.
|
||||
|
||||
Data for attributes of type Any will be checked to ensure they match
|
||||
types supported directly by json. This does not include types such
|
||||
as tuples which are implicitly translated by Python's json module
|
||||
(as this would break the ability to do a lossless round-trip with data).
|
||||
|
||||
If coerce_to_float is True, int values passed for float typed fields
|
||||
will be converted to float values. Otherwise, a TypeError is raised.
|
||||
|
||||
If allow_unknown_attrs is False, AttributeErrors will be raised for
|
||||
attributes present in the dict but not on the data class. Otherwise, they
|
||||
will be preserved as part of the instance and included if it is
|
||||
exported back to a dict, unless discard_unknown_attrs is True, in which
|
||||
case they will simply be discarded.
|
||||
"""
|
||||
return _Inputter(
|
||||
cls,
|
||||
codec=codec,
|
||||
coerce_to_float=coerce_to_float,
|
||||
allow_unknown_attrs=allow_unknown_attrs,
|
||||
discard_unknown_attrs=discard_unknown_attrs,
|
||||
).run(values)
|
||||
|
||||
|
||||
def dataclass_from_json(
|
||||
cls: type[T],
|
||||
json_str: str,
|
||||
coerce_to_float: bool = True,
|
||||
allow_unknown_attrs: bool = True,
|
||||
discard_unknown_attrs: bool = False,
|
||||
) -> T:
|
||||
"""Utility function; return a dataclass instance given a json string.
|
||||
|
||||
Basically dataclass_from_dict(json.loads(...))
|
||||
"""
|
||||
import json
|
||||
|
||||
return dataclass_from_dict(
|
||||
cls=cls,
|
||||
values=json.loads(json_str),
|
||||
coerce_to_float=coerce_to_float,
|
||||
allow_unknown_attrs=allow_unknown_attrs,
|
||||
discard_unknown_attrs=discard_unknown_attrs,
|
||||
)
|
||||
|
||||
|
||||
def dataclass_validate(
|
||||
obj: Any, coerce_to_float: bool = True, codec: Codec = Codec.JSON
|
||||
) -> None:
|
||||
"""Ensure that values in a dataclass instance are the correct types."""
|
||||
|
||||
# Simply run an output pass but tell it not to generate data;
|
||||
# only run validation.
|
||||
_Outputter(
|
||||
obj, create=False, codec=codec, coerce_to_float=coerce_to_float
|
||||
).run()
|
||||
276
dist/ba_data/python/efro/dataclassio/_base.py
vendored
Normal file
276
dist/ba_data/python/efro/dataclassio/_base.py
vendored
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Core components of dataclassio."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import typing
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, get_args
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from typing import _AnnotatedAlias # type: ignore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable
|
||||
|
||||
# Types which we can pass through as-is.
|
||||
SIMPLE_TYPES = {int, bool, str, float, type(None)}
|
||||
|
||||
# Attr name for dict of extra attributes included on dataclass instances.
|
||||
# Note that this is only added if extra attributes are present.
|
||||
EXTRA_ATTRS_ATTR = '_DCIOEXATTRS'
|
||||
|
||||
|
||||
def _raise_type_error(
|
||||
fieldpath: str, valuetype: type, expected: tuple[type, ...]
|
||||
) -> None:
|
||||
"""Raise an error when a field value's type does not match expected."""
|
||||
assert isinstance(expected, tuple)
|
||||
assert all(isinstance(e, type) for e in expected)
|
||||
if len(expected) == 1:
|
||||
expected_str = expected[0].__name__
|
||||
else:
|
||||
expected_str = ' | '.join(t.__name__ for t in expected)
|
||||
raise TypeError(
|
||||
f'Invalid value type for "{fieldpath}";'
|
||||
f' expected "{expected_str}", got'
|
||||
f' "{valuetype.__name__}".'
|
||||
)
|
||||
|
||||
|
||||
class Codec(Enum):
|
||||
"""Specifies expected data format exported to or imported from."""
|
||||
|
||||
# Use only types that will translate cleanly to/from json: lists,
|
||||
# dicts with str keys, bools, ints, floats, and None.
|
||||
JSON = 'json'
|
||||
|
||||
# Mostly like JSON but passes bytes and datetime objects through
|
||||
# as-is instead of converting them to json-friendly types.
|
||||
FIRESTORE = 'firestore'
|
||||
|
||||
|
||||
class IOExtendedData:
|
||||
"""A class that data types can inherit from for extra functionality."""
|
||||
|
||||
def will_output(self) -> None:
|
||||
"""Called before data is sent to an outputter.
|
||||
|
||||
Can be overridden to validate or filter data before
|
||||
sending it on its way.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def will_input(cls, data: dict) -> None:
|
||||
"""Called on raw data before a class instance is created from it.
|
||||
|
||||
Can be overridden to migrate old data formats to new, etc.
|
||||
"""
|
||||
|
||||
|
||||
def _is_valid_for_codec(obj: Any, codec: Codec) -> bool:
|
||||
"""Return whether a value consists solely of json-supported types.
|
||||
|
||||
Note that this does not include things like tuples which are
|
||||
implicitly translated to lists by python's json module.
|
||||
"""
|
||||
if obj is None:
|
||||
return True
|
||||
|
||||
objtype = type(obj)
|
||||
if objtype in (int, float, str, bool):
|
||||
return True
|
||||
if objtype is dict:
|
||||
# JSON 'objects' supports only string dict keys, but all value types.
|
||||
return all(
|
||||
isinstance(k, str) and _is_valid_for_codec(v, codec)
|
||||
for k, v in obj.items()
|
||||
)
|
||||
if objtype is list:
|
||||
return all(_is_valid_for_codec(elem, codec) for elem in obj)
|
||||
|
||||
# A few things are valid in firestore but not json.
|
||||
if issubclass(objtype, datetime.datetime) or objtype is bytes:
|
||||
return codec is Codec.FIRESTORE
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class IOAttrs:
|
||||
"""For specifying io behavior in annotations.
|
||||
|
||||
'storagename', if passed, is the name used when storing to json/etc.
|
||||
'store_default' can be set to False to avoid writing values when equal
|
||||
to the default value. Note that this requires the dataclass field
|
||||
to define a default or default_factory or for its IOAttrs to
|
||||
define a soft_default value.
|
||||
'whole_days', if True, requires datetime values to be exactly on day
|
||||
boundaries (see efro.util.utc_today()).
|
||||
'whole_hours', if True, requires datetime values to lie exactly on hour
|
||||
boundaries (see efro.util.utc_this_hour()).
|
||||
'whole_minutes', if True, requires datetime values to lie exactly on minute
|
||||
boundaries (see efro.util.utc_this_minute()).
|
||||
'soft_default', if passed, injects a default value into dataclass
|
||||
instantiation when the field is not present in the input data.
|
||||
This allows dataclasses to add new non-optional fields while
|
||||
gracefully 'upgrading' old data. Note that when a soft_default is
|
||||
present it will take precedence over field defaults when determining
|
||||
whether to store a value for a field with store_default=False
|
||||
(since the soft_default value is what we'll get when reading that
|
||||
same data back in when the field is omitted).
|
||||
'soft_default_factory' is similar to 'default_factory' in dataclass
|
||||
fields; it should be used instead of 'soft_default' for mutable types
|
||||
such as lists to prevent a single default object from unintentionally
|
||||
changing over time.
|
||||
"""
|
||||
|
||||
# A sentinel object to detect if a parameter is supplied or not. Use
|
||||
# a class to give it a better repr.
|
||||
class _MissingType:
|
||||
pass
|
||||
|
||||
MISSING = _MissingType()
|
||||
|
||||
storagename: str | None = None
|
||||
store_default: bool = True
|
||||
whole_days: bool = False
|
||||
whole_hours: bool = False
|
||||
whole_minutes: bool = False
|
||||
soft_default: Any = MISSING
|
||||
soft_default_factory: Callable[[], Any] | _MissingType = MISSING
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storagename: str | None = storagename,
|
||||
store_default: bool = store_default,
|
||||
whole_days: bool = whole_days,
|
||||
whole_hours: bool = whole_hours,
|
||||
whole_minutes: bool = whole_minutes,
|
||||
soft_default: Any = MISSING,
|
||||
soft_default_factory: Callable[[], Any] | _MissingType = MISSING,
|
||||
):
|
||||
|
||||
# Only store values that differ from class defaults to keep
|
||||
# our instances nice and lean.
|
||||
cls = type(self)
|
||||
if storagename != cls.storagename:
|
||||
self.storagename = storagename
|
||||
if store_default != cls.store_default:
|
||||
self.store_default = store_default
|
||||
if whole_days != cls.whole_days:
|
||||
self.whole_days = whole_days
|
||||
if whole_hours != cls.whole_hours:
|
||||
self.whole_hours = whole_hours
|
||||
if whole_minutes != cls.whole_minutes:
|
||||
self.whole_minutes = whole_minutes
|
||||
if soft_default is not cls.soft_default:
|
||||
|
||||
# Do what dataclasses does with its default types and
|
||||
# tell the user to use factory for mutable ones.
|
||||
if isinstance(soft_default, (list, dict, set)):
|
||||
raise ValueError(
|
||||
f'mutable {type(soft_default)} is not allowed'
|
||||
f' for soft_default; use soft_default_factory.'
|
||||
)
|
||||
self.soft_default = soft_default
|
||||
if soft_default_factory is not cls.soft_default_factory:
|
||||
self.soft_default_factory = soft_default_factory
|
||||
if self.soft_default is not cls.soft_default:
|
||||
raise ValueError(
|
||||
'Cannot set both soft_default and soft_default_factory'
|
||||
)
|
||||
|
||||
def validate_for_field(self, cls: type, field: dataclasses.Field) -> None:
|
||||
"""Ensure the IOAttrs instance is ok to use with the provided field."""
|
||||
|
||||
# Turning off store_default requires the field to have either
|
||||
# a default or a a default_factory or for us to have soft equivalents.
|
||||
|
||||
if not self.store_default:
|
||||
field_default_factory: Any = field.default_factory
|
||||
if (
|
||||
field_default_factory is dataclasses.MISSING
|
||||
and field.default is dataclasses.MISSING
|
||||
and self.soft_default is self.MISSING
|
||||
and self.soft_default_factory is self.MISSING
|
||||
):
|
||||
raise TypeError(
|
||||
f'Field {field.name} of {cls} has'
|
||||
f' neither a default nor a default_factory'
|
||||
f' and IOAttrs contains neither a soft_default'
|
||||
f' nor a soft_default_factory;'
|
||||
f' store_default=False cannot be set for it.'
|
||||
)
|
||||
|
||||
def validate_datetime(
|
||||
self, value: datetime.datetime, fieldpath: str
|
||||
) -> None:
|
||||
"""Ensure a datetime value meets our value requirements."""
|
||||
if self.whole_days:
|
||||
if any(
|
||||
x != 0
|
||||
for x in (
|
||||
value.hour,
|
||||
value.minute,
|
||||
value.second,
|
||||
value.microsecond,
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f'Value {value} at {fieldpath} is not a whole day.'
|
||||
)
|
||||
elif self.whole_hours:
|
||||
if any(
|
||||
x != 0 for x in (value.minute, value.second, value.microsecond)
|
||||
):
|
||||
raise ValueError(
|
||||
f'Value {value} at {fieldpath}' f' is not a whole hour.'
|
||||
)
|
||||
elif self.whole_minutes:
|
||||
if any(x != 0 for x in (value.second, value.microsecond)):
|
||||
raise ValueError(
|
||||
f'Value {value} at {fieldpath}' f' is not a whole minute.'
|
||||
)
|
||||
|
||||
|
||||
def _get_origin(anntype: Any) -> Any:
|
||||
"""Given a type annotation, return its origin or itself if there is none.
|
||||
|
||||
This differs from typing.get_origin in that it will never return None.
|
||||
This lets us use the same code path for handling typing.List
|
||||
that we do for handling list, which is good since they can be used
|
||||
interchangeably in annotations.
|
||||
"""
|
||||
origin = typing.get_origin(anntype)
|
||||
return anntype if origin is None else origin
|
||||
|
||||
|
||||
def _parse_annotated(anntype: Any) -> tuple[Any, IOAttrs | None]:
|
||||
"""Parse Annotated() constructs, returning annotated type & IOAttrs."""
|
||||
# If we get an Annotated[foo, bar, eep] we take
|
||||
# foo as the actual type, and we look for IOAttrs instances in
|
||||
# bar/eep to affect our behavior.
|
||||
ioattrs: IOAttrs | None = None
|
||||
if isinstance(anntype, _AnnotatedAlias):
|
||||
annargs = get_args(anntype)
|
||||
for annarg in annargs[1:]:
|
||||
if isinstance(annarg, IOAttrs):
|
||||
if ioattrs is not None:
|
||||
raise RuntimeError(
|
||||
'Multiple IOAttrs instances found for a'
|
||||
' single annotation; this is not supported.'
|
||||
)
|
||||
ioattrs = annarg
|
||||
|
||||
# I occasionally just throw a 'x' down when I mean IOAttrs('x');
|
||||
# catch these mistakes.
|
||||
elif isinstance(annarg, (str, int, float, bool)):
|
||||
raise RuntimeError(
|
||||
f'Raw {type(annarg)} found in Annotated[] entry:'
|
||||
f' {anntype}; this is probably not what you intended.'
|
||||
)
|
||||
anntype = annargs[0]
|
||||
return anntype, ioattrs
|
||||
555
dist/ba_data/python/efro/dataclassio/_inputter.py
vendored
Normal file
555
dist/ba_data/python/efro/dataclassio/_inputter.py
vendored
Normal file
|
|
@ -0,0 +1,555 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for dataclassio related to pulling data into dataclasses."""
|
||||
|
||||
# Note: We do lots of comparing of exact types here which is normally
|
||||
# frowned upon (stuff like isinstance() is usually encouraged).
|
||||
# pylint: disable=unidiomatic-typecheck
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
import dataclasses
|
||||
import typing
|
||||
import types
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
from efro.util import enum_by_value, check_utc
|
||||
from efro.dataclassio._base import (
|
||||
Codec,
|
||||
_parse_annotated,
|
||||
EXTRA_ATTRS_ATTR,
|
||||
_is_valid_for_codec,
|
||||
_get_origin,
|
||||
SIMPLE_TYPES,
|
||||
_raise_type_error,
|
||||
IOExtendedData,
|
||||
)
|
||||
from efro.dataclassio._prep import PrepSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from efro.dataclassio._base import IOAttrs
|
||||
from efro.dataclassio._outputter import _Outputter
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class _Inputter(Generic[T]):
|
||||
def __init__(
|
||||
self,
|
||||
cls: type[T],
|
||||
codec: Codec,
|
||||
coerce_to_float: bool,
|
||||
allow_unknown_attrs: bool = True,
|
||||
discard_unknown_attrs: bool = False,
|
||||
):
|
||||
self._cls = cls
|
||||
self._codec = codec
|
||||
self._coerce_to_float = coerce_to_float
|
||||
self._allow_unknown_attrs = allow_unknown_attrs
|
||||
self._discard_unknown_attrs = discard_unknown_attrs
|
||||
self._soft_default_validator: _Outputter | None = None
|
||||
|
||||
if not allow_unknown_attrs and discard_unknown_attrs:
|
||||
raise ValueError(
|
||||
'discard_unknown_attrs cannot be True'
|
||||
' when allow_unknown_attrs is False.'
|
||||
)
|
||||
|
||||
def run(self, values: dict) -> T:
|
||||
"""Do the thing."""
|
||||
|
||||
# For special extended data types, call their 'will_output' callback.
|
||||
tcls = self._cls
|
||||
if issubclass(tcls, IOExtendedData):
|
||||
tcls.will_input(values)
|
||||
|
||||
out = self._dataclass_from_input(self._cls, '', values)
|
||||
assert isinstance(out, self._cls)
|
||||
return out
|
||||
|
||||
def _value_from_input(
|
||||
self,
|
||||
cls: type,
|
||||
fieldpath: str,
|
||||
anntype: Any,
|
||||
value: Any,
|
||||
ioattrs: IOAttrs | None,
|
||||
) -> Any:
|
||||
"""Convert an assigned value to what a dataclass field expects."""
|
||||
# pylint: disable=too-many-return-statements
|
||||
# pylint: disable=too-many-branches
|
||||
|
||||
origin = _get_origin(anntype)
|
||||
|
||||
if origin is typing.Any:
|
||||
if not _is_valid_for_codec(value, self._codec):
|
||||
raise TypeError(
|
||||
f'Invalid value type for \'{fieldpath}\';'
|
||||
f' \'Any\' typed values must contain only'
|
||||
f' types directly supported by the specified'
|
||||
f' codec ({self._codec.name}); found'
|
||||
f' \'{type(value).__name__}\' which is not.'
|
||||
)
|
||||
return value
|
||||
|
||||
if origin is typing.Union or origin is types.UnionType:
|
||||
# Currently, the only unions we support are None/Value
|
||||
# (translated from Optional), which we verified on prep.
|
||||
# So let's treat this as a simple optional case.
|
||||
if value is None:
|
||||
return None
|
||||
childanntypes_l = [
|
||||
c for c in typing.get_args(anntype) if c is not type(None)
|
||||
] # noqa (pycodestyle complains about *is* with type)
|
||||
assert len(childanntypes_l) == 1
|
||||
return self._value_from_input(
|
||||
cls, fieldpath, childanntypes_l[0], value, ioattrs
|
||||
)
|
||||
|
||||
# Everything below this point assumes the annotation type resolves
|
||||
# to a concrete type. (This should have been verified at prep time).
|
||||
assert isinstance(origin, type)
|
||||
|
||||
if origin in SIMPLE_TYPES:
|
||||
if type(value) is not origin:
|
||||
# Special case: if they want to coerce ints to floats, do so.
|
||||
if (
|
||||
self._coerce_to_float
|
||||
and origin is float
|
||||
and type(value) is int
|
||||
):
|
||||
return float(value)
|
||||
_raise_type_error(fieldpath, type(value), (origin,))
|
||||
return value
|
||||
|
||||
if origin in {list, set}:
|
||||
return self._sequence_from_input(
|
||||
cls, fieldpath, anntype, value, origin, ioattrs
|
||||
)
|
||||
|
||||
if origin is tuple:
|
||||
return self._tuple_from_input(
|
||||
cls, fieldpath, anntype, value, ioattrs
|
||||
)
|
||||
|
||||
if origin is dict:
|
||||
return self._dict_from_input(
|
||||
cls, fieldpath, anntype, value, ioattrs
|
||||
)
|
||||
|
||||
if dataclasses.is_dataclass(origin):
|
||||
return self._dataclass_from_input(origin, fieldpath, value)
|
||||
|
||||
if issubclass(origin, Enum):
|
||||
return enum_by_value(origin, value)
|
||||
|
||||
if issubclass(origin, datetime.datetime):
|
||||
return self._datetime_from_input(cls, fieldpath, value, ioattrs)
|
||||
|
||||
if origin is bytes:
|
||||
return self._bytes_from_input(origin, fieldpath, value)
|
||||
|
||||
raise TypeError(
|
||||
f"Field '{fieldpath}' of type '{anntype}' is unsupported here."
|
||||
)
|
||||
|
||||
def _bytes_from_input(self, cls: type, fieldpath: str, value: Any) -> bytes:
|
||||
"""Given input data, returns bytes."""
|
||||
import base64
|
||||
|
||||
# For firestore, bytes are passed as-is. Otherwise, they're encoded
|
||||
# as base64.
|
||||
if self._codec is Codec.FIRESTORE:
|
||||
if not isinstance(value, bytes):
|
||||
raise TypeError(
|
||||
f'Expected a bytes object for {fieldpath}'
|
||||
f' on {cls.__name__}; got a {type(value)}.'
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
assert self._codec is Codec.JSON
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f'Expected a string object for {fieldpath}'
|
||||
f' on {cls.__name__}; got a {type(value)}.'
|
||||
)
|
||||
return base64.b64decode(value)
|
||||
|
||||
def _dataclass_from_input(
|
||||
self, cls: type, fieldpath: str, values: dict
|
||||
) -> Any:
|
||||
"""Given a dict, instantiates a dataclass of the given type.
|
||||
|
||||
The dict must be in the json-friendly format as emitted from
|
||||
dataclass_to_dict. This means that sequence values such as tuples or
|
||||
sets should be passed as lists, enums should be passed as their
|
||||
associated values, and nested dataclasses should be passed as dicts.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
# pylint: disable=too-many-branches
|
||||
if not isinstance(values, dict):
|
||||
raise TypeError(
|
||||
f'Expected a dict for {fieldpath} on {cls.__name__};'
|
||||
f' got a {type(values)}.'
|
||||
)
|
||||
|
||||
prep = PrepSession(explicit=False).prep_dataclass(
|
||||
cls, recursion_level=0
|
||||
)
|
||||
assert prep is not None
|
||||
|
||||
extra_attrs = {}
|
||||
|
||||
# noinspection PyDataclass
|
||||
fields = dataclasses.fields(cls)
|
||||
fields_by_name = {f.name: f for f in fields}
|
||||
|
||||
# Preprocess all fields to convert Annotated[] to contained types
|
||||
# and IOAttrs.
|
||||
parsed_field_annotations = {
|
||||
f.name: _parse_annotated(prep.annotations[f.name]) for f in fields
|
||||
}
|
||||
|
||||
# Go through all data in the input, converting it to either dataclass
|
||||
# args or extra data.
|
||||
args: dict[str, Any] = {}
|
||||
for rawkey, value in values.items():
|
||||
key = prep.storage_names_to_attr_names.get(rawkey, rawkey)
|
||||
field = fields_by_name.get(key)
|
||||
|
||||
# Store unknown attrs off to the side (or error if desired).
|
||||
if field is None:
|
||||
if self._allow_unknown_attrs:
|
||||
if self._discard_unknown_attrs:
|
||||
continue
|
||||
|
||||
# Treat this like 'Any' data; ensure that it is valid
|
||||
# raw json.
|
||||
if not _is_valid_for_codec(value, self._codec):
|
||||
raise TypeError(
|
||||
f'Unknown attr \'{key}\''
|
||||
f' on {fieldpath} contains data type(s)'
|
||||
f' not supported by the specified codec'
|
||||
f' ({self._codec.name}).'
|
||||
)
|
||||
extra_attrs[key] = value
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"'{cls.__name__}' has no '{key}' field."
|
||||
)
|
||||
else:
|
||||
fieldname = field.name
|
||||
anntype, ioattrs = parsed_field_annotations[fieldname]
|
||||
subfieldpath = (
|
||||
f'{fieldpath}.{fieldname}' if fieldpath else fieldname
|
||||
)
|
||||
args[key] = self._value_from_input(
|
||||
cls, subfieldpath, anntype, value, ioattrs
|
||||
)
|
||||
|
||||
# Go through all fields looking for any not yet present in our data.
|
||||
# If we find any such fields with a soft-default value or factory
|
||||
# defined, inject that soft value into our args.
|
||||
for key, aparsed in parsed_field_annotations.items():
|
||||
if key in args:
|
||||
continue
|
||||
ioattrs = aparsed[1]
|
||||
if ioattrs is not None and (
|
||||
ioattrs.soft_default is not ioattrs.MISSING
|
||||
or ioattrs.soft_default_factory is not ioattrs.MISSING
|
||||
):
|
||||
if ioattrs.soft_default is not ioattrs.MISSING:
|
||||
soft_default = ioattrs.soft_default
|
||||
else:
|
||||
assert callable(ioattrs.soft_default_factory)
|
||||
soft_default = ioattrs.soft_default_factory()
|
||||
args[key] = soft_default
|
||||
|
||||
# Make sure these values are valid since we didn't run
|
||||
# them through our normal input type checking.
|
||||
|
||||
self._type_check_soft_default(
|
||||
value=soft_default,
|
||||
anntype=aparsed[0],
|
||||
fieldpath=(f'{fieldpath}.{key}' if fieldpath else key),
|
||||
)
|
||||
|
||||
try:
|
||||
out = cls(**args)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f'Error instantiating class {cls.__name__}'
|
||||
f' at {fieldpath}: {exc}'
|
||||
) from exc
|
||||
if extra_attrs:
|
||||
setattr(out, EXTRA_ATTRS_ATTR, extra_attrs)
|
||||
return out
|
||||
|
||||
def _type_check_soft_default(
|
||||
self, value: Any, anntype: Any, fieldpath: str
|
||||
) -> None:
|
||||
from efro.dataclassio._outputter import _Outputter
|
||||
|
||||
# Counter-intuitively, we create an outputter as part of
|
||||
# our inputter. Soft-default values are already internal types;
|
||||
# we need to make sure they can go out from there.
|
||||
if self._soft_default_validator is None:
|
||||
self._soft_default_validator = _Outputter(
|
||||
obj=None,
|
||||
create=False,
|
||||
codec=self._codec,
|
||||
coerce_to_float=self._coerce_to_float,
|
||||
)
|
||||
self._soft_default_validator.soft_default_check(
|
||||
value=value, anntype=anntype, fieldpath=fieldpath
|
||||
)
|
||||
|
||||
def _dict_from_input(
|
||||
self,
|
||||
cls: type,
|
||||
fieldpath: str,
|
||||
anntype: Any,
|
||||
value: Any,
|
||||
ioattrs: IOAttrs | None,
|
||||
) -> Any:
|
||||
# pylint: disable=too-many-branches
|
||||
# pylint: disable=too-many-locals
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(
|
||||
f'Expected a dict for \'{fieldpath}\' on {cls.__name__};'
|
||||
f' got a {type(value)}.'
|
||||
)
|
||||
|
||||
childtypes = typing.get_args(anntype)
|
||||
assert len(childtypes) in (0, 2)
|
||||
|
||||
out: dict
|
||||
|
||||
# We treat 'Any' dicts simply as json; we don't do any translating.
|
||||
if not childtypes or childtypes[0] is typing.Any:
|
||||
if not isinstance(value, dict) or not _is_valid_for_codec(
|
||||
value, self._codec
|
||||
):
|
||||
raise TypeError(
|
||||
f'Got invalid value for Dict[Any, Any]'
|
||||
f' at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' all keys and values must be'
|
||||
f' compatible with the specified codec'
|
||||
f' ({self._codec.name}).'
|
||||
)
|
||||
out = value
|
||||
else:
|
||||
out = {}
|
||||
keyanntype, valanntype = childtypes
|
||||
|
||||
# Ok; we've got definite key/value types (which we verified as
|
||||
# valid during prep). Run all keys/values through it.
|
||||
|
||||
# str keys we just take directly since that's supported by json.
|
||||
if keyanntype is str:
|
||||
for key, val in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(
|
||||
f'Got invalid key type {type(key)} for'
|
||||
f' dict key at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' expected a str.'
|
||||
)
|
||||
out[key] = self._value_from_input(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
|
||||
# int keys are stored in json as str versions of themselves.
|
||||
elif keyanntype is int:
|
||||
for key, val in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(
|
||||
f'Got invalid key type {type(key)} for'
|
||||
f' dict key at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' expected a str.'
|
||||
)
|
||||
try:
|
||||
keyint = int(key)
|
||||
except ValueError as exc:
|
||||
raise TypeError(
|
||||
f'Got invalid key value {key} for'
|
||||
f' dict key at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' expected an int in string form.'
|
||||
) from exc
|
||||
out[keyint] = self._value_from_input(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
|
||||
elif issubclass(keyanntype, Enum):
|
||||
# In prep, we verified that all these enums' values have
|
||||
# the same type, so we can just look at the first to see if
|
||||
# this is a string enum or an int enum.
|
||||
enumvaltype = type(next(iter(keyanntype)).value)
|
||||
assert enumvaltype in (int, str)
|
||||
if enumvaltype is str:
|
||||
for key, val in value.items():
|
||||
try:
|
||||
enumval = enum_by_value(keyanntype, key)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f'Got invalid key value {key} for'
|
||||
f' dict key at \'{fieldpath}\''
|
||||
f' on {cls.__name__};'
|
||||
f' expected a value corresponding to'
|
||||
f' a {keyanntype}.'
|
||||
) from exc
|
||||
out[enumval] = self._value_from_input(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
else:
|
||||
for key, val in value.items():
|
||||
try:
|
||||
enumval = enum_by_value(keyanntype, int(key))
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise ValueError(
|
||||
f'Got invalid key value {key} for'
|
||||
f' dict key at \'{fieldpath}\''
|
||||
f' on {cls.__name__};'
|
||||
f' expected {keyanntype} value (though'
|
||||
f' in string form).'
|
||||
) from exc
|
||||
out[enumval] = self._value_from_input(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
|
||||
else:
|
||||
raise RuntimeError(f'Unhandled dict in-key-type {keyanntype}')
|
||||
|
||||
return out
|
||||
|
||||
def _sequence_from_input(
|
||||
self,
|
||||
cls: type,
|
||||
fieldpath: str,
|
||||
anntype: Any,
|
||||
value: Any,
|
||||
seqtype: type,
|
||||
ioattrs: IOAttrs | None,
|
||||
) -> Any:
|
||||
|
||||
# Because we are json-centric, we expect a list for all sequences.
|
||||
if type(value) is not list:
|
||||
raise TypeError(
|
||||
f'Invalid input value for "{fieldpath}";'
|
||||
f' expected a list, got a {type(value).__name__}'
|
||||
)
|
||||
|
||||
childanntypes = typing.get_args(anntype)
|
||||
|
||||
# 'Any' type children; make sure they are valid json values
|
||||
# and then just grab them.
|
||||
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
|
||||
for i, child in enumerate(value):
|
||||
if not _is_valid_for_codec(child, self._codec):
|
||||
raise TypeError(
|
||||
f'Item {i} of {fieldpath} contains'
|
||||
f' data type(s) not supported by json.'
|
||||
)
|
||||
return value if type(value) is seqtype else seqtype(value)
|
||||
|
||||
# We contain elements of some specified type.
|
||||
assert len(childanntypes) == 1
|
||||
childanntype = childanntypes[0]
|
||||
return seqtype(
|
||||
self._value_from_input(cls, fieldpath, childanntype, i, ioattrs)
|
||||
for i in value
|
||||
)
|
||||
|
||||
def _datetime_from_input(
|
||||
self, cls: type, fieldpath: str, value: Any, ioattrs: IOAttrs | None
|
||||
) -> Any:
|
||||
|
||||
# For firestore we expect a datetime object.
|
||||
if self._codec is Codec.FIRESTORE:
|
||||
# Don't compare exact type here, as firestore can give us
|
||||
# a subclass with extended precision.
|
||||
if not isinstance(value, datetime.datetime):
|
||||
raise TypeError(
|
||||
f'Invalid input value for "{fieldpath}" on'
|
||||
f' "{cls.__name__}";'
|
||||
f' expected a datetime, got a {type(value).__name__}'
|
||||
)
|
||||
check_utc(value)
|
||||
return value
|
||||
|
||||
assert self._codec is Codec.JSON
|
||||
|
||||
# We expect a list of 7 ints.
|
||||
if type(value) is not list:
|
||||
raise TypeError(
|
||||
f'Invalid input value for "{fieldpath}" on "{cls.__name__}";'
|
||||
f' expected a list, got a {type(value).__name__}'
|
||||
)
|
||||
if len(value) != 7 or not all(isinstance(x, int) for x in value):
|
||||
raise ValueError(
|
||||
f'Invalid input value for "{fieldpath}" on "{cls.__name__}";'
|
||||
f' expected a list of 7 ints, got {[type(v) for v in value]}.'
|
||||
)
|
||||
out = datetime.datetime( # type: ignore
|
||||
*value, tzinfo=datetime.timezone.utc
|
||||
)
|
||||
if ioattrs is not None:
|
||||
ioattrs.validate_datetime(out, fieldpath)
|
||||
return out
|
||||
|
||||
def _tuple_from_input(
|
||||
self,
|
||||
cls: type,
|
||||
fieldpath: str,
|
||||
anntype: Any,
|
||||
value: Any,
|
||||
ioattrs: IOAttrs | None,
|
||||
) -> Any:
|
||||
|
||||
out: list = []
|
||||
|
||||
# Because we are json-centric, we expect a list for all sequences.
|
||||
if type(value) is not list:
|
||||
raise TypeError(
|
||||
f'Invalid input value for "{fieldpath}";'
|
||||
f' expected a list, got a {type(value).__name__}'
|
||||
)
|
||||
|
||||
childanntypes = typing.get_args(anntype)
|
||||
|
||||
# We should have verified this to be non-zero at prep-time.
|
||||
assert childanntypes
|
||||
|
||||
if len(value) != len(childanntypes):
|
||||
raise ValueError(
|
||||
f'Invalid tuple input for "{fieldpath}";'
|
||||
f' expected {len(childanntypes)} values,'
|
||||
f' found {len(value)}.'
|
||||
)
|
||||
|
||||
for i, childanntype in enumerate(childanntypes):
|
||||
childval = value[i]
|
||||
|
||||
# 'Any' type children; make sure they are valid json values
|
||||
# and then just grab them.
|
||||
if childanntype is typing.Any:
|
||||
if not _is_valid_for_codec(childval, self._codec):
|
||||
raise TypeError(
|
||||
f'Item {i} of {fieldpath} contains'
|
||||
f' data type(s) not supported by json.'
|
||||
)
|
||||
out.append(childval)
|
||||
else:
|
||||
out.append(
|
||||
self._value_from_input(
|
||||
cls, fieldpath, childanntype, childval, ioattrs
|
||||
)
|
||||
)
|
||||
|
||||
assert len(out) == len(childanntypes)
|
||||
return tuple(out)
|
||||
457
dist/ba_data/python/efro/dataclassio/_outputter.py
vendored
Normal file
457
dist/ba_data/python/efro/dataclassio/_outputter.py
vendored
Normal file
|
|
@ -0,0 +1,457 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for dataclassio related to exporting data from dataclasses."""
|
||||
|
||||
# Note: We do lots of comparing of exact types here which is normally
|
||||
# frowned upon (stuff like isinstance() is usually encouraged).
|
||||
# pylint: disable=unidiomatic-typecheck
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
import dataclasses
|
||||
import typing
|
||||
import types
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from efro.util import check_utc
|
||||
from efro.dataclassio._base import (
|
||||
Codec,
|
||||
_parse_annotated,
|
||||
EXTRA_ATTRS_ATTR,
|
||||
_is_valid_for_codec,
|
||||
_get_origin,
|
||||
SIMPLE_TYPES,
|
||||
_raise_type_error,
|
||||
IOExtendedData,
|
||||
)
|
||||
from efro.dataclassio._prep import PrepSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
from efro.dataclassio._base import IOAttrs
|
||||
|
||||
|
||||
class _Outputter:
|
||||
"""Validates or exports data contained in a dataclass instance."""
|
||||
|
||||
def __init__(
|
||||
self, obj: Any, create: bool, codec: Codec, coerce_to_float: bool
|
||||
) -> None:
|
||||
self._obj = obj
|
||||
self._create = create
|
||||
self._codec = codec
|
||||
self._coerce_to_float = coerce_to_float
|
||||
|
||||
def run(self) -> Any:
|
||||
"""Do the thing."""
|
||||
|
||||
assert dataclasses.is_dataclass(self._obj)
|
||||
|
||||
# For special extended data types, call their 'will_output' callback.
|
||||
if isinstance(self._obj, IOExtendedData):
|
||||
self._obj.will_output()
|
||||
|
||||
return self._process_dataclass(type(self._obj), self._obj, '')
|
||||
|
||||
def soft_default_check(
|
||||
self, value: Any, anntype: Any, fieldpath: str
|
||||
) -> None:
|
||||
"""(internal)"""
|
||||
self._process_value(
|
||||
type(value),
|
||||
fieldpath=fieldpath,
|
||||
anntype=anntype,
|
||||
value=value,
|
||||
ioattrs=None,
|
||||
)
|
||||
|
||||
def _process_dataclass(self, cls: type, obj: Any, fieldpath: str) -> Any:
|
||||
# pylint: disable=too-many-locals
|
||||
# pylint: disable=too-many-branches
|
||||
prep = PrepSession(explicit=False).prep_dataclass(
|
||||
type(obj), recursion_level=0
|
||||
)
|
||||
assert prep is not None
|
||||
fields = dataclasses.fields(obj)
|
||||
out: dict[str, Any] | None = {} if self._create else None
|
||||
for field in fields:
|
||||
fieldname = field.name
|
||||
if fieldpath:
|
||||
subfieldpath = f'{fieldpath}.{fieldname}'
|
||||
else:
|
||||
subfieldpath = fieldname
|
||||
anntype = prep.annotations[fieldname]
|
||||
value = getattr(obj, fieldname)
|
||||
|
||||
anntype, ioattrs = _parse_annotated(anntype)
|
||||
|
||||
# If we're not storing default values for this fella,
|
||||
# we can skip all output processing if we've got a default value.
|
||||
if ioattrs is not None and not ioattrs.store_default:
|
||||
# If both soft_defaults and regular field defaults
|
||||
# are present we want to go with soft_defaults since
|
||||
# those same values would be re-injected when reading
|
||||
# the same data back in if we've omitted the field.
|
||||
default_factory: Any = field.default_factory
|
||||
if ioattrs.soft_default is not ioattrs.MISSING:
|
||||
if ioattrs.soft_default == value:
|
||||
continue
|
||||
elif ioattrs.soft_default_factory is not ioattrs.MISSING:
|
||||
assert callable(ioattrs.soft_default_factory)
|
||||
if ioattrs.soft_default_factory() == value:
|
||||
continue
|
||||
elif field.default is not dataclasses.MISSING:
|
||||
if field.default == value:
|
||||
continue
|
||||
elif default_factory is not dataclasses.MISSING:
|
||||
if default_factory() == value:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Field {fieldname} of {cls.__name__} has'
|
||||
f' no source of default values; store_default=False'
|
||||
f' cannot be set for it. (AND THIS SHOULD HAVE BEEN'
|
||||
f' CAUGHT IN PREP!)'
|
||||
)
|
||||
|
||||
outvalue = self._process_value(
|
||||
cls, subfieldpath, anntype, value, ioattrs
|
||||
)
|
||||
if self._create:
|
||||
assert out is not None
|
||||
storagename = (
|
||||
fieldname
|
||||
if (ioattrs is None or ioattrs.storagename is None)
|
||||
else ioattrs.storagename
|
||||
)
|
||||
out[storagename] = outvalue
|
||||
|
||||
# If there's extra-attrs stored on us, check/include them.
|
||||
extra_attrs = getattr(obj, EXTRA_ATTRS_ATTR, None)
|
||||
if isinstance(extra_attrs, dict):
|
||||
if not _is_valid_for_codec(extra_attrs, self._codec):
|
||||
raise TypeError(
|
||||
f'Extra attrs on \'{fieldpath}\' contains data type(s)'
|
||||
f' not supported by \'{self._codec.value}\' codec:'
|
||||
f' {extra_attrs}.'
|
||||
)
|
||||
if self._create:
|
||||
assert out is not None
|
||||
out.update(extra_attrs)
|
||||
return out
|
||||
|
||||
def _process_value(
|
||||
self,
|
||||
cls: type,
|
||||
fieldpath: str,
|
||||
anntype: Any,
|
||||
value: Any,
|
||||
ioattrs: IOAttrs | None,
|
||||
) -> Any:
|
||||
# pylint: disable=too-many-return-statements
|
||||
# pylint: disable=too-many-branches
|
||||
# pylint: disable=too-many-statements
|
||||
|
||||
origin = _get_origin(anntype)
|
||||
|
||||
if origin is typing.Any:
|
||||
if not _is_valid_for_codec(value, self._codec):
|
||||
raise TypeError(
|
||||
f'Invalid value type for \'{fieldpath}\';'
|
||||
f" 'Any' typed values must contain types directly"
|
||||
f' supported by the specified codec ({self._codec.name});'
|
||||
f' found \'{type(value).__name__}\' which is not.'
|
||||
)
|
||||
return value if self._create else None
|
||||
|
||||
if origin is typing.Union or origin is types.UnionType:
|
||||
# Currently, the only unions we support are None/Value
|
||||
# (translated from Optional), which we verified on prep.
|
||||
# So let's treat this as a simple optional case.
|
||||
if value is None:
|
||||
return None
|
||||
childanntypes_l = [
|
||||
c for c in typing.get_args(anntype) if c is not type(None)
|
||||
] # noqa (pycodestyle complains about *is* with type)
|
||||
assert len(childanntypes_l) == 1
|
||||
return self._process_value(
|
||||
cls, fieldpath, childanntypes_l[0], value, ioattrs
|
||||
)
|
||||
|
||||
# Everything below this point assumes the annotation type resolves
|
||||
# to a concrete type. (This should have been verified at prep time).
|
||||
assert isinstance(origin, type)
|
||||
|
||||
# For simple flat types, look for exact matches:
|
||||
if origin in SIMPLE_TYPES:
|
||||
if type(value) is not origin:
|
||||
# Special case: if they want to coerce ints to floats, do so.
|
||||
if (
|
||||
self._coerce_to_float
|
||||
and origin is float
|
||||
and type(value) is int
|
||||
):
|
||||
return float(value) if self._create else None
|
||||
_raise_type_error(fieldpath, type(value), (origin,))
|
||||
return value if self._create else None
|
||||
|
||||
if origin is tuple:
|
||||
if not isinstance(value, tuple):
|
||||
raise TypeError(
|
||||
f'Expected a tuple for {fieldpath};'
|
||||
f' found a {type(value)}'
|
||||
)
|
||||
childanntypes = typing.get_args(anntype)
|
||||
|
||||
# We should have verified this was non-zero at prep-time
|
||||
assert childanntypes
|
||||
if len(value) != len(childanntypes):
|
||||
raise TypeError(
|
||||
f'Tuple at {fieldpath} contains'
|
||||
f' {len(value)} values; type specifies'
|
||||
f' {len(childanntypes)}.'
|
||||
)
|
||||
if self._create:
|
||||
return [
|
||||
self._process_value(
|
||||
cls, fieldpath, childanntypes[i], x, ioattrs
|
||||
)
|
||||
for i, x in enumerate(value)
|
||||
]
|
||||
for i, x in enumerate(value):
|
||||
self._process_value(
|
||||
cls, fieldpath, childanntypes[i], x, ioattrs
|
||||
)
|
||||
return None
|
||||
|
||||
if origin is list:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(
|
||||
f'Expected a list for {fieldpath};'
|
||||
f' found a {type(value)}'
|
||||
)
|
||||
childanntypes = typing.get_args(anntype)
|
||||
|
||||
# 'Any' type children; make sure they are valid values for
|
||||
# the specified codec.
|
||||
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
|
||||
for i, child in enumerate(value):
|
||||
if not _is_valid_for_codec(child, self._codec):
|
||||
raise TypeError(
|
||||
f'Item {i} of {fieldpath} contains'
|
||||
f' data type(s) not supported by the specified'
|
||||
f' codec ({self._codec.name}).'
|
||||
)
|
||||
# Hmm; should we do a copy here?
|
||||
return value if self._create else None
|
||||
|
||||
# We contain elements of some specified type.
|
||||
assert len(childanntypes) == 1
|
||||
if self._create:
|
||||
return [
|
||||
self._process_value(
|
||||
cls, fieldpath, childanntypes[0], x, ioattrs
|
||||
)
|
||||
for x in value
|
||||
]
|
||||
for x in value:
|
||||
self._process_value(
|
||||
cls, fieldpath, childanntypes[0], x, ioattrs
|
||||
)
|
||||
return None
|
||||
|
||||
if origin is set:
|
||||
if not isinstance(value, set):
|
||||
raise TypeError(
|
||||
f'Expected a set for {fieldpath};' f' found a {type(value)}'
|
||||
)
|
||||
childanntypes = typing.get_args(anntype)
|
||||
|
||||
# 'Any' type children; make sure they are valid Any values.
|
||||
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
|
||||
for child in value:
|
||||
if not _is_valid_for_codec(child, self._codec):
|
||||
raise TypeError(
|
||||
f'Set at {fieldpath} contains'
|
||||
f' data type(s) not supported by the'
|
||||
f' specified codec ({self._codec.name}).'
|
||||
)
|
||||
return list(value) if self._create else None
|
||||
|
||||
# We contain elements of some specified type.
|
||||
assert len(childanntypes) == 1
|
||||
if self._create:
|
||||
# Note: we output json-friendly values so this becomes
|
||||
# a list.
|
||||
return [
|
||||
self._process_value(
|
||||
cls, fieldpath, childanntypes[0], x, ioattrs
|
||||
)
|
||||
for x in value
|
||||
]
|
||||
for x in value:
|
||||
self._process_value(
|
||||
cls, fieldpath, childanntypes[0], x, ioattrs
|
||||
)
|
||||
return None
|
||||
|
||||
if origin is dict:
|
||||
return self._process_dict(cls, fieldpath, anntype, value, ioattrs)
|
||||
|
||||
if dataclasses.is_dataclass(origin):
|
||||
if not isinstance(value, origin):
|
||||
raise TypeError(
|
||||
f'Expected a {origin} for {fieldpath};'
|
||||
f' found a {type(value)}.'
|
||||
)
|
||||
return self._process_dataclass(cls, value, fieldpath)
|
||||
|
||||
if issubclass(origin, Enum):
|
||||
if not isinstance(value, origin):
|
||||
raise TypeError(
|
||||
f'Expected a {origin} for {fieldpath};'
|
||||
f' found a {type(value)}.'
|
||||
)
|
||||
# At prep-time we verified that these enums had valid value
|
||||
# types, so we can blindly return it here.
|
||||
return value.value if self._create else None
|
||||
|
||||
if issubclass(origin, datetime.datetime):
|
||||
if not isinstance(value, origin):
|
||||
raise TypeError(
|
||||
f'Expected a {origin} for {fieldpath};'
|
||||
f' found a {type(value)}.'
|
||||
)
|
||||
check_utc(value)
|
||||
if ioattrs is not None:
|
||||
ioattrs.validate_datetime(value, fieldpath)
|
||||
if self._codec is Codec.FIRESTORE:
|
||||
return value
|
||||
assert self._codec is Codec.JSON
|
||||
return (
|
||||
[
|
||||
value.year,
|
||||
value.month,
|
||||
value.day,
|
||||
value.hour,
|
||||
value.minute,
|
||||
value.second,
|
||||
value.microsecond,
|
||||
]
|
||||
if self._create
|
||||
else None
|
||||
)
|
||||
|
||||
if origin is bytes:
|
||||
return self._process_bytes(cls, fieldpath, value)
|
||||
|
||||
raise TypeError(
|
||||
f"Field '{fieldpath}' of type '{anntype}' is unsupported here."
|
||||
)
|
||||
|
||||
def _process_bytes(self, cls: type, fieldpath: str, value: bytes) -> Any:
|
||||
import base64
|
||||
|
||||
if not isinstance(value, bytes):
|
||||
raise TypeError(
|
||||
f'Expected bytes for {fieldpath} on {cls.__name__};'
|
||||
f' found a {type(value)}.'
|
||||
)
|
||||
|
||||
if not self._create:
|
||||
return None
|
||||
|
||||
# In JSON we convert to base64, but firestore directly supports bytes.
|
||||
if self._codec is Codec.JSON:
|
||||
return base64.b64encode(value).decode()
|
||||
|
||||
assert self._codec is Codec.FIRESTORE
|
||||
return value
|
||||
|
||||
def _process_dict(
|
||||
self,
|
||||
cls: type,
|
||||
fieldpath: str,
|
||||
anntype: Any,
|
||||
value: dict,
|
||||
ioattrs: IOAttrs | None,
|
||||
) -> Any:
|
||||
# pylint: disable=too-many-branches
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(
|
||||
f'Expected a dict for {fieldpath};' f' found a {type(value)}.'
|
||||
)
|
||||
childtypes = typing.get_args(anntype)
|
||||
assert len(childtypes) in (0, 2)
|
||||
|
||||
# We treat 'Any' dicts simply as json; we don't do any translating.
|
||||
if not childtypes or childtypes[0] is typing.Any:
|
||||
if not isinstance(value, dict) or not _is_valid_for_codec(
|
||||
value, self._codec
|
||||
):
|
||||
raise TypeError(
|
||||
f'Invalid value for Dict[Any, Any]'
|
||||
f' at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' all keys and values must be directly compatible'
|
||||
f' with the specified codec ({self._codec.name})'
|
||||
f' when dict type is Any.'
|
||||
)
|
||||
return value if self._create else None
|
||||
|
||||
# Ok; we've got a definite key type (which we verified as valid
|
||||
# during prep). Make sure all keys match it.
|
||||
out: dict | None = {} if self._create else None
|
||||
keyanntype, valanntype = childtypes
|
||||
|
||||
# str keys we just export directly since that's supported by json.
|
||||
if keyanntype is str:
|
||||
for key, val in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(
|
||||
f'Got invalid key type {type(key)} for'
|
||||
f' dict key at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' expected {keyanntype}.'
|
||||
)
|
||||
outval = self._process_value(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
if self._create:
|
||||
assert out is not None
|
||||
out[key] = outval
|
||||
|
||||
# int keys are stored as str versions of themselves.
|
||||
elif keyanntype is int:
|
||||
for key, val in value.items():
|
||||
if not isinstance(key, int):
|
||||
raise TypeError(
|
||||
f'Got invalid key type {type(key)} for'
|
||||
f' dict key at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' expected an int.'
|
||||
)
|
||||
outval = self._process_value(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
if self._create:
|
||||
assert out is not None
|
||||
out[str(key)] = outval
|
||||
|
||||
elif issubclass(keyanntype, Enum):
|
||||
for key, val in value.items():
|
||||
if not isinstance(key, keyanntype):
|
||||
raise TypeError(
|
||||
f'Got invalid key type {type(key)} for'
|
||||
f' dict key at \'{fieldpath}\' on {cls.__name__};'
|
||||
f' expected a {keyanntype}.'
|
||||
)
|
||||
outval = self._process_value(
|
||||
cls, fieldpath, valanntype, val, ioattrs
|
||||
)
|
||||
if self._create:
|
||||
assert out is not None
|
||||
out[str(key.value)] = outval
|
||||
else:
|
||||
raise RuntimeError(f'Unhandled dict out-key-type {keyanntype}')
|
||||
|
||||
return out
|
||||
115
dist/ba_data/python/efro/dataclassio/_pathcapture.py
vendored
Normal file
115
dist/ba_data/python/efro/dataclassio/_pathcapture.py
vendored
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality related to capturing nested dataclass paths."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic
|
||||
|
||||
from efro.dataclassio._base import _parse_annotated, _get_origin
|
||||
from efro.dataclassio._prep import PrepSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class _PathCapture:
|
||||
"""Utility for obtaining dataclass storage paths in a type safe way."""
|
||||
|
||||
def __init__(self, obj: Any, pathparts: list[str] | None = None):
|
||||
self._is_dataclass = dataclasses.is_dataclass(obj)
|
||||
if pathparts is None:
|
||||
pathparts = []
|
||||
self._cls = obj if isinstance(obj, type) else type(obj)
|
||||
self._pathparts = pathparts
|
||||
|
||||
def __getattr__(self, name: str) -> _PathCapture:
|
||||
|
||||
# We only allow diving into sub-objects if we are a dataclass.
|
||||
if not self._is_dataclass:
|
||||
raise TypeError(
|
||||
f"Field path cannot include attribute '{name}' "
|
||||
f'under parent {self._cls}; parent types must be dataclasses.'
|
||||
)
|
||||
|
||||
prep = PrepSession(explicit=False).prep_dataclass(
|
||||
self._cls, recursion_level=0
|
||||
)
|
||||
assert prep is not None
|
||||
try:
|
||||
anntype = prep.annotations[name]
|
||||
except KeyError as exc:
|
||||
raise AttributeError(f'{type(self)} has no {name} field.') from exc
|
||||
anntype, ioattrs = _parse_annotated(anntype)
|
||||
storagename = (
|
||||
name
|
||||
if (ioattrs is None or ioattrs.storagename is None)
|
||||
else ioattrs.storagename
|
||||
)
|
||||
origin = _get_origin(anntype)
|
||||
return _PathCapture(origin, pathparts=self._pathparts + [storagename])
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
"""The final output path."""
|
||||
return '.'.join(self._pathparts)
|
||||
|
||||
|
||||
class DataclassFieldLookup(Generic[T]):
|
||||
"""Get info about nested dataclass fields in type-safe way."""
|
||||
|
||||
def __init__(self, cls: type[T]) -> None:
|
||||
self.cls = cls
|
||||
|
||||
def path(self, callback: Callable[[T], Any]) -> str:
|
||||
"""Look up a path on child dataclass fields.
|
||||
|
||||
example:
|
||||
DataclassFieldLookup(MyType).path(lambda obj: obj.foo.bar)
|
||||
|
||||
The above example will return the string 'foo.bar' or something
|
||||
like 'f.b' if the dataclasses have custom storage names set.
|
||||
It will also be static-type-checked, triggering an error if
|
||||
MyType.foo.bar is not a valid path. Note, however, that the
|
||||
callback technically allows any return value but only nested
|
||||
dataclasses and their fields will succeed.
|
||||
"""
|
||||
|
||||
# We tell the type system that we are returning an instance
|
||||
# of our class, which allows it to perform type checking on
|
||||
# member lookups. In reality, however, we are providing a
|
||||
# special object which captures path lookups, so we can build
|
||||
# a string from them.
|
||||
if not TYPE_CHECKING:
|
||||
out = callback(_PathCapture(self.cls))
|
||||
if not isinstance(out, _PathCapture):
|
||||
raise TypeError(
|
||||
f'Expected a valid path under'
|
||||
f' the provided object; got a {type(out)}.'
|
||||
)
|
||||
return out.path
|
||||
return ''
|
||||
|
||||
def paths(self, callback: Callable[[T], list[Any]]) -> list[str]:
|
||||
"""Look up multiple paths on child dataclass fields.
|
||||
|
||||
Functionality is identical to path() but for multiple paths at once.
|
||||
|
||||
example:
|
||||
DataclassFieldLookup(MyType).paths(lambda obj: [obj.foo, obj.bar])
|
||||
"""
|
||||
outvals: list[str] = []
|
||||
if not TYPE_CHECKING:
|
||||
outs = callback(_PathCapture(self.cls))
|
||||
assert isinstance(outs, list)
|
||||
for out in outs:
|
||||
if not isinstance(out, _PathCapture):
|
||||
raise TypeError(
|
||||
f'Expected a valid path under'
|
||||
f' the provided object; got a {type(out)}.'
|
||||
)
|
||||
outvals.append(out.path)
|
||||
return outvals
|
||||
459
dist/ba_data/python/efro/dataclassio/_prep.py
vendored
Normal file
459
dist/ba_data/python/efro/dataclassio/_prep.py
vendored
Normal file
|
|
@ -0,0 +1,459 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Functionality for prepping types for use with dataclassio."""
|
||||
|
||||
# Note: We do lots of comparing of exact types here which is normally
|
||||
# frowned upon (stuff like isinstance() is usually encouraged).
|
||||
# pylint: disable=unidiomatic-typecheck
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
import dataclasses
|
||||
import typing
|
||||
import types
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, TypeVar, get_type_hints
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from efro.dataclassio._base import _parse_annotated, _get_origin, SIMPLE_TYPES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
from efro.dataclassio._base import IOAttrs
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# How deep we go when prepping nested types
|
||||
# (basically for detecting recursive types)
|
||||
MAX_RECURSION = 10
|
||||
|
||||
# Attr name for data we store on dataclass types that have been prepped.
|
||||
PREP_ATTR = '_DCIOPREP'
|
||||
|
||||
# We also store the prep-session while the prep is in progress.
|
||||
# (necessary to support recursive types).
|
||||
PREP_SESSION_ATTR = '_DCIOPREPSESSION'
|
||||
|
||||
|
||||
def ioprep(cls: type, globalns: dict | None = None) -> None:
|
||||
"""Prep a dataclass type for use with this module's functionality.
|
||||
|
||||
Prepping ensures that all types contained in a data class as well as
|
||||
the usage of said types are supported by this module and pre-builds
|
||||
necessary constructs needed for encoding/decoding/etc.
|
||||
|
||||
Prepping will happen on-the-fly as needed, but a warning will be
|
||||
emitted in such cases, as it is better to explicitly prep all used types
|
||||
early in a process to ensure any invalid types or configuration are caught
|
||||
immediately.
|
||||
|
||||
Prepping a dataclass involves evaluating its type annotations, which,
|
||||
as of PEP 563, are stored simply as strings. This evaluation is done
|
||||
with localns set to the class dict (so that types defined in the class
|
||||
can be used) and globalns set to the containing module's class.
|
||||
It is possible to override globalns for special cases such as when
|
||||
prepping happens as part of an execed string instead of within a
|
||||
module.
|
||||
"""
|
||||
PrepSession(explicit=True, globalns=globalns).prep_dataclass(
|
||||
cls, recursion_level=0
|
||||
)
|
||||
|
||||
|
||||
def ioprepped(cls: type[T]) -> type[T]:
|
||||
"""Class decorator for easily prepping a dataclass at definition time.
|
||||
|
||||
Note that in some cases it may not be possible to prep a dataclass
|
||||
immediately (such as when its type annotations refer to forward-declared
|
||||
types). In these cases, dataclass_prep() should be explicitly called for
|
||||
the class as soon as possible; ideally at module import time to expose any
|
||||
errors as early as possible in execution.
|
||||
"""
|
||||
ioprep(cls)
|
||||
return cls
|
||||
|
||||
|
||||
def will_ioprep(cls: type[T]) -> type[T]:
|
||||
"""Class decorator hinting that we will prep a class later.
|
||||
|
||||
In some cases (such as recursive types) we cannot use the @ioprepped
|
||||
decorator and must instead call ioprep() explicitly later. However,
|
||||
some of our custom pylint checking behaves differently when the
|
||||
@ioprepped decorator is present, in that case requiring type annotations
|
||||
to be present and not simply forward declared under an "if TYPE_CHECKING"
|
||||
block. (since they are used at runtime).
|
||||
|
||||
The @will_ioprep decorator triggers the same pylint behavior
|
||||
differences as @ioprepped (which are necessary for the later ioprep() call
|
||||
to work correctly) but without actually running any prep itself.
|
||||
"""
|
||||
return cls
|
||||
|
||||
|
||||
def is_ioprepped_dataclass(obj: Any) -> bool:
|
||||
"""Return whether the obj is an ioprepped dataclass type or instance."""
|
||||
cls = obj if isinstance(obj, type) else type(obj)
|
||||
return dataclasses.is_dataclass(cls) and hasattr(cls, PREP_ATTR)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PrepData:
|
||||
"""Data we prepare and cache for a class during prep.
|
||||
|
||||
This data is used as part of the encoding/decoding/validating process.
|
||||
"""
|
||||
|
||||
# Resolved annotation data with 'live' classes.
|
||||
annotations: dict[str, Any]
|
||||
|
||||
# Map of storage names to attr names.
|
||||
storage_names_to_attr_names: dict[str, str]
|
||||
|
||||
|
||||
class PrepSession:
|
||||
"""Context for a prep."""
|
||||
|
||||
def __init__(self, explicit: bool, globalns: dict | None = None):
|
||||
self.explicit = explicit
|
||||
self.globalns = globalns
|
||||
|
||||
def prep_dataclass(
|
||||
self, cls: type, recursion_level: int
|
||||
) -> PrepData | None:
|
||||
"""Run prep on a dataclass if necessary and return its prep data.
|
||||
|
||||
The only case where this will return None is for recursive types
|
||||
if the type is already being prepped higher in the call order.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
# pylint: disable=too-many-branches
|
||||
|
||||
# We should only need to do this once per dataclass.
|
||||
existing_data = getattr(cls, PREP_ATTR, None)
|
||||
if existing_data is not None:
|
||||
assert isinstance(existing_data, PrepData)
|
||||
return existing_data
|
||||
|
||||
# Sanity check.
|
||||
# Note that we now support recursive types via the PREP_SESSION_ATTR,
|
||||
# so we theoretically shouldn't run into this this.
|
||||
if recursion_level > MAX_RECURSION:
|
||||
raise RuntimeError('Max recursion exceeded.')
|
||||
|
||||
# We should only be passed classes which are dataclasses.
|
||||
if not isinstance(cls, type) or not dataclasses.is_dataclass(cls):
|
||||
raise TypeError(f'Passed arg {cls} is not a dataclass type.')
|
||||
|
||||
# Add a pointer to the prep-session while doing the prep.
|
||||
# This way we can ignore types that we're already in the process
|
||||
# of prepping and can support recursive types.
|
||||
existing_prep = getattr(cls, PREP_SESSION_ATTR, None)
|
||||
if existing_prep is not None:
|
||||
if existing_prep is self:
|
||||
return None
|
||||
# We shouldn't need to support failed preps
|
||||
# or preps from multiple threads at once.
|
||||
raise RuntimeError('Found existing in-progress prep.')
|
||||
setattr(cls, PREP_SESSION_ATTR, self)
|
||||
|
||||
# Generate a warning on non-explicit preps; we prefer prep to
|
||||
# happen explicitly at runtime so errors can be detected early on.
|
||||
if not self.explicit:
|
||||
logging.warning(
|
||||
'efro.dataclassio: implicitly prepping dataclass: %s.'
|
||||
' It is highly recommended to explicitly prep dataclasses'
|
||||
' as soon as possible after definition (via'
|
||||
' efro.dataclassio.ioprep() or the'
|
||||
' @efro.dataclassio.ioprepped decorator).',
|
||||
cls,
|
||||
)
|
||||
|
||||
try:
|
||||
# NOTE: Now passing the class' __dict__ (vars()) as locals
|
||||
# which allows us to pick up nested classes, etc.
|
||||
resolved_annotations = get_type_hints(
|
||||
cls,
|
||||
localns=vars(cls),
|
||||
globalns=self.globalns,
|
||||
include_extras=True,
|
||||
)
|
||||
# pylint: enable=unexpected-keyword-arg
|
||||
except Exception as exc:
|
||||
raise TypeError(
|
||||
f'dataclassio prep for {cls} failed with error: {exc}.'
|
||||
f' Make sure all types used in annotations are defined'
|
||||
f' at the module or class level or add them as part of an'
|
||||
f' explicit prep call.'
|
||||
) from exc
|
||||
|
||||
# noinspection PyDataclass
|
||||
fields = dataclasses.fields(cls)
|
||||
fields_by_name = {f.name: f for f in fields}
|
||||
|
||||
all_storage_names: set[str] = set()
|
||||
storage_names_to_attr_names: dict[str, str] = {}
|
||||
|
||||
# Ok; we've resolved actual types for this dataclass.
|
||||
# now recurse through them, verifying that we support all contained
|
||||
# types and prepping any contained dataclass types.
|
||||
for attrname, anntype in resolved_annotations.items():
|
||||
|
||||
anntype, ioattrs = _parse_annotated(anntype)
|
||||
|
||||
# If we found attached IOAttrs data, make sure it contains
|
||||
# valid values for the field it is attached to.
|
||||
if ioattrs is not None:
|
||||
ioattrs.validate_for_field(cls, fields_by_name[attrname])
|
||||
if ioattrs.storagename is not None:
|
||||
storagename = ioattrs.storagename
|
||||
storage_names_to_attr_names[ioattrs.storagename] = attrname
|
||||
else:
|
||||
storagename = attrname
|
||||
else:
|
||||
storagename = attrname
|
||||
|
||||
# Make sure we don't have any clashes in our storage names.
|
||||
if storagename in all_storage_names:
|
||||
raise TypeError(
|
||||
f'Multiple attrs on {cls} are using'
|
||||
f' storage-name \'{storagename}\''
|
||||
)
|
||||
all_storage_names.add(storagename)
|
||||
|
||||
self.prep_type(
|
||||
cls,
|
||||
attrname,
|
||||
anntype,
|
||||
ioattrs=ioattrs,
|
||||
recursion_level=recursion_level + 1,
|
||||
)
|
||||
|
||||
# Success! Store our resolved stuff with the class and we're done.
|
||||
prepdata = PrepData(
|
||||
annotations=resolved_annotations,
|
||||
storage_names_to_attr_names=storage_names_to_attr_names,
|
||||
)
|
||||
setattr(cls, PREP_ATTR, prepdata)
|
||||
|
||||
# Clear our prep-session tag.
|
||||
assert getattr(cls, PREP_SESSION_ATTR, None) is self
|
||||
delattr(cls, PREP_SESSION_ATTR)
|
||||
return prepdata
|
||||
|
||||
def prep_type(
|
||||
self,
|
||||
cls: type,
|
||||
attrname: str,
|
||||
anntype: Any,
|
||||
ioattrs: IOAttrs | None,
|
||||
recursion_level: int,
|
||||
) -> None:
|
||||
"""Run prep on a dataclass."""
|
||||
# pylint: disable=too-many-return-statements
|
||||
# pylint: disable=too-many-branches
|
||||
# pylint: disable=too-many-statements
|
||||
|
||||
if recursion_level > MAX_RECURSION:
|
||||
raise RuntimeError('Max recursion exceeded.')
|
||||
|
||||
origin = _get_origin(anntype)
|
||||
|
||||
if origin is typing.Union or origin is types.UnionType:
|
||||
self.prep_union(
|
||||
cls, attrname, anntype, recursion_level=recursion_level + 1
|
||||
)
|
||||
return
|
||||
|
||||
if anntype is typing.Any:
|
||||
return
|
||||
|
||||
# Everything below this point assumes the annotation type resolves
|
||||
# to a concrete type.
|
||||
if not isinstance(origin, type):
|
||||
raise TypeError(
|
||||
f'Unsupported type found for \'{attrname}\' on {cls}:'
|
||||
f' {anntype}'
|
||||
)
|
||||
|
||||
# If a soft_default value/factory was passed, we do some basic
|
||||
# type checking on the top-level value here. We also run full
|
||||
# recursive validation on values later during inputting, but this
|
||||
# should catch at least some errors early on, which can be
|
||||
# useful since soft_defaults are not static type checked.
|
||||
if ioattrs is not None:
|
||||
have_soft_default = False
|
||||
soft_default: Any = None
|
||||
if ioattrs.soft_default is not ioattrs.MISSING:
|
||||
have_soft_default = True
|
||||
soft_default = ioattrs.soft_default
|
||||
elif ioattrs.soft_default_factory is not ioattrs.MISSING:
|
||||
assert callable(ioattrs.soft_default_factory)
|
||||
have_soft_default = True
|
||||
soft_default = ioattrs.soft_default_factory()
|
||||
|
||||
# Do a simple type check for the top level to catch basic
|
||||
# soft_default mismatches early; full check will happen at
|
||||
# input time.
|
||||
if have_soft_default:
|
||||
if not isinstance(soft_default, origin):
|
||||
raise TypeError(
|
||||
f'{cls} attr {attrname} has type {origin}'
|
||||
f' but soft_default value is type {type(soft_default)}'
|
||||
)
|
||||
|
||||
if origin in SIMPLE_TYPES:
|
||||
return
|
||||
|
||||
# For sets and lists, check out their single contained type (if any).
|
||||
if origin in (list, set):
|
||||
childtypes = typing.get_args(anntype)
|
||||
if len(childtypes) == 0:
|
||||
# This is equivalent to Any; nothing else needs checking.
|
||||
return
|
||||
if len(childtypes) > 1:
|
||||
raise TypeError(
|
||||
f'Unrecognized typing arg count {len(childtypes)}'
|
||||
f" for {anntype} attr '{attrname}' on {cls}"
|
||||
)
|
||||
self.prep_type(
|
||||
cls,
|
||||
attrname,
|
||||
childtypes[0],
|
||||
ioattrs=None,
|
||||
recursion_level=recursion_level + 1,
|
||||
)
|
||||
return
|
||||
|
||||
if origin is dict:
|
||||
childtypes = typing.get_args(anntype)
|
||||
assert len(childtypes) in (0, 2)
|
||||
|
||||
# For key types we support Any, str, int,
|
||||
# and Enums with uniform str/int values.
|
||||
if not childtypes or childtypes[0] is typing.Any:
|
||||
# 'Any' needs no further checks (just checked per-instance).
|
||||
pass
|
||||
elif childtypes[0] in (str, int):
|
||||
# str and int are all good as keys.
|
||||
pass
|
||||
elif issubclass(childtypes[0], Enum):
|
||||
# Allow our usual str or int enum types as keys.
|
||||
self.prep_enum(childtypes[0])
|
||||
else:
|
||||
raise TypeError(
|
||||
f'Dict key type {childtypes[0]} for \'{attrname}\''
|
||||
f' on {cls.__name__} is not supported by dataclassio.'
|
||||
)
|
||||
|
||||
# For value types we support any of our normal types.
|
||||
if not childtypes or _get_origin(childtypes[1]) is typing.Any:
|
||||
# 'Any' needs no further checks (just checked per-instance).
|
||||
pass
|
||||
else:
|
||||
self.prep_type(
|
||||
cls,
|
||||
attrname,
|
||||
childtypes[1],
|
||||
ioattrs=None,
|
||||
recursion_level=recursion_level + 1,
|
||||
)
|
||||
return
|
||||
|
||||
# For Tuples, simply check individual member types.
|
||||
# (and, for now, explicitly disallow zero member types or usage
|
||||
# of ellipsis)
|
||||
if origin is tuple:
|
||||
childtypes = typing.get_args(anntype)
|
||||
if not childtypes:
|
||||
raise TypeError(
|
||||
f'Tuple at \'{attrname}\''
|
||||
f' has no type args; dataclassio requires type args.'
|
||||
)
|
||||
if childtypes[-1] is ...:
|
||||
raise TypeError(
|
||||
f'Found ellipsis as part of type for'
|
||||
f' \'{attrname}\' on {cls.__name__};'
|
||||
f' these are not'
|
||||
f' supported by dataclassio.'
|
||||
)
|
||||
for childtype in childtypes:
|
||||
self.prep_type(
|
||||
cls,
|
||||
attrname,
|
||||
childtype,
|
||||
ioattrs=None,
|
||||
recursion_level=recursion_level + 1,
|
||||
)
|
||||
return
|
||||
|
||||
if issubclass(origin, Enum):
|
||||
self.prep_enum(origin)
|
||||
return
|
||||
|
||||
# We allow datetime objects (and google's extended subclass of them
|
||||
# used in firestore, which is why we don't look for exact type here).
|
||||
if issubclass(origin, datetime.datetime):
|
||||
return
|
||||
|
||||
if dataclasses.is_dataclass(origin):
|
||||
self.prep_dataclass(origin, recursion_level=recursion_level + 1)
|
||||
return
|
||||
|
||||
if origin is bytes:
|
||||
return
|
||||
|
||||
raise TypeError(
|
||||
f"Attr '{attrname}' on {cls.__name__} contains"
|
||||
f" type '{anntype}'"
|
||||
f' which is not supported by dataclassio.'
|
||||
)
|
||||
|
||||
def prep_union(
|
||||
self, cls: type, attrname: str, anntype: Any, recursion_level: int
|
||||
) -> None:
|
||||
"""Run prep on a Union type."""
|
||||
typeargs = typing.get_args(anntype)
|
||||
if (
|
||||
len(typeargs) != 2
|
||||
or len([c for c in typeargs if c is type(None)]) != 1
|
||||
): # noqa
|
||||
raise TypeError(
|
||||
f'Union {anntype} for attr \'{attrname}\' on'
|
||||
f' {cls.__name__} is not supported by dataclassio;'
|
||||
f' only 2 member Unions with one type being None'
|
||||
f' are supported.'
|
||||
)
|
||||
for childtype in typeargs:
|
||||
self.prep_type(
|
||||
cls,
|
||||
attrname,
|
||||
childtype,
|
||||
None,
|
||||
recursion_level=recursion_level + 1,
|
||||
)
|
||||
|
||||
def prep_enum(self, enumtype: type[Enum]) -> None:
|
||||
"""Run prep on an enum type."""
|
||||
|
||||
valtype: Any = None
|
||||
|
||||
# We currently support enums with str or int values; fail if we
|
||||
# find any others.
|
||||
for enumval in enumtype:
|
||||
if not isinstance(enumval.value, (str, int)):
|
||||
raise TypeError(
|
||||
f'Enum value {enumval} has value type'
|
||||
f' {type(enumval.value)}; only str and int is'
|
||||
f' supported by dataclassio.'
|
||||
)
|
||||
if valtype is None:
|
||||
valtype = type(enumval.value)
|
||||
else:
|
||||
if type(enumval.value) is not valtype:
|
||||
raise TypeError(
|
||||
f'Enum type {enumtype} has multiple'
|
||||
f' value types; dataclassio requires'
|
||||
f' them to be uniform.'
|
||||
)
|
||||
71
dist/ba_data/python/efro/dataclassio/extras.py
vendored
Normal file
71
dist/ba_data/python/efro/dataclassio/extras.py
vendored
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
# Released under the MIT License. See LICENSE for details.
|
||||
#
|
||||
"""Extra rarely-needed functionality related to dataclasses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
def dataclass_diff(obj1: Any, obj2: Any) -> str:
|
||||
"""Generate a string showing differences between two dataclass instances.
|
||||
|
||||
Both must be of the exact same type.
|
||||
"""
|
||||
diff = _diff(obj1, obj2, 2)
|
||||
return ' <no differences>' if diff == '' else diff
|
||||
|
||||
|
||||
class DataclassDiff:
|
||||
"""Wraps dataclass_diff() in an object for efficiency.
|
||||
|
||||
It is preferable to pass this to logging calls instead of the
|
||||
final diff string since the diff will never be generated if
|
||||
the associated logging level is not being emitted.
|
||||
"""
|
||||
|
||||
def __init__(self, obj1: Any, obj2: Any):
|
||||
self._obj1 = obj1
|
||||
self._obj2 = obj2
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return dataclass_diff(self._obj1, self._obj2)
|
||||
|
||||
|
||||
def _diff(obj1: Any, obj2: Any, indent: int) -> str:
|
||||
assert dataclasses.is_dataclass(obj1)
|
||||
assert dataclasses.is_dataclass(obj2)
|
||||
if type(obj1) is not type(obj2):
|
||||
raise TypeError(
|
||||
f'Passed objects are not of the same'
|
||||
f' type ({type(obj1)} and {type(obj2)}).'
|
||||
)
|
||||
bits: list[str] = []
|
||||
indentstr = ' ' * indent
|
||||
fields = dataclasses.fields(obj1)
|
||||
for field in fields:
|
||||
fieldname = field.name
|
||||
val1 = getattr(obj1, fieldname)
|
||||
val2 = getattr(obj2, fieldname)
|
||||
|
||||
# For nested dataclasses, dive in and do nice piecewise compares.
|
||||
if (
|
||||
dataclasses.is_dataclass(val1)
|
||||
and dataclasses.is_dataclass(val2)
|
||||
and type(val1) is type(val2)
|
||||
):
|
||||
diff = _diff(val1, val2, indent + 2)
|
||||
if diff != '':
|
||||
bits.append(f'{indentstr}{fieldname}:')
|
||||
bits.append(diff)
|
||||
|
||||
# For all else just do a single line
|
||||
# (perhaps we could improve on this for other complex types)
|
||||
else:
|
||||
if val1 != val2:
|
||||
bits.append(f'{indentstr}{fieldname}: {val1} -> {val2}')
|
||||
return '\n'.join(bits)
|
||||
Loading…
Add table
Add a link
Reference in a new issue