Delete dist directory

This commit is contained in:
Mikahael 2024-02-20 22:53:45 +05:30 committed by GitHub
parent 2e2c838750
commit 867634cc5c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
1779 changed files with 0 additions and 565850 deletions

View file

@ -1,50 +0,0 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for importing, exporting, and validating dataclasses.
This allows complex nested dataclasses to be flattened to json-compatible
data and restored from said data. It also gracefully handles and preserves
unrecognized attribute data, allowing older clients to interact with newer
data formats in a nondestructive manner.
"""
from __future__ import annotations
from efro.util import set_canonical_module
from efro.dataclassio._base import Codec, IOAttrs, IOExtendedData
from efro.dataclassio._prep import (
ioprep,
ioprepped,
will_ioprep,
is_ioprepped_dataclass,
)
from efro.dataclassio._pathcapture import DataclassFieldLookup
from efro.dataclassio._api import (
JsonStyle,
dataclass_to_dict,
dataclass_to_json,
dataclass_from_dict,
dataclass_from_json,
dataclass_validate,
)
__all__ = [
'JsonStyle',
'Codec',
'IOAttrs',
'IOExtendedData',
'ioprep',
'ioprepped',
'will_ioprep',
'is_ioprepped_dataclass',
'DataclassFieldLookup',
'dataclass_to_dict',
'dataclass_to_json',
'dataclass_from_dict',
'dataclass_from_json',
'dataclass_validate',
]
# Have these things present themselves cleanly as 'thismodule.SomeClass'
# instead of 'thismodule._internalmodule.SomeClass'
set_canonical_module(module_globals=globals(), names=__all__)

View file

@ -1,163 +0,0 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for importing, exporting, and validating dataclasses.
This allows complex nested dataclasses to be flattened to json-compatible
data and restored from said data. It also gracefully handles and preserves
unrecognized attribute data, allowing older clients to interact with newer
data formats in a nondestructive manner.
"""
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, TypeVar
from efro.dataclassio._outputter import _Outputter
from efro.dataclassio._inputter import _Inputter
from efro.dataclassio._base import Codec
if TYPE_CHECKING:
from typing import Any
T = TypeVar('T')
class JsonStyle(Enum):
"""Different style types for json."""
# Single line, no spaces, no sorting. Not deterministic.
# Use this for most storage purposes.
FAST = 'fast'
# Single line, no spaces, sorted keys. Deterministic.
# Use this when output may be hashed or compared for equality.
SORTED = 'sorted'
# Multiple lines, spaces, sorted keys. Deterministic.
# Use this for pretty human readable output.
PRETTY = 'pretty'
def dataclass_to_dict(
obj: Any, codec: Codec = Codec.JSON, coerce_to_float: bool = True
) -> dict:
"""Given a dataclass object, return a json-friendly dict.
All values will be checked to ensure they match the types specified
on fields. Note that a limited set of types and data configurations is
supported.
Values with type Any will be checked to ensure they match types supported
directly by json. This does not include types such as tuples which are
implicitly translated by Python's json module (as this would break
the ability to do a lossless round-trip with data).
If coerce_to_float is True, integer values present on float typed fields
will be converted to float in the dict output. If False, a TypeError
will be triggered.
"""
out = _Outputter(
obj, create=True, codec=codec, coerce_to_float=coerce_to_float
).run()
assert isinstance(out, dict)
return out
def dataclass_to_json(
obj: Any,
coerce_to_float: bool = True,
pretty: bool = False,
sort_keys: bool | None = None,
) -> str:
"""Utility function; return a json string from a dataclass instance.
Basically json.dumps(dataclass_to_dict(...)).
By default, keys are sorted for pretty output and not otherwise, but
this can be overridden by supplying a value for the 'sort_keys' arg.
"""
import json
jdict = dataclass_to_dict(
obj=obj, coerce_to_float=coerce_to_float, codec=Codec.JSON
)
if sort_keys is None:
sort_keys = pretty
if pretty:
return json.dumps(jdict, indent=2, sort_keys=sort_keys)
return json.dumps(jdict, separators=(',', ':'), sort_keys=sort_keys)
def dataclass_from_dict(
cls: type[T],
values: dict,
codec: Codec = Codec.JSON,
coerce_to_float: bool = True,
allow_unknown_attrs: bool = True,
discard_unknown_attrs: bool = False,
) -> T:
"""Given a dict, return a dataclass of a given type.
The dict must be formatted to match the specified codec (generally
json-friendly object types). This means that sequence values such as
tuples or sets should be passed as lists, enums should be passed as their
associated values, nested dataclasses should be passed as dicts, etc.
All values are checked to ensure their types/values are valid.
Data for attributes of type Any will be checked to ensure they match
types supported directly by json. This does not include types such
as tuples which are implicitly translated by Python's json module
(as this would break the ability to do a lossless round-trip with data).
If coerce_to_float is True, int values passed for float typed fields
will be converted to float values. Otherwise, a TypeError is raised.
If allow_unknown_attrs is False, AttributeErrors will be raised for
attributes present in the dict but not on the data class. Otherwise, they
will be preserved as part of the instance and included if it is
exported back to a dict, unless discard_unknown_attrs is True, in which
case they will simply be discarded.
"""
return _Inputter(
cls,
codec=codec,
coerce_to_float=coerce_to_float,
allow_unknown_attrs=allow_unknown_attrs,
discard_unknown_attrs=discard_unknown_attrs,
).run(values)
def dataclass_from_json(
cls: type[T],
json_str: str,
coerce_to_float: bool = True,
allow_unknown_attrs: bool = True,
discard_unknown_attrs: bool = False,
) -> T:
"""Utility function; return a dataclass instance given a json string.
Basically dataclass_from_dict(json.loads(...))
"""
import json
return dataclass_from_dict(
cls=cls,
values=json.loads(json_str),
coerce_to_float=coerce_to_float,
allow_unknown_attrs=allow_unknown_attrs,
discard_unknown_attrs=discard_unknown_attrs,
)
def dataclass_validate(
obj: Any, coerce_to_float: bool = True, codec: Codec = Codec.JSON
) -> None:
"""Ensure that values in a dataclass instance are the correct types."""
# Simply run an output pass but tell it not to generate data;
# only run validation.
_Outputter(
obj, create=False, codec=codec, coerce_to_float=coerce_to_float
).run()

View file

@ -1,276 +0,0 @@
# Released under the MIT License. See LICENSE for details.
#
"""Core components of dataclassio."""
from __future__ import annotations
import dataclasses
import typing
import datetime
from enum import Enum
from typing import TYPE_CHECKING, get_args
# noinspection PyProtectedMember
from typing import _AnnotatedAlias # type: ignore
if TYPE_CHECKING:
from typing import Any, Callable
# Types which we can pass through as-is.
SIMPLE_TYPES = {int, bool, str, float, type(None)}
# Attr name for dict of extra attributes included on dataclass instances.
# Note that this is only added if extra attributes are present.
EXTRA_ATTRS_ATTR = '_DCIOEXATTRS'
def _raise_type_error(
fieldpath: str, valuetype: type, expected: tuple[type, ...]
) -> None:
"""Raise an error when a field value's type does not match expected."""
assert isinstance(expected, tuple)
assert all(isinstance(e, type) for e in expected)
if len(expected) == 1:
expected_str = expected[0].__name__
else:
expected_str = ' | '.join(t.__name__ for t in expected)
raise TypeError(
f'Invalid value type for "{fieldpath}";'
f' expected "{expected_str}", got'
f' "{valuetype.__name__}".'
)
class Codec(Enum):
"""Specifies expected data format exported to or imported from."""
# Use only types that will translate cleanly to/from json: lists,
# dicts with str keys, bools, ints, floats, and None.
JSON = 'json'
# Mostly like JSON but passes bytes and datetime objects through
# as-is instead of converting them to json-friendly types.
FIRESTORE = 'firestore'
class IOExtendedData:
"""A class that data types can inherit from for extra functionality."""
def will_output(self) -> None:
"""Called before data is sent to an outputter.
Can be overridden to validate or filter data before
sending it on its way.
"""
@classmethod
def will_input(cls, data: dict) -> None:
"""Called on raw data before a class instance is created from it.
Can be overridden to migrate old data formats to new, etc.
"""
def _is_valid_for_codec(obj: Any, codec: Codec) -> bool:
"""Return whether a value consists solely of json-supported types.
Note that this does not include things like tuples which are
implicitly translated to lists by python's json module.
"""
if obj is None:
return True
objtype = type(obj)
if objtype in (int, float, str, bool):
return True
if objtype is dict:
# JSON 'objects' supports only string dict keys, but all value types.
return all(
isinstance(k, str) and _is_valid_for_codec(v, codec)
for k, v in obj.items()
)
if objtype is list:
return all(_is_valid_for_codec(elem, codec) for elem in obj)
# A few things are valid in firestore but not json.
if issubclass(objtype, datetime.datetime) or objtype is bytes:
return codec is Codec.FIRESTORE
return False
class IOAttrs:
"""For specifying io behavior in annotations.
'storagename', if passed, is the name used when storing to json/etc.
'store_default' can be set to False to avoid writing values when equal
to the default value. Note that this requires the dataclass field
to define a default or default_factory or for its IOAttrs to
define a soft_default value.
'whole_days', if True, requires datetime values to be exactly on day
boundaries (see efro.util.utc_today()).
'whole_hours', if True, requires datetime values to lie exactly on hour
boundaries (see efro.util.utc_this_hour()).
'whole_minutes', if True, requires datetime values to lie exactly on minute
boundaries (see efro.util.utc_this_minute()).
'soft_default', if passed, injects a default value into dataclass
instantiation when the field is not present in the input data.
This allows dataclasses to add new non-optional fields while
gracefully 'upgrading' old data. Note that when a soft_default is
present it will take precedence over field defaults when determining
whether to store a value for a field with store_default=False
(since the soft_default value is what we'll get when reading that
same data back in when the field is omitted).
'soft_default_factory' is similar to 'default_factory' in dataclass
fields; it should be used instead of 'soft_default' for mutable types
such as lists to prevent a single default object from unintentionally
changing over time.
"""
# A sentinel object to detect if a parameter is supplied or not. Use
# a class to give it a better repr.
class _MissingType:
pass
MISSING = _MissingType()
storagename: str | None = None
store_default: bool = True
whole_days: bool = False
whole_hours: bool = False
whole_minutes: bool = False
soft_default: Any = MISSING
soft_default_factory: Callable[[], Any] | _MissingType = MISSING
def __init__(
self,
storagename: str | None = storagename,
store_default: bool = store_default,
whole_days: bool = whole_days,
whole_hours: bool = whole_hours,
whole_minutes: bool = whole_minutes,
soft_default: Any = MISSING,
soft_default_factory: Callable[[], Any] | _MissingType = MISSING,
):
# Only store values that differ from class defaults to keep
# our instances nice and lean.
cls = type(self)
if storagename != cls.storagename:
self.storagename = storagename
if store_default != cls.store_default:
self.store_default = store_default
if whole_days != cls.whole_days:
self.whole_days = whole_days
if whole_hours != cls.whole_hours:
self.whole_hours = whole_hours
if whole_minutes != cls.whole_minutes:
self.whole_minutes = whole_minutes
if soft_default is not cls.soft_default:
# Do what dataclasses does with its default types and
# tell the user to use factory for mutable ones.
if isinstance(soft_default, (list, dict, set)):
raise ValueError(
f'mutable {type(soft_default)} is not allowed'
f' for soft_default; use soft_default_factory.'
)
self.soft_default = soft_default
if soft_default_factory is not cls.soft_default_factory:
self.soft_default_factory = soft_default_factory
if self.soft_default is not cls.soft_default:
raise ValueError(
'Cannot set both soft_default and soft_default_factory'
)
def validate_for_field(self, cls: type, field: dataclasses.Field) -> None:
"""Ensure the IOAttrs instance is ok to use with the provided field."""
# Turning off store_default requires the field to have either
# a default or a a default_factory or for us to have soft equivalents.
if not self.store_default:
field_default_factory: Any = field.default_factory
if (
field_default_factory is dataclasses.MISSING
and field.default is dataclasses.MISSING
and self.soft_default is self.MISSING
and self.soft_default_factory is self.MISSING
):
raise TypeError(
f'Field {field.name} of {cls} has'
f' neither a default nor a default_factory'
f' and IOAttrs contains neither a soft_default'
f' nor a soft_default_factory;'
f' store_default=False cannot be set for it.'
)
def validate_datetime(
self, value: datetime.datetime, fieldpath: str
) -> None:
"""Ensure a datetime value meets our value requirements."""
if self.whole_days:
if any(
x != 0
for x in (
value.hour,
value.minute,
value.second,
value.microsecond,
)
):
raise ValueError(
f'Value {value} at {fieldpath} is not a whole day.'
)
elif self.whole_hours:
if any(
x != 0 for x in (value.minute, value.second, value.microsecond)
):
raise ValueError(
f'Value {value} at {fieldpath}' f' is not a whole hour.'
)
elif self.whole_minutes:
if any(x != 0 for x in (value.second, value.microsecond)):
raise ValueError(
f'Value {value} at {fieldpath}' f' is not a whole minute.'
)
def _get_origin(anntype: Any) -> Any:
"""Given a type annotation, return its origin or itself if there is none.
This differs from typing.get_origin in that it will never return None.
This lets us use the same code path for handling typing.List
that we do for handling list, which is good since they can be used
interchangeably in annotations.
"""
origin = typing.get_origin(anntype)
return anntype if origin is None else origin
def _parse_annotated(anntype: Any) -> tuple[Any, IOAttrs | None]:
"""Parse Annotated() constructs, returning annotated type & IOAttrs."""
# If we get an Annotated[foo, bar, eep] we take
# foo as the actual type, and we look for IOAttrs instances in
# bar/eep to affect our behavior.
ioattrs: IOAttrs | None = None
if isinstance(anntype, _AnnotatedAlias):
annargs = get_args(anntype)
for annarg in annargs[1:]:
if isinstance(annarg, IOAttrs):
if ioattrs is not None:
raise RuntimeError(
'Multiple IOAttrs instances found for a'
' single annotation; this is not supported.'
)
ioattrs = annarg
# I occasionally just throw a 'x' down when I mean IOAttrs('x');
# catch these mistakes.
elif isinstance(annarg, (str, int, float, bool)):
raise RuntimeError(
f'Raw {type(annarg)} found in Annotated[] entry:'
f' {anntype}; this is probably not what you intended.'
)
anntype = annargs[0]
return anntype, ioattrs

View file

@ -1,555 +0,0 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for dataclassio related to pulling data into dataclasses."""
# Note: We do lots of comparing of exact types here which is normally
# frowned upon (stuff like isinstance() is usually encouraged).
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
from enum import Enum
import dataclasses
import typing
import types
import datetime
from typing import TYPE_CHECKING, Generic, TypeVar
from efro.util import enum_by_value, check_utc
from efro.dataclassio._base import (
Codec,
_parse_annotated,
EXTRA_ATTRS_ATTR,
_is_valid_for_codec,
_get_origin,
SIMPLE_TYPES,
_raise_type_error,
IOExtendedData,
)
from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING:
from typing import Any
from efro.dataclassio._base import IOAttrs
from efro.dataclassio._outputter import _Outputter
T = TypeVar('T')
class _Inputter(Generic[T]):
def __init__(
self,
cls: type[T],
codec: Codec,
coerce_to_float: bool,
allow_unknown_attrs: bool = True,
discard_unknown_attrs: bool = False,
):
self._cls = cls
self._codec = codec
self._coerce_to_float = coerce_to_float
self._allow_unknown_attrs = allow_unknown_attrs
self._discard_unknown_attrs = discard_unknown_attrs
self._soft_default_validator: _Outputter | None = None
if not allow_unknown_attrs and discard_unknown_attrs:
raise ValueError(
'discard_unknown_attrs cannot be True'
' when allow_unknown_attrs is False.'
)
def run(self, values: dict) -> T:
"""Do the thing."""
# For special extended data types, call their 'will_output' callback.
tcls = self._cls
if issubclass(tcls, IOExtendedData):
tcls.will_input(values)
out = self._dataclass_from_input(self._cls, '', values)
assert isinstance(out, self._cls)
return out
def _value_from_input(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: Any,
ioattrs: IOAttrs | None,
) -> Any:
"""Convert an assigned value to what a dataclass field expects."""
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
origin = _get_origin(anntype)
if origin is typing.Any:
if not _is_valid_for_codec(value, self._codec):
raise TypeError(
f'Invalid value type for \'{fieldpath}\';'
f' \'Any\' typed values must contain only'
f' types directly supported by the specified'
f' codec ({self._codec.name}); found'
f' \'{type(value).__name__}\' which is not.'
)
return value
if origin is typing.Union or origin is types.UnionType:
# Currently, the only unions we support are None/Value
# (translated from Optional), which we verified on prep.
# So let's treat this as a simple optional case.
if value is None:
return None
childanntypes_l = [
c for c in typing.get_args(anntype) if c is not type(None)
] # noqa (pycodestyle complains about *is* with type)
assert len(childanntypes_l) == 1
return self._value_from_input(
cls, fieldpath, childanntypes_l[0], value, ioattrs
)
# Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time).
assert isinstance(origin, type)
if origin in SIMPLE_TYPES:
if type(value) is not origin:
# Special case: if they want to coerce ints to floats, do so.
if (
self._coerce_to_float
and origin is float
and type(value) is int
):
return float(value)
_raise_type_error(fieldpath, type(value), (origin,))
return value
if origin in {list, set}:
return self._sequence_from_input(
cls, fieldpath, anntype, value, origin, ioattrs
)
if origin is tuple:
return self._tuple_from_input(
cls, fieldpath, anntype, value, ioattrs
)
if origin is dict:
return self._dict_from_input(
cls, fieldpath, anntype, value, ioattrs
)
if dataclasses.is_dataclass(origin):
return self._dataclass_from_input(origin, fieldpath, value)
if issubclass(origin, Enum):
return enum_by_value(origin, value)
if issubclass(origin, datetime.datetime):
return self._datetime_from_input(cls, fieldpath, value, ioattrs)
if origin is bytes:
return self._bytes_from_input(origin, fieldpath, value)
raise TypeError(
f"Field '{fieldpath}' of type '{anntype}' is unsupported here."
)
def _bytes_from_input(self, cls: type, fieldpath: str, value: Any) -> bytes:
"""Given input data, returns bytes."""
import base64
# For firestore, bytes are passed as-is. Otherwise, they're encoded
# as base64.
if self._codec is Codec.FIRESTORE:
if not isinstance(value, bytes):
raise TypeError(
f'Expected a bytes object for {fieldpath}'
f' on {cls.__name__}; got a {type(value)}.'
)
return value
assert self._codec is Codec.JSON
if not isinstance(value, str):
raise TypeError(
f'Expected a string object for {fieldpath}'
f' on {cls.__name__}; got a {type(value)}.'
)
return base64.b64decode(value)
def _dataclass_from_input(
self, cls: type, fieldpath: str, values: dict
) -> Any:
"""Given a dict, instantiates a dataclass of the given type.
The dict must be in the json-friendly format as emitted from
dataclass_to_dict. This means that sequence values such as tuples or
sets should be passed as lists, enums should be passed as their
associated values, and nested dataclasses should be passed as dicts.
"""
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
if not isinstance(values, dict):
raise TypeError(
f'Expected a dict for {fieldpath} on {cls.__name__};'
f' got a {type(values)}.'
)
prep = PrepSession(explicit=False).prep_dataclass(
cls, recursion_level=0
)
assert prep is not None
extra_attrs = {}
# noinspection PyDataclass
fields = dataclasses.fields(cls)
fields_by_name = {f.name: f for f in fields}
# Preprocess all fields to convert Annotated[] to contained types
# and IOAttrs.
parsed_field_annotations = {
f.name: _parse_annotated(prep.annotations[f.name]) for f in fields
}
# Go through all data in the input, converting it to either dataclass
# args or extra data.
args: dict[str, Any] = {}
for rawkey, value in values.items():
key = prep.storage_names_to_attr_names.get(rawkey, rawkey)
field = fields_by_name.get(key)
# Store unknown attrs off to the side (or error if desired).
if field is None:
if self._allow_unknown_attrs:
if self._discard_unknown_attrs:
continue
# Treat this like 'Any' data; ensure that it is valid
# raw json.
if not _is_valid_for_codec(value, self._codec):
raise TypeError(
f'Unknown attr \'{key}\''
f' on {fieldpath} contains data type(s)'
f' not supported by the specified codec'
f' ({self._codec.name}).'
)
extra_attrs[key] = value
else:
raise AttributeError(
f"'{cls.__name__}' has no '{key}' field."
)
else:
fieldname = field.name
anntype, ioattrs = parsed_field_annotations[fieldname]
subfieldpath = (
f'{fieldpath}.{fieldname}' if fieldpath else fieldname
)
args[key] = self._value_from_input(
cls, subfieldpath, anntype, value, ioattrs
)
# Go through all fields looking for any not yet present in our data.
# If we find any such fields with a soft-default value or factory
# defined, inject that soft value into our args.
for key, aparsed in parsed_field_annotations.items():
if key in args:
continue
ioattrs = aparsed[1]
if ioattrs is not None and (
ioattrs.soft_default is not ioattrs.MISSING
or ioattrs.soft_default_factory is not ioattrs.MISSING
):
if ioattrs.soft_default is not ioattrs.MISSING:
soft_default = ioattrs.soft_default
else:
assert callable(ioattrs.soft_default_factory)
soft_default = ioattrs.soft_default_factory()
args[key] = soft_default
# Make sure these values are valid since we didn't run
# them through our normal input type checking.
self._type_check_soft_default(
value=soft_default,
anntype=aparsed[0],
fieldpath=(f'{fieldpath}.{key}' if fieldpath else key),
)
try:
out = cls(**args)
except Exception as exc:
raise ValueError(
f'Error instantiating class {cls.__name__}'
f' at {fieldpath}: {exc}'
) from exc
if extra_attrs:
setattr(out, EXTRA_ATTRS_ATTR, extra_attrs)
return out
def _type_check_soft_default(
self, value: Any, anntype: Any, fieldpath: str
) -> None:
from efro.dataclassio._outputter import _Outputter
# Counter-intuitively, we create an outputter as part of
# our inputter. Soft-default values are already internal types;
# we need to make sure they can go out from there.
if self._soft_default_validator is None:
self._soft_default_validator = _Outputter(
obj=None,
create=False,
codec=self._codec,
coerce_to_float=self._coerce_to_float,
)
self._soft_default_validator.soft_default_check(
value=value, anntype=anntype, fieldpath=fieldpath
)
def _dict_from_input(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: Any,
ioattrs: IOAttrs | None,
) -> Any:
# pylint: disable=too-many-branches
# pylint: disable=too-many-locals
if not isinstance(value, dict):
raise TypeError(
f'Expected a dict for \'{fieldpath}\' on {cls.__name__};'
f' got a {type(value)}.'
)
childtypes = typing.get_args(anntype)
assert len(childtypes) in (0, 2)
out: dict
# We treat 'Any' dicts simply as json; we don't do any translating.
if not childtypes or childtypes[0] is typing.Any:
if not isinstance(value, dict) or not _is_valid_for_codec(
value, self._codec
):
raise TypeError(
f'Got invalid value for Dict[Any, Any]'
f' at \'{fieldpath}\' on {cls.__name__};'
f' all keys and values must be'
f' compatible with the specified codec'
f' ({self._codec.name}).'
)
out = value
else:
out = {}
keyanntype, valanntype = childtypes
# Ok; we've got definite key/value types (which we verified as
# valid during prep). Run all keys/values through it.
# str keys we just take directly since that's supported by json.
if keyanntype is str:
for key, val in value.items():
if not isinstance(key, str):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected a str.'
)
out[key] = self._value_from_input(
cls, fieldpath, valanntype, val, ioattrs
)
# int keys are stored in json as str versions of themselves.
elif keyanntype is int:
for key, val in value.items():
if not isinstance(key, str):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected a str.'
)
try:
keyint = int(key)
except ValueError as exc:
raise TypeError(
f'Got invalid key value {key} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected an int in string form.'
) from exc
out[keyint] = self._value_from_input(
cls, fieldpath, valanntype, val, ioattrs
)
elif issubclass(keyanntype, Enum):
# In prep, we verified that all these enums' values have
# the same type, so we can just look at the first to see if
# this is a string enum or an int enum.
enumvaltype = type(next(iter(keyanntype)).value)
assert enumvaltype in (int, str)
if enumvaltype is str:
for key, val in value.items():
try:
enumval = enum_by_value(keyanntype, key)
except ValueError as exc:
raise ValueError(
f'Got invalid key value {key} for'
f' dict key at \'{fieldpath}\''
f' on {cls.__name__};'
f' expected a value corresponding to'
f' a {keyanntype}.'
) from exc
out[enumval] = self._value_from_input(
cls, fieldpath, valanntype, val, ioattrs
)
else:
for key, val in value.items():
try:
enumval = enum_by_value(keyanntype, int(key))
except (ValueError, TypeError) as exc:
raise ValueError(
f'Got invalid key value {key} for'
f' dict key at \'{fieldpath}\''
f' on {cls.__name__};'
f' expected {keyanntype} value (though'
f' in string form).'
) from exc
out[enumval] = self._value_from_input(
cls, fieldpath, valanntype, val, ioattrs
)
else:
raise RuntimeError(f'Unhandled dict in-key-type {keyanntype}')
return out
def _sequence_from_input(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: Any,
seqtype: type,
ioattrs: IOAttrs | None,
) -> Any:
# Because we are json-centric, we expect a list for all sequences.
if type(value) is not list:
raise TypeError(
f'Invalid input value for "{fieldpath}";'
f' expected a list, got a {type(value).__name__}'
)
childanntypes = typing.get_args(anntype)
# 'Any' type children; make sure they are valid json values
# and then just grab them.
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
for i, child in enumerate(value):
if not _is_valid_for_codec(child, self._codec):
raise TypeError(
f'Item {i} of {fieldpath} contains'
f' data type(s) not supported by json.'
)
return value if type(value) is seqtype else seqtype(value)
# We contain elements of some specified type.
assert len(childanntypes) == 1
childanntype = childanntypes[0]
return seqtype(
self._value_from_input(cls, fieldpath, childanntype, i, ioattrs)
for i in value
)
def _datetime_from_input(
self, cls: type, fieldpath: str, value: Any, ioattrs: IOAttrs | None
) -> Any:
# For firestore we expect a datetime object.
if self._codec is Codec.FIRESTORE:
# Don't compare exact type here, as firestore can give us
# a subclass with extended precision.
if not isinstance(value, datetime.datetime):
raise TypeError(
f'Invalid input value for "{fieldpath}" on'
f' "{cls.__name__}";'
f' expected a datetime, got a {type(value).__name__}'
)
check_utc(value)
return value
assert self._codec is Codec.JSON
# We expect a list of 7 ints.
if type(value) is not list:
raise TypeError(
f'Invalid input value for "{fieldpath}" on "{cls.__name__}";'
f' expected a list, got a {type(value).__name__}'
)
if len(value) != 7 or not all(isinstance(x, int) for x in value):
raise ValueError(
f'Invalid input value for "{fieldpath}" on "{cls.__name__}";'
f' expected a list of 7 ints, got {[type(v) for v in value]}.'
)
out = datetime.datetime( # type: ignore
*value, tzinfo=datetime.timezone.utc
)
if ioattrs is not None:
ioattrs.validate_datetime(out, fieldpath)
return out
def _tuple_from_input(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: Any,
ioattrs: IOAttrs | None,
) -> Any:
out: list = []
# Because we are json-centric, we expect a list for all sequences.
if type(value) is not list:
raise TypeError(
f'Invalid input value for "{fieldpath}";'
f' expected a list, got a {type(value).__name__}'
)
childanntypes = typing.get_args(anntype)
# We should have verified this to be non-zero at prep-time.
assert childanntypes
if len(value) != len(childanntypes):
raise ValueError(
f'Invalid tuple input for "{fieldpath}";'
f' expected {len(childanntypes)} values,'
f' found {len(value)}.'
)
for i, childanntype in enumerate(childanntypes):
childval = value[i]
# 'Any' type children; make sure they are valid json values
# and then just grab them.
if childanntype is typing.Any:
if not _is_valid_for_codec(childval, self._codec):
raise TypeError(
f'Item {i} of {fieldpath} contains'
f' data type(s) not supported by json.'
)
out.append(childval)
else:
out.append(
self._value_from_input(
cls, fieldpath, childanntype, childval, ioattrs
)
)
assert len(out) == len(childanntypes)
return tuple(out)

View file

@ -1,457 +0,0 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for dataclassio related to exporting data from dataclasses."""
# Note: We do lots of comparing of exact types here which is normally
# frowned upon (stuff like isinstance() is usually encouraged).
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
from enum import Enum
import dataclasses
import typing
import types
import datetime
from typing import TYPE_CHECKING
from efro.util import check_utc
from efro.dataclassio._base import (
Codec,
_parse_annotated,
EXTRA_ATTRS_ATTR,
_is_valid_for_codec,
_get_origin,
SIMPLE_TYPES,
_raise_type_error,
IOExtendedData,
)
from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING:
from typing import Any
from efro.dataclassio._base import IOAttrs
class _Outputter:
"""Validates or exports data contained in a dataclass instance."""
def __init__(
self, obj: Any, create: bool, codec: Codec, coerce_to_float: bool
) -> None:
self._obj = obj
self._create = create
self._codec = codec
self._coerce_to_float = coerce_to_float
def run(self) -> Any:
"""Do the thing."""
assert dataclasses.is_dataclass(self._obj)
# For special extended data types, call their 'will_output' callback.
if isinstance(self._obj, IOExtendedData):
self._obj.will_output()
return self._process_dataclass(type(self._obj), self._obj, '')
def soft_default_check(
self, value: Any, anntype: Any, fieldpath: str
) -> None:
"""(internal)"""
self._process_value(
type(value),
fieldpath=fieldpath,
anntype=anntype,
value=value,
ioattrs=None,
)
def _process_dataclass(self, cls: type, obj: Any, fieldpath: str) -> Any:
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
prep = PrepSession(explicit=False).prep_dataclass(
type(obj), recursion_level=0
)
assert prep is not None
fields = dataclasses.fields(obj)
out: dict[str, Any] | None = {} if self._create else None
for field in fields:
fieldname = field.name
if fieldpath:
subfieldpath = f'{fieldpath}.{fieldname}'
else:
subfieldpath = fieldname
anntype = prep.annotations[fieldname]
value = getattr(obj, fieldname)
anntype, ioattrs = _parse_annotated(anntype)
# If we're not storing default values for this fella,
# we can skip all output processing if we've got a default value.
if ioattrs is not None and not ioattrs.store_default:
# If both soft_defaults and regular field defaults
# are present we want to go with soft_defaults since
# those same values would be re-injected when reading
# the same data back in if we've omitted the field.
default_factory: Any = field.default_factory
if ioattrs.soft_default is not ioattrs.MISSING:
if ioattrs.soft_default == value:
continue
elif ioattrs.soft_default_factory is not ioattrs.MISSING:
assert callable(ioattrs.soft_default_factory)
if ioattrs.soft_default_factory() == value:
continue
elif field.default is not dataclasses.MISSING:
if field.default == value:
continue
elif default_factory is not dataclasses.MISSING:
if default_factory() == value:
continue
else:
raise RuntimeError(
f'Field {fieldname} of {cls.__name__} has'
f' no source of default values; store_default=False'
f' cannot be set for it. (AND THIS SHOULD HAVE BEEN'
f' CAUGHT IN PREP!)'
)
outvalue = self._process_value(
cls, subfieldpath, anntype, value, ioattrs
)
if self._create:
assert out is not None
storagename = (
fieldname
if (ioattrs is None or ioattrs.storagename is None)
else ioattrs.storagename
)
out[storagename] = outvalue
# If there's extra-attrs stored on us, check/include them.
extra_attrs = getattr(obj, EXTRA_ATTRS_ATTR, None)
if isinstance(extra_attrs, dict):
if not _is_valid_for_codec(extra_attrs, self._codec):
raise TypeError(
f'Extra attrs on \'{fieldpath}\' contains data type(s)'
f' not supported by \'{self._codec.value}\' codec:'
f' {extra_attrs}.'
)
if self._create:
assert out is not None
out.update(extra_attrs)
return out
def _process_value(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: Any,
ioattrs: IOAttrs | None,
) -> Any:
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
origin = _get_origin(anntype)
if origin is typing.Any:
if not _is_valid_for_codec(value, self._codec):
raise TypeError(
f'Invalid value type for \'{fieldpath}\';'
f" 'Any' typed values must contain types directly"
f' supported by the specified codec ({self._codec.name});'
f' found \'{type(value).__name__}\' which is not.'
)
return value if self._create else None
if origin is typing.Union or origin is types.UnionType:
# Currently, the only unions we support are None/Value
# (translated from Optional), which we verified on prep.
# So let's treat this as a simple optional case.
if value is None:
return None
childanntypes_l = [
c for c in typing.get_args(anntype) if c is not type(None)
] # noqa (pycodestyle complains about *is* with type)
assert len(childanntypes_l) == 1
return self._process_value(
cls, fieldpath, childanntypes_l[0], value, ioattrs
)
# Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time).
assert isinstance(origin, type)
# For simple flat types, look for exact matches:
if origin in SIMPLE_TYPES:
if type(value) is not origin:
# Special case: if they want to coerce ints to floats, do so.
if (
self._coerce_to_float
and origin is float
and type(value) is int
):
return float(value) if self._create else None
_raise_type_error(fieldpath, type(value), (origin,))
return value if self._create else None
if origin is tuple:
if not isinstance(value, tuple):
raise TypeError(
f'Expected a tuple for {fieldpath};'
f' found a {type(value)}'
)
childanntypes = typing.get_args(anntype)
# We should have verified this was non-zero at prep-time
assert childanntypes
if len(value) != len(childanntypes):
raise TypeError(
f'Tuple at {fieldpath} contains'
f' {len(value)} values; type specifies'
f' {len(childanntypes)}.'
)
if self._create:
return [
self._process_value(
cls, fieldpath, childanntypes[i], x, ioattrs
)
for i, x in enumerate(value)
]
for i, x in enumerate(value):
self._process_value(
cls, fieldpath, childanntypes[i], x, ioattrs
)
return None
if origin is list:
if not isinstance(value, list):
raise TypeError(
f'Expected a list for {fieldpath};'
f' found a {type(value)}'
)
childanntypes = typing.get_args(anntype)
# 'Any' type children; make sure they are valid values for
# the specified codec.
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
for i, child in enumerate(value):
if not _is_valid_for_codec(child, self._codec):
raise TypeError(
f'Item {i} of {fieldpath} contains'
f' data type(s) not supported by the specified'
f' codec ({self._codec.name}).'
)
# Hmm; should we do a copy here?
return value if self._create else None
# We contain elements of some specified type.
assert len(childanntypes) == 1
if self._create:
return [
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
for x in value
]
for x in value:
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
return None
if origin is set:
if not isinstance(value, set):
raise TypeError(
f'Expected a set for {fieldpath};' f' found a {type(value)}'
)
childanntypes = typing.get_args(anntype)
# 'Any' type children; make sure they are valid Any values.
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
for child in value:
if not _is_valid_for_codec(child, self._codec):
raise TypeError(
f'Set at {fieldpath} contains'
f' data type(s) not supported by the'
f' specified codec ({self._codec.name}).'
)
return list(value) if self._create else None
# We contain elements of some specified type.
assert len(childanntypes) == 1
if self._create:
# Note: we output json-friendly values so this becomes
# a list.
return [
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
for x in value
]
for x in value:
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
return None
if origin is dict:
return self._process_dict(cls, fieldpath, anntype, value, ioattrs)
if dataclasses.is_dataclass(origin):
if not isinstance(value, origin):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
return self._process_dataclass(cls, value, fieldpath)
if issubclass(origin, Enum):
if not isinstance(value, origin):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
# At prep-time we verified that these enums had valid value
# types, so we can blindly return it here.
return value.value if self._create else None
if issubclass(origin, datetime.datetime):
if not isinstance(value, origin):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
check_utc(value)
if ioattrs is not None:
ioattrs.validate_datetime(value, fieldpath)
if self._codec is Codec.FIRESTORE:
return value
assert self._codec is Codec.JSON
return (
[
value.year,
value.month,
value.day,
value.hour,
value.minute,
value.second,
value.microsecond,
]
if self._create
else None
)
if origin is bytes:
return self._process_bytes(cls, fieldpath, value)
raise TypeError(
f"Field '{fieldpath}' of type '{anntype}' is unsupported here."
)
def _process_bytes(self, cls: type, fieldpath: str, value: bytes) -> Any:
import base64
if not isinstance(value, bytes):
raise TypeError(
f'Expected bytes for {fieldpath} on {cls.__name__};'
f' found a {type(value)}.'
)
if not self._create:
return None
# In JSON we convert to base64, but firestore directly supports bytes.
if self._codec is Codec.JSON:
return base64.b64encode(value).decode()
assert self._codec is Codec.FIRESTORE
return value
def _process_dict(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: dict,
ioattrs: IOAttrs | None,
) -> Any:
# pylint: disable=too-many-branches
if not isinstance(value, dict):
raise TypeError(
f'Expected a dict for {fieldpath};' f' found a {type(value)}.'
)
childtypes = typing.get_args(anntype)
assert len(childtypes) in (0, 2)
# We treat 'Any' dicts simply as json; we don't do any translating.
if not childtypes or childtypes[0] is typing.Any:
if not isinstance(value, dict) or not _is_valid_for_codec(
value, self._codec
):
raise TypeError(
f'Invalid value for Dict[Any, Any]'
f' at \'{fieldpath}\' on {cls.__name__};'
f' all keys and values must be directly compatible'
f' with the specified codec ({self._codec.name})'
f' when dict type is Any.'
)
return value if self._create else None
# Ok; we've got a definite key type (which we verified as valid
# during prep). Make sure all keys match it.
out: dict | None = {} if self._create else None
keyanntype, valanntype = childtypes
# str keys we just export directly since that's supported by json.
if keyanntype is str:
for key, val in value.items():
if not isinstance(key, str):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected {keyanntype}.'
)
outval = self._process_value(
cls, fieldpath, valanntype, val, ioattrs
)
if self._create:
assert out is not None
out[key] = outval
# int keys are stored as str versions of themselves.
elif keyanntype is int:
for key, val in value.items():
if not isinstance(key, int):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected an int.'
)
outval = self._process_value(
cls, fieldpath, valanntype, val, ioattrs
)
if self._create:
assert out is not None
out[str(key)] = outval
elif issubclass(keyanntype, Enum):
for key, val in value.items():
if not isinstance(key, keyanntype):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected a {keyanntype}.'
)
outval = self._process_value(
cls, fieldpath, valanntype, val, ioattrs
)
if self._create:
assert out is not None
out[str(key.value)] = outval
else:
raise RuntimeError(f'Unhandled dict out-key-type {keyanntype}')
return out

View file

@ -1,115 +0,0 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality related to capturing nested dataclass paths."""
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING, TypeVar, Generic
from efro.dataclassio._base import _parse_annotated, _get_origin
from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING:
from typing import Any, Callable
T = TypeVar('T')
class _PathCapture:
"""Utility for obtaining dataclass storage paths in a type safe way."""
def __init__(self, obj: Any, pathparts: list[str] | None = None):
self._is_dataclass = dataclasses.is_dataclass(obj)
if pathparts is None:
pathparts = []
self._cls = obj if isinstance(obj, type) else type(obj)
self._pathparts = pathparts
def __getattr__(self, name: str) -> _PathCapture:
# We only allow diving into sub-objects if we are a dataclass.
if not self._is_dataclass:
raise TypeError(
f"Field path cannot include attribute '{name}' "
f'under parent {self._cls}; parent types must be dataclasses.'
)
prep = PrepSession(explicit=False).prep_dataclass(
self._cls, recursion_level=0
)
assert prep is not None
try:
anntype = prep.annotations[name]
except KeyError as exc:
raise AttributeError(f'{type(self)} has no {name} field.') from exc
anntype, ioattrs = _parse_annotated(anntype)
storagename = (
name
if (ioattrs is None or ioattrs.storagename is None)
else ioattrs.storagename
)
origin = _get_origin(anntype)
return _PathCapture(origin, pathparts=self._pathparts + [storagename])
@property
def path(self) -> str:
"""The final output path."""
return '.'.join(self._pathparts)
class DataclassFieldLookup(Generic[T]):
"""Get info about nested dataclass fields in type-safe way."""
def __init__(self, cls: type[T]) -> None:
self.cls = cls
def path(self, callback: Callable[[T], Any]) -> str:
"""Look up a path on child dataclass fields.
example:
DataclassFieldLookup(MyType).path(lambda obj: obj.foo.bar)
The above example will return the string 'foo.bar' or something
like 'f.b' if the dataclasses have custom storage names set.
It will also be static-type-checked, triggering an error if
MyType.foo.bar is not a valid path. Note, however, that the
callback technically allows any return value but only nested
dataclasses and their fields will succeed.
"""
# We tell the type system that we are returning an instance
# of our class, which allows it to perform type checking on
# member lookups. In reality, however, we are providing a
# special object which captures path lookups, so we can build
# a string from them.
if not TYPE_CHECKING:
out = callback(_PathCapture(self.cls))
if not isinstance(out, _PathCapture):
raise TypeError(
f'Expected a valid path under'
f' the provided object; got a {type(out)}.'
)
return out.path
return ''
def paths(self, callback: Callable[[T], list[Any]]) -> list[str]:
"""Look up multiple paths on child dataclass fields.
Functionality is identical to path() but for multiple paths at once.
example:
DataclassFieldLookup(MyType).paths(lambda obj: [obj.foo, obj.bar])
"""
outvals: list[str] = []
if not TYPE_CHECKING:
outs = callback(_PathCapture(self.cls))
assert isinstance(outs, list)
for out in outs:
if not isinstance(out, _PathCapture):
raise TypeError(
f'Expected a valid path under'
f' the provided object; got a {type(out)}.'
)
outvals.append(out.path)
return outvals

View file

@ -1,459 +0,0 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality for prepping types for use with dataclassio."""
# Note: We do lots of comparing of exact types here which is normally
# frowned upon (stuff like isinstance() is usually encouraged).
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
import logging
from enum import Enum
import dataclasses
import typing
import types
import datetime
from typing import TYPE_CHECKING, TypeVar, get_type_hints
# noinspection PyProtectedMember
from efro.dataclassio._base import _parse_annotated, _get_origin, SIMPLE_TYPES
if TYPE_CHECKING:
from typing import Any
from efro.dataclassio._base import IOAttrs
T = TypeVar('T')
# How deep we go when prepping nested types
# (basically for detecting recursive types)
MAX_RECURSION = 10
# Attr name for data we store on dataclass types that have been prepped.
PREP_ATTR = '_DCIOPREP'
# We also store the prep-session while the prep is in progress.
# (necessary to support recursive types).
PREP_SESSION_ATTR = '_DCIOPREPSESSION'
def ioprep(cls: type, globalns: dict | None = None) -> None:
"""Prep a dataclass type for use with this module's functionality.
Prepping ensures that all types contained in a data class as well as
the usage of said types are supported by this module and pre-builds
necessary constructs needed for encoding/decoding/etc.
Prepping will happen on-the-fly as needed, but a warning will be
emitted in such cases, as it is better to explicitly prep all used types
early in a process to ensure any invalid types or configuration are caught
immediately.
Prepping a dataclass involves evaluating its type annotations, which,
as of PEP 563, are stored simply as strings. This evaluation is done
with localns set to the class dict (so that types defined in the class
can be used) and globalns set to the containing module's class.
It is possible to override globalns for special cases such as when
prepping happens as part of an execed string instead of within a
module.
"""
PrepSession(explicit=True, globalns=globalns).prep_dataclass(
cls, recursion_level=0
)
def ioprepped(cls: type[T]) -> type[T]:
"""Class decorator for easily prepping a dataclass at definition time.
Note that in some cases it may not be possible to prep a dataclass
immediately (such as when its type annotations refer to forward-declared
types). In these cases, dataclass_prep() should be explicitly called for
the class as soon as possible; ideally at module import time to expose any
errors as early as possible in execution.
"""
ioprep(cls)
return cls
def will_ioprep(cls: type[T]) -> type[T]:
"""Class decorator hinting that we will prep a class later.
In some cases (such as recursive types) we cannot use the @ioprepped
decorator and must instead call ioprep() explicitly later. However,
some of our custom pylint checking behaves differently when the
@ioprepped decorator is present, in that case requiring type annotations
to be present and not simply forward declared under an "if TYPE_CHECKING"
block. (since they are used at runtime).
The @will_ioprep decorator triggers the same pylint behavior
differences as @ioprepped (which are necessary for the later ioprep() call
to work correctly) but without actually running any prep itself.
"""
return cls
def is_ioprepped_dataclass(obj: Any) -> bool:
"""Return whether the obj is an ioprepped dataclass type or instance."""
cls = obj if isinstance(obj, type) else type(obj)
return dataclasses.is_dataclass(cls) and hasattr(cls, PREP_ATTR)
@dataclasses.dataclass
class PrepData:
"""Data we prepare and cache for a class during prep.
This data is used as part of the encoding/decoding/validating process.
"""
# Resolved annotation data with 'live' classes.
annotations: dict[str, Any]
# Map of storage names to attr names.
storage_names_to_attr_names: dict[str, str]
class PrepSession:
"""Context for a prep."""
def __init__(self, explicit: bool, globalns: dict | None = None):
self.explicit = explicit
self.globalns = globalns
def prep_dataclass(
self, cls: type, recursion_level: int
) -> PrepData | None:
"""Run prep on a dataclass if necessary and return its prep data.
The only case where this will return None is for recursive types
if the type is already being prepped higher in the call order.
"""
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
# We should only need to do this once per dataclass.
existing_data = getattr(cls, PREP_ATTR, None)
if existing_data is not None:
assert isinstance(existing_data, PrepData)
return existing_data
# Sanity check.
# Note that we now support recursive types via the PREP_SESSION_ATTR,
# so we theoretically shouldn't run into this this.
if recursion_level > MAX_RECURSION:
raise RuntimeError('Max recursion exceeded.')
# We should only be passed classes which are dataclasses.
if not isinstance(cls, type) or not dataclasses.is_dataclass(cls):
raise TypeError(f'Passed arg {cls} is not a dataclass type.')
# Add a pointer to the prep-session while doing the prep.
# This way we can ignore types that we're already in the process
# of prepping and can support recursive types.
existing_prep = getattr(cls, PREP_SESSION_ATTR, None)
if existing_prep is not None:
if existing_prep is self:
return None
# We shouldn't need to support failed preps
# or preps from multiple threads at once.
raise RuntimeError('Found existing in-progress prep.')
setattr(cls, PREP_SESSION_ATTR, self)
# Generate a warning on non-explicit preps; we prefer prep to
# happen explicitly at runtime so errors can be detected early on.
if not self.explicit:
logging.warning(
'efro.dataclassio: implicitly prepping dataclass: %s.'
' It is highly recommended to explicitly prep dataclasses'
' as soon as possible after definition (via'
' efro.dataclassio.ioprep() or the'
' @efro.dataclassio.ioprepped decorator).',
cls,
)
try:
# NOTE: Now passing the class' __dict__ (vars()) as locals
# which allows us to pick up nested classes, etc.
resolved_annotations = get_type_hints(
cls,
localns=vars(cls),
globalns=self.globalns,
include_extras=True,
)
# pylint: enable=unexpected-keyword-arg
except Exception as exc:
raise TypeError(
f'dataclassio prep for {cls} failed with error: {exc}.'
f' Make sure all types used in annotations are defined'
f' at the module or class level or add them as part of an'
f' explicit prep call.'
) from exc
# noinspection PyDataclass
fields = dataclasses.fields(cls)
fields_by_name = {f.name: f for f in fields}
all_storage_names: set[str] = set()
storage_names_to_attr_names: dict[str, str] = {}
# Ok; we've resolved actual types for this dataclass.
# now recurse through them, verifying that we support all contained
# types and prepping any contained dataclass types.
for attrname, anntype in resolved_annotations.items():
anntype, ioattrs = _parse_annotated(anntype)
# If we found attached IOAttrs data, make sure it contains
# valid values for the field it is attached to.
if ioattrs is not None:
ioattrs.validate_for_field(cls, fields_by_name[attrname])
if ioattrs.storagename is not None:
storagename = ioattrs.storagename
storage_names_to_attr_names[ioattrs.storagename] = attrname
else:
storagename = attrname
else:
storagename = attrname
# Make sure we don't have any clashes in our storage names.
if storagename in all_storage_names:
raise TypeError(
f'Multiple attrs on {cls} are using'
f' storage-name \'{storagename}\''
)
all_storage_names.add(storagename)
self.prep_type(
cls,
attrname,
anntype,
ioattrs=ioattrs,
recursion_level=recursion_level + 1,
)
# Success! Store our resolved stuff with the class and we're done.
prepdata = PrepData(
annotations=resolved_annotations,
storage_names_to_attr_names=storage_names_to_attr_names,
)
setattr(cls, PREP_ATTR, prepdata)
# Clear our prep-session tag.
assert getattr(cls, PREP_SESSION_ATTR, None) is self
delattr(cls, PREP_SESSION_ATTR)
return prepdata
def prep_type(
self,
cls: type,
attrname: str,
anntype: Any,
ioattrs: IOAttrs | None,
recursion_level: int,
) -> None:
"""Run prep on a dataclass."""
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
if recursion_level > MAX_RECURSION:
raise RuntimeError('Max recursion exceeded.')
origin = _get_origin(anntype)
if origin is typing.Union or origin is types.UnionType:
self.prep_union(
cls, attrname, anntype, recursion_level=recursion_level + 1
)
return
if anntype is typing.Any:
return
# Everything below this point assumes the annotation type resolves
# to a concrete type.
if not isinstance(origin, type):
raise TypeError(
f'Unsupported type found for \'{attrname}\' on {cls}:'
f' {anntype}'
)
# If a soft_default value/factory was passed, we do some basic
# type checking on the top-level value here. We also run full
# recursive validation on values later during inputting, but this
# should catch at least some errors early on, which can be
# useful since soft_defaults are not static type checked.
if ioattrs is not None:
have_soft_default = False
soft_default: Any = None
if ioattrs.soft_default is not ioattrs.MISSING:
have_soft_default = True
soft_default = ioattrs.soft_default
elif ioattrs.soft_default_factory is not ioattrs.MISSING:
assert callable(ioattrs.soft_default_factory)
have_soft_default = True
soft_default = ioattrs.soft_default_factory()
# Do a simple type check for the top level to catch basic
# soft_default mismatches early; full check will happen at
# input time.
if have_soft_default:
if not isinstance(soft_default, origin):
raise TypeError(
f'{cls} attr {attrname} has type {origin}'
f' but soft_default value is type {type(soft_default)}'
)
if origin in SIMPLE_TYPES:
return
# For sets and lists, check out their single contained type (if any).
if origin in (list, set):
childtypes = typing.get_args(anntype)
if len(childtypes) == 0:
# This is equivalent to Any; nothing else needs checking.
return
if len(childtypes) > 1:
raise TypeError(
f'Unrecognized typing arg count {len(childtypes)}'
f" for {anntype} attr '{attrname}' on {cls}"
)
self.prep_type(
cls,
attrname,
childtypes[0],
ioattrs=None,
recursion_level=recursion_level + 1,
)
return
if origin is dict:
childtypes = typing.get_args(anntype)
assert len(childtypes) in (0, 2)
# For key types we support Any, str, int,
# and Enums with uniform str/int values.
if not childtypes or childtypes[0] is typing.Any:
# 'Any' needs no further checks (just checked per-instance).
pass
elif childtypes[0] in (str, int):
# str and int are all good as keys.
pass
elif issubclass(childtypes[0], Enum):
# Allow our usual str or int enum types as keys.
self.prep_enum(childtypes[0])
else:
raise TypeError(
f'Dict key type {childtypes[0]} for \'{attrname}\''
f' on {cls.__name__} is not supported by dataclassio.'
)
# For value types we support any of our normal types.
if not childtypes or _get_origin(childtypes[1]) is typing.Any:
# 'Any' needs no further checks (just checked per-instance).
pass
else:
self.prep_type(
cls,
attrname,
childtypes[1],
ioattrs=None,
recursion_level=recursion_level + 1,
)
return
# For Tuples, simply check individual member types.
# (and, for now, explicitly disallow zero member types or usage
# of ellipsis)
if origin is tuple:
childtypes = typing.get_args(anntype)
if not childtypes:
raise TypeError(
f'Tuple at \'{attrname}\''
f' has no type args; dataclassio requires type args.'
)
if childtypes[-1] is ...:
raise TypeError(
f'Found ellipsis as part of type for'
f' \'{attrname}\' on {cls.__name__};'
f' these are not'
f' supported by dataclassio.'
)
for childtype in childtypes:
self.prep_type(
cls,
attrname,
childtype,
ioattrs=None,
recursion_level=recursion_level + 1,
)
return
if issubclass(origin, Enum):
self.prep_enum(origin)
return
# We allow datetime objects (and google's extended subclass of them
# used in firestore, which is why we don't look for exact type here).
if issubclass(origin, datetime.datetime):
return
if dataclasses.is_dataclass(origin):
self.prep_dataclass(origin, recursion_level=recursion_level + 1)
return
if origin is bytes:
return
raise TypeError(
f"Attr '{attrname}' on {cls.__name__} contains"
f" type '{anntype}'"
f' which is not supported by dataclassio.'
)
def prep_union(
self, cls: type, attrname: str, anntype: Any, recursion_level: int
) -> None:
"""Run prep on a Union type."""
typeargs = typing.get_args(anntype)
if (
len(typeargs) != 2
or len([c for c in typeargs if c is type(None)]) != 1
): # noqa
raise TypeError(
f'Union {anntype} for attr \'{attrname}\' on'
f' {cls.__name__} is not supported by dataclassio;'
f' only 2 member Unions with one type being None'
f' are supported.'
)
for childtype in typeargs:
self.prep_type(
cls,
attrname,
childtype,
None,
recursion_level=recursion_level + 1,
)
def prep_enum(self, enumtype: type[Enum]) -> None:
"""Run prep on an enum type."""
valtype: Any = None
# We currently support enums with str or int values; fail if we
# find any others.
for enumval in enumtype:
if not isinstance(enumval.value, (str, int)):
raise TypeError(
f'Enum value {enumval} has value type'
f' {type(enumval.value)}; only str and int is'
f' supported by dataclassio.'
)
if valtype is None:
valtype = type(enumval.value)
else:
if type(enumval.value) is not valtype:
raise TypeError(
f'Enum type {enumtype} has multiple'
f' value types; dataclassio requires'
f' them to be uniform.'
)

View file

@ -1,71 +0,0 @@
# Released under the MIT License. See LICENSE for details.
#
"""Extra rarely-needed functionality related to dataclasses."""
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
def dataclass_diff(obj1: Any, obj2: Any) -> str:
"""Generate a string showing differences between two dataclass instances.
Both must be of the exact same type.
"""
diff = _diff(obj1, obj2, 2)
return ' <no differences>' if diff == '' else diff
class DataclassDiff:
"""Wraps dataclass_diff() in an object for efficiency.
It is preferable to pass this to logging calls instead of the
final diff string since the diff will never be generated if
the associated logging level is not being emitted.
"""
def __init__(self, obj1: Any, obj2: Any):
self._obj1 = obj1
self._obj2 = obj2
def __repr__(self) -> str:
return dataclass_diff(self._obj1, self._obj2)
def _diff(obj1: Any, obj2: Any, indent: int) -> str:
assert dataclasses.is_dataclass(obj1)
assert dataclasses.is_dataclass(obj2)
if type(obj1) is not type(obj2):
raise TypeError(
f'Passed objects are not of the same'
f' type ({type(obj1)} and {type(obj2)}).'
)
bits: list[str] = []
indentstr = ' ' * indent
fields = dataclasses.fields(obj1)
for field in fields:
fieldname = field.name
val1 = getattr(obj1, fieldname)
val2 = getattr(obj2, fieldname)
# For nested dataclasses, dive in and do nice piecewise compares.
if (
dataclasses.is_dataclass(val1)
and dataclasses.is_dataclass(val2)
and type(val1) is type(val2)
):
diff = _diff(val1, val2, indent + 2)
if diff != '':
bits.append(f'{indentstr}{fieldname}:')
bits.append(diff)
# For all else just do a single line
# (perhaps we could improve on this for other complex types)
else:
if val1 != val2:
bits.append(f'{indentstr}{fieldname}: {val1} -> {val2}')
return '\n'.join(bits)