vh-bombsquad-modded-server-.../dist/ba_data/python/efro/dataclassio/_base.py
2024-02-26 00:17:10 +05:30

276 lines
10 KiB
Python

# Released under the MIT License. See LICENSE for details.
#
"""Core components of dataclassio."""
from __future__ import annotations
import dataclasses
import typing
import datetime
from enum import Enum
from typing import TYPE_CHECKING, get_args
# noinspection PyProtectedMember
from typing import _AnnotatedAlias # type: ignore
if TYPE_CHECKING:
from typing import Any, Callable
# Types which we can pass through as-is.
SIMPLE_TYPES = {int, bool, str, float, type(None)}
# Attr name for dict of extra attributes included on dataclass instances.
# Note that this is only added if extra attributes are present.
EXTRA_ATTRS_ATTR = '_DCIOEXATTRS'
def _raise_type_error(
fieldpath: str, valuetype: type, expected: tuple[type, ...]
) -> None:
"""Raise an error when a field value's type does not match expected."""
assert isinstance(expected, tuple)
assert all(isinstance(e, type) for e in expected)
if len(expected) == 1:
expected_str = expected[0].__name__
else:
expected_str = ' | '.join(t.__name__ for t in expected)
raise TypeError(
f'Invalid value type for "{fieldpath}";'
f' expected "{expected_str}", got'
f' "{valuetype.__name__}".'
)
class Codec(Enum):
"""Specifies expected data format exported to or imported from."""
# Use only types that will translate cleanly to/from json: lists,
# dicts with str keys, bools, ints, floats, and None.
JSON = 'json'
# Mostly like JSON but passes bytes and datetime objects through
# as-is instead of converting them to json-friendly types.
FIRESTORE = 'firestore'
class IOExtendedData:
"""A class that data types can inherit from for extra functionality."""
def will_output(self) -> None:
"""Called before data is sent to an outputter.
Can be overridden to validate or filter data before
sending it on its way.
"""
@classmethod
def will_input(cls, data: dict) -> None:
"""Called on raw data before a class instance is created from it.
Can be overridden to migrate old data formats to new, etc.
"""
def _is_valid_for_codec(obj: Any, codec: Codec) -> bool:
"""Return whether a value consists solely of json-supported types.
Note that this does not include things like tuples which are
implicitly translated to lists by python's json module.
"""
if obj is None:
return True
objtype = type(obj)
if objtype in (int, float, str, bool):
return True
if objtype is dict:
# JSON 'objects' supports only string dict keys, but all value types.
return all(
isinstance(k, str) and _is_valid_for_codec(v, codec)
for k, v in obj.items()
)
if objtype is list:
return all(_is_valid_for_codec(elem, codec) for elem in obj)
# A few things are valid in firestore but not json.
if issubclass(objtype, datetime.datetime) or objtype is bytes:
return codec is Codec.FIRESTORE
return False
class IOAttrs:
"""For specifying io behavior in annotations.
'storagename', if passed, is the name used when storing to json/etc.
'store_default' can be set to False to avoid writing values when equal
to the default value. Note that this requires the dataclass field
to define a default or default_factory or for its IOAttrs to
define a soft_default value.
'whole_days', if True, requires datetime values to be exactly on day
boundaries (see efro.util.utc_today()).
'whole_hours', if True, requires datetime values to lie exactly on hour
boundaries (see efro.util.utc_this_hour()).
'whole_minutes', if True, requires datetime values to lie exactly on minute
boundaries (see efro.util.utc_this_minute()).
'soft_default', if passed, injects a default value into dataclass
instantiation when the field is not present in the input data.
This allows dataclasses to add new non-optional fields while
gracefully 'upgrading' old data. Note that when a soft_default is
present it will take precedence over field defaults when determining
whether to store a value for a field with store_default=False
(since the soft_default value is what we'll get when reading that
same data back in when the field is omitted).
'soft_default_factory' is similar to 'default_factory' in dataclass
fields; it should be used instead of 'soft_default' for mutable types
such as lists to prevent a single default object from unintentionally
changing over time.
"""
# A sentinel object to detect if a parameter is supplied or not. Use
# a class to give it a better repr.
class _MissingType:
pass
MISSING = _MissingType()
storagename: str | None = None
store_default: bool = True
whole_days: bool = False
whole_hours: bool = False
whole_minutes: bool = False
soft_default: Any = MISSING
soft_default_factory: Callable[[], Any] | _MissingType = MISSING
def __init__(
self,
storagename: str | None = storagename,
store_default: bool = store_default,
whole_days: bool = whole_days,
whole_hours: bool = whole_hours,
whole_minutes: bool = whole_minutes,
soft_default: Any = MISSING,
soft_default_factory: Callable[[], Any] | _MissingType = MISSING,
):
# Only store values that differ from class defaults to keep
# our instances nice and lean.
cls = type(self)
if storagename != cls.storagename:
self.storagename = storagename
if store_default != cls.store_default:
self.store_default = store_default
if whole_days != cls.whole_days:
self.whole_days = whole_days
if whole_hours != cls.whole_hours:
self.whole_hours = whole_hours
if whole_minutes != cls.whole_minutes:
self.whole_minutes = whole_minutes
if soft_default is not cls.soft_default:
# Do what dataclasses does with its default types and
# tell the user to use factory for mutable ones.
if isinstance(soft_default, (list, dict, set)):
raise ValueError(
f'mutable {type(soft_default)} is not allowed'
f' for soft_default; use soft_default_factory.'
)
self.soft_default = soft_default
if soft_default_factory is not cls.soft_default_factory:
self.soft_default_factory = soft_default_factory
if self.soft_default is not cls.soft_default:
raise ValueError(
'Cannot set both soft_default and soft_default_factory'
)
def validate_for_field(self, cls: type, field: dataclasses.Field) -> None:
"""Ensure the IOAttrs instance is ok to use with the provided field."""
# Turning off store_default requires the field to have either
# a default or a a default_factory or for us to have soft equivalents.
if not self.store_default:
field_default_factory: Any = field.default_factory
if (
field_default_factory is dataclasses.MISSING
and field.default is dataclasses.MISSING
and self.soft_default is self.MISSING
and self.soft_default_factory is self.MISSING
):
raise TypeError(
f'Field {field.name} of {cls} has'
f' neither a default nor a default_factory'
f' and IOAttrs contains neither a soft_default'
f' nor a soft_default_factory;'
f' store_default=False cannot be set for it.'
)
def validate_datetime(
self, value: datetime.datetime, fieldpath: str
) -> None:
"""Ensure a datetime value meets our value requirements."""
if self.whole_days:
if any(
x != 0
for x in (
value.hour,
value.minute,
value.second,
value.microsecond,
)
):
raise ValueError(
f'Value {value} at {fieldpath} is not a whole day.'
)
elif self.whole_hours:
if any(
x != 0 for x in (value.minute, value.second, value.microsecond)
):
raise ValueError(
f'Value {value} at {fieldpath}' f' is not a whole hour.'
)
elif self.whole_minutes:
if any(x != 0 for x in (value.second, value.microsecond)):
raise ValueError(
f'Value {value} at {fieldpath}' f' is not a whole minute.'
)
def _get_origin(anntype: Any) -> Any:
"""Given a type annotation, return its origin or itself if there is none.
This differs from typing.get_origin in that it will never return None.
This lets us use the same code path for handling typing.List
that we do for handling list, which is good since they can be used
interchangeably in annotations.
"""
origin = typing.get_origin(anntype)
return anntype if origin is None else origin
def _parse_annotated(anntype: Any) -> tuple[Any, IOAttrs | None]:
"""Parse Annotated() constructs, returning annotated type & IOAttrs."""
# If we get an Annotated[foo, bar, eep] we take
# foo as the actual type, and we look for IOAttrs instances in
# bar/eep to affect our behavior.
ioattrs: IOAttrs | None = None
if isinstance(anntype, _AnnotatedAlias):
annargs = get_args(anntype)
for annarg in annargs[1:]:
if isinstance(annarg, IOAttrs):
if ioattrs is not None:
raise RuntimeError(
'Multiple IOAttrs instances found for a'
' single annotation; this is not supported.'
)
ioattrs = annarg
# I occasionally just throw a 'x' down when I mean IOAttrs('x');
# catch these mistakes.
elif isinstance(annarg, (str, int, float, bool)):
raise RuntimeError(
f'Raw {type(annarg)} found in Annotated[] entry:'
f' {anntype}; this is probably not what you intended.'
)
anntype = annargs[0]
return anntype, ioattrs