vh-bombsquad-modded-server-.../dist/ba_data/python/efro/dataclassio/_outputter.py
2024-02-26 00:17:10 +05:30

457 lines
17 KiB
Python

# Released under the MIT License. See LICENSE for details.
#
"""Functionality for dataclassio related to exporting data from dataclasses."""
# Note: We do lots of comparing of exact types here which is normally
# frowned upon (stuff like isinstance() is usually encouraged).
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
from enum import Enum
import dataclasses
import typing
import types
import datetime
from typing import TYPE_CHECKING
from efro.util import check_utc
from efro.dataclassio._base import (
Codec,
_parse_annotated,
EXTRA_ATTRS_ATTR,
_is_valid_for_codec,
_get_origin,
SIMPLE_TYPES,
_raise_type_error,
IOExtendedData,
)
from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING:
from typing import Any
from efro.dataclassio._base import IOAttrs
class _Outputter:
"""Validates or exports data contained in a dataclass instance."""
def __init__(
self, obj: Any, create: bool, codec: Codec, coerce_to_float: bool
) -> None:
self._obj = obj
self._create = create
self._codec = codec
self._coerce_to_float = coerce_to_float
def run(self) -> Any:
"""Do the thing."""
assert dataclasses.is_dataclass(self._obj)
# For special extended data types, call their 'will_output' callback.
if isinstance(self._obj, IOExtendedData):
self._obj.will_output()
return self._process_dataclass(type(self._obj), self._obj, '')
def soft_default_check(
self, value: Any, anntype: Any, fieldpath: str
) -> None:
"""(internal)"""
self._process_value(
type(value),
fieldpath=fieldpath,
anntype=anntype,
value=value,
ioattrs=None,
)
def _process_dataclass(self, cls: type, obj: Any, fieldpath: str) -> Any:
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
prep = PrepSession(explicit=False).prep_dataclass(
type(obj), recursion_level=0
)
assert prep is not None
fields = dataclasses.fields(obj)
out: dict[str, Any] | None = {} if self._create else None
for field in fields:
fieldname = field.name
if fieldpath:
subfieldpath = f'{fieldpath}.{fieldname}'
else:
subfieldpath = fieldname
anntype = prep.annotations[fieldname]
value = getattr(obj, fieldname)
anntype, ioattrs = _parse_annotated(anntype)
# If we're not storing default values for this fella,
# we can skip all output processing if we've got a default value.
if ioattrs is not None and not ioattrs.store_default:
# If both soft_defaults and regular field defaults
# are present we want to go with soft_defaults since
# those same values would be re-injected when reading
# the same data back in if we've omitted the field.
default_factory: Any = field.default_factory
if ioattrs.soft_default is not ioattrs.MISSING:
if ioattrs.soft_default == value:
continue
elif ioattrs.soft_default_factory is not ioattrs.MISSING:
assert callable(ioattrs.soft_default_factory)
if ioattrs.soft_default_factory() == value:
continue
elif field.default is not dataclasses.MISSING:
if field.default == value:
continue
elif default_factory is not dataclasses.MISSING:
if default_factory() == value:
continue
else:
raise RuntimeError(
f'Field {fieldname} of {cls.__name__} has'
f' no source of default values; store_default=False'
f' cannot be set for it. (AND THIS SHOULD HAVE BEEN'
f' CAUGHT IN PREP!)'
)
outvalue = self._process_value(
cls, subfieldpath, anntype, value, ioattrs
)
if self._create:
assert out is not None
storagename = (
fieldname
if (ioattrs is None or ioattrs.storagename is None)
else ioattrs.storagename
)
out[storagename] = outvalue
# If there's extra-attrs stored on us, check/include them.
extra_attrs = getattr(obj, EXTRA_ATTRS_ATTR, None)
if isinstance(extra_attrs, dict):
if not _is_valid_for_codec(extra_attrs, self._codec):
raise TypeError(
f'Extra attrs on \'{fieldpath}\' contains data type(s)'
f' not supported by \'{self._codec.value}\' codec:'
f' {extra_attrs}.'
)
if self._create:
assert out is not None
out.update(extra_attrs)
return out
def _process_value(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: Any,
ioattrs: IOAttrs | None,
) -> Any:
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
origin = _get_origin(anntype)
if origin is typing.Any:
if not _is_valid_for_codec(value, self._codec):
raise TypeError(
f'Invalid value type for \'{fieldpath}\';'
f" 'Any' typed values must contain types directly"
f' supported by the specified codec ({self._codec.name});'
f' found \'{type(value).__name__}\' which is not.'
)
return value if self._create else None
if origin is typing.Union or origin is types.UnionType:
# Currently, the only unions we support are None/Value
# (translated from Optional), which we verified on prep.
# So let's treat this as a simple optional case.
if value is None:
return None
childanntypes_l = [
c for c in typing.get_args(anntype) if c is not type(None)
] # noqa (pycodestyle complains about *is* with type)
assert len(childanntypes_l) == 1
return self._process_value(
cls, fieldpath, childanntypes_l[0], value, ioattrs
)
# Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time).
assert isinstance(origin, type)
# For simple flat types, look for exact matches:
if origin in SIMPLE_TYPES:
if type(value) is not origin:
# Special case: if they want to coerce ints to floats, do so.
if (
self._coerce_to_float
and origin is float
and type(value) is int
):
return float(value) if self._create else None
_raise_type_error(fieldpath, type(value), (origin,))
return value if self._create else None
if origin is tuple:
if not isinstance(value, tuple):
raise TypeError(
f'Expected a tuple for {fieldpath};'
f' found a {type(value)}'
)
childanntypes = typing.get_args(anntype)
# We should have verified this was non-zero at prep-time
assert childanntypes
if len(value) != len(childanntypes):
raise TypeError(
f'Tuple at {fieldpath} contains'
f' {len(value)} values; type specifies'
f' {len(childanntypes)}.'
)
if self._create:
return [
self._process_value(
cls, fieldpath, childanntypes[i], x, ioattrs
)
for i, x in enumerate(value)
]
for i, x in enumerate(value):
self._process_value(
cls, fieldpath, childanntypes[i], x, ioattrs
)
return None
if origin is list:
if not isinstance(value, list):
raise TypeError(
f'Expected a list for {fieldpath};'
f' found a {type(value)}'
)
childanntypes = typing.get_args(anntype)
# 'Any' type children; make sure they are valid values for
# the specified codec.
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
for i, child in enumerate(value):
if not _is_valid_for_codec(child, self._codec):
raise TypeError(
f'Item {i} of {fieldpath} contains'
f' data type(s) not supported by the specified'
f' codec ({self._codec.name}).'
)
# Hmm; should we do a copy here?
return value if self._create else None
# We contain elements of some specified type.
assert len(childanntypes) == 1
if self._create:
return [
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
for x in value
]
for x in value:
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
return None
if origin is set:
if not isinstance(value, set):
raise TypeError(
f'Expected a set for {fieldpath};' f' found a {type(value)}'
)
childanntypes = typing.get_args(anntype)
# 'Any' type children; make sure they are valid Any values.
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
for child in value:
if not _is_valid_for_codec(child, self._codec):
raise TypeError(
f'Set at {fieldpath} contains'
f' data type(s) not supported by the'
f' specified codec ({self._codec.name}).'
)
return list(value) if self._create else None
# We contain elements of some specified type.
assert len(childanntypes) == 1
if self._create:
# Note: we output json-friendly values so this becomes
# a list.
return [
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
for x in value
]
for x in value:
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
return None
if origin is dict:
return self._process_dict(cls, fieldpath, anntype, value, ioattrs)
if dataclasses.is_dataclass(origin):
if not isinstance(value, origin):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
return self._process_dataclass(cls, value, fieldpath)
if issubclass(origin, Enum):
if not isinstance(value, origin):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
# At prep-time we verified that these enums had valid value
# types, so we can blindly return it here.
return value.value if self._create else None
if issubclass(origin, datetime.datetime):
if not isinstance(value, origin):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
check_utc(value)
if ioattrs is not None:
ioattrs.validate_datetime(value, fieldpath)
if self._codec is Codec.FIRESTORE:
return value
assert self._codec is Codec.JSON
return (
[
value.year,
value.month,
value.day,
value.hour,
value.minute,
value.second,
value.microsecond,
]
if self._create
else None
)
if origin is bytes:
return self._process_bytes(cls, fieldpath, value)
raise TypeError(
f"Field '{fieldpath}' of type '{anntype}' is unsupported here."
)
def _process_bytes(self, cls: type, fieldpath: str, value: bytes) -> Any:
import base64
if not isinstance(value, bytes):
raise TypeError(
f'Expected bytes for {fieldpath} on {cls.__name__};'
f' found a {type(value)}.'
)
if not self._create:
return None
# In JSON we convert to base64, but firestore directly supports bytes.
if self._codec is Codec.JSON:
return base64.b64encode(value).decode()
assert self._codec is Codec.FIRESTORE
return value
def _process_dict(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: dict,
ioattrs: IOAttrs | None,
) -> Any:
# pylint: disable=too-many-branches
if not isinstance(value, dict):
raise TypeError(
f'Expected a dict for {fieldpath};' f' found a {type(value)}.'
)
childtypes = typing.get_args(anntype)
assert len(childtypes) in (0, 2)
# We treat 'Any' dicts simply as json; we don't do any translating.
if not childtypes or childtypes[0] is typing.Any:
if not isinstance(value, dict) or not _is_valid_for_codec(
value, self._codec
):
raise TypeError(
f'Invalid value for Dict[Any, Any]'
f' at \'{fieldpath}\' on {cls.__name__};'
f' all keys and values must be directly compatible'
f' with the specified codec ({self._codec.name})'
f' when dict type is Any.'
)
return value if self._create else None
# Ok; we've got a definite key type (which we verified as valid
# during prep). Make sure all keys match it.
out: dict | None = {} if self._create else None
keyanntype, valanntype = childtypes
# str keys we just export directly since that's supported by json.
if keyanntype is str:
for key, val in value.items():
if not isinstance(key, str):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected {keyanntype}.'
)
outval = self._process_value(
cls, fieldpath, valanntype, val, ioattrs
)
if self._create:
assert out is not None
out[key] = outval
# int keys are stored as str versions of themselves.
elif keyanntype is int:
for key, val in value.items():
if not isinstance(key, int):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected an int.'
)
outval = self._process_value(
cls, fieldpath, valanntype, val, ioattrs
)
if self._create:
assert out is not None
out[str(key)] = outval
elif issubclass(keyanntype, Enum):
for key, val in value.items():
if not isinstance(key, keyanntype):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected a {keyanntype}.'
)
outval = self._process_value(
cls, fieldpath, valanntype, val, ioattrs
)
if self._create:
assert out is not None
out[str(key.value)] = outval
else:
raise RuntimeError(f'Unhandled dict out-key-type {keyanntype}')
return out