1.6.10 update

This commit is contained in:
Ayush Saini 2022-03-13 17:33:09 +05:30
parent 81485da646
commit 240155bce3
37 changed files with 809 additions and 350 deletions

View file

@ -3,8 +3,9 @@
"""Functionality related to the high level state of the app.""" """Functionality related to the high level state of the app."""
from __future__ import annotations from __future__ import annotations
from enum import Enum
import random import random
import logging
from enum import Enum
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import _ba import _ba
@ -184,6 +185,9 @@ class App:
self.state = self.State.LAUNCHING self.state = self.State.LAUNCHING
self._app_launched = False
self._app_paused = False
# Config. # Config.
self.config_file_healthy = False self.config_file_healthy = False
@ -348,27 +352,6 @@ class App:
for key in ('lc14173', 'lc14292'): for key in ('lc14173', 'lc14292'):
cfg.setdefault(key, launch_count) cfg.setdefault(key, launch_count)
# Debugging - make note if we're using the local test server so we
# don't accidentally leave it on in a release.
# FIXME - should move this to the native layer.
server_addr = _ba.get_master_server_address()
if 'localhost' in server_addr:
_ba.timer(2.0,
lambda: _ba.screenmessage(
'Note: using local server',
(1, 1, 0),
log=True,
),
timetype=TimeType.REAL)
elif 'test' in server_addr:
_ba.timer(2.0,
lambda: _ba.screenmessage(
'Note: using test server-module',
(1, 1, 0),
log=True,
),
timetype=TimeType.REAL)
cfg['launchCount'] = launch_count cfg['launchCount'] = launch_count
cfg.commit() cfg.commit()
@ -389,20 +372,40 @@ class App:
self.accounts.on_app_launch() self.accounts.on_app_launch()
self.plugins.on_app_launch() self.plugins.on_app_launch()
self.state = self.State.RUNNING # See note below in on_app_pause.
if self.state != self.State.LAUNCHING:
logging.error('on_app_launch found state %s; expected LAUNCHING.',
self.state)
self._app_launched = True
self._update_state()
# from ba._dependency import test_depset # from ba._dependency import test_depset
# test_depset() # test_depset()
if bool(False):
self._test_https()
def _update_state(self) -> None:
if self._app_paused:
self.state = self.State.PAUSED
else:
if self._app_launched:
self.state = self.State.RUNNING
else:
self.state = self.State.LAUNCHING
def on_app_pause(self) -> None: def on_app_pause(self) -> None:
"""Called when the app goes to a suspended state.""" """Called when the app goes to a suspended state."""
self.state = self.State.PAUSED
self._app_paused = True
self._update_state()
self.plugins.on_app_pause() self.plugins.on_app_pause()
def on_app_resume(self) -> None: def on_app_resume(self) -> None:
"""Run when the app resumes from a suspended state.""" """Run when the app resumes from a suspended state."""
self.state = self.State.RUNNING self._app_paused = False
self._update_state()
self.fg_state += 1 self.fg_state += 1
self.accounts.on_app_resume() self.accounts.on_app_resume()
self.music.on_app_resume() self.music.on_app_resume()
@ -586,6 +589,8 @@ class App:
try: try:
with urllib.request.urlopen('https://example.com') as url: with urllib.request.urlopen('https://example.com') as url:
val = url.read() val = url.read()
_ba.screenmessage('HTTPS SUCCESS!')
print('HTTPS TEST SUCCESS', len(val)) print('HTTPS TEST SUCCESS', len(val))
except Exception as exc: except Exception as exc:
_ba.screenmessage('HTTPS FAIL.')
print('HTTPS TEST FAIL:', exc) print('HTTPS TEST FAIL:', exc)

View file

@ -60,6 +60,8 @@ def setup_asyncio() -> None:
timetype=TimeType.REAL, timetype=TimeType.REAL,
repeat=True) repeat=True)
if bool(False):
async def aio_test() -> None: async def aio_test() -> None:
print('TEST AIO TASK STARTING') print('TEST AIO TASK STARTING')
assert _asyncio_event_loop is not None assert _asyncio_event_loop is not None
@ -67,5 +69,4 @@ def setup_asyncio() -> None:
await asyncio.sleep(2.0) await asyncio.sleep(2.0)
print('TEST AIO TASK ENDING') print('TEST AIO TASK ENDING')
if bool(False):
_asyncio_event_loop.create_task(aio_test()) _asyncio_event_loop.create_task(aio_test())

View file

@ -148,7 +148,7 @@ class DependencyEntry:
instance = self.cls.__new__(self.cls) instance = self.cls.__new__(self.cls)
# pylint: disable=protected-access # pylint: disable=protected-access
instance._dep_entry = weakref.ref(self) instance._dep_entry = weakref.ref(self)
instance.__init__() instance.__init__() # type: ignore
assert self.depset assert self.depset
depset = self.depset() depset = self.depset()

View file

@ -196,3 +196,4 @@ class SpecialChar(Enum):
FLAG_PHILIPPINES = 87 FLAG_PHILIPPINES = 87
FLAG_CHILE = 88 FLAG_CHILE = 88
MIKIROG = 89 MIKIROG = 89
V2_LOGO = 90

View file

@ -196,3 +196,4 @@ class SpecialChar(Enum):
FLAG_PHILIPPINES = 87 FLAG_PHILIPPINES = 87
FLAG_CHILE = 88 FLAG_CHILE = 88
MIKIROG = 89 MIKIROG = 89
V2_LOGO = 90

View file

@ -119,6 +119,11 @@ def gear_vr_controller_warning() -> None:
color=(1, 0, 0)) color=(1, 0, 0))
def uuid_str() -> str:
import uuid
return str(uuid.uuid4())
def orientation_reset_cb_message() -> None: def orientation_reset_cb_message() -> None:
from ba._language import Lstr from ba._language import Lstr
_ba.screenmessage( _ba.screenmessage(
@ -370,3 +375,13 @@ def get_player_icon(sessionplayer: ba.SessionPlayer) -> dict[str, Any]:
'tint_color': info['tint_color'], 'tint_color': info['tint_color'],
'tint2_color': info['tint2_color'] 'tint2_color': info['tint2_color']
} }
def hash_strings(inputs: list[str]) -> str:
"""Hash provided strings into a short output string."""
import hashlib
sha = hashlib.sha1()
for inp in inputs:
sha.update(inp.encode())
return sha.hexdigest()

View file

@ -87,6 +87,7 @@ class LanguageSubsystem:
'vec': 'Venetian', 'vec': 'Venetian',
'hi': 'Hindi', 'hi': 'Hindi',
'ta': 'Tamil', 'ta': 'Tamil',
'fil': 'Filipino',
} }
# Special case for Chinese: map specific variations to traditional. # Special case for Chinese: map specific variations to traditional.
@ -373,7 +374,7 @@ class Lstr:
currently-active language. currently-active language.
To see available resource keys, look at any of the bs_language_*.py files To see available resource keys, look at any of the bs_language_*.py files
in the game or the translations pages at bombsquadgame.com/translate. in the game or the translations pages at legacy.ballistica.net/translate.
# EXAMPLE 1: specify a string from a resource path # EXAMPLE 1: specify a string from a resource path
mynode.text = ba.Lstr(resource='audioSettingsWindow.titleText') mynode.text = ba.Lstr(resource='audioSettingsWindow.titleText')

View file

@ -25,8 +25,15 @@ class NetworkSubsystem:
"""Network related app subsystem.""" """Network related app subsystem."""
def __init__(self) -> None: def __init__(self) -> None:
# Anyone accessing/modifying region_pings should hold this lock.
self.region_pings_lock = threading.Lock()
self.region_pings: dict[str, float] = {} self.region_pings: dict[str, float] = {}
# For debugging.
self.v1_test_log: str = ''
self.v1_ctest_results: dict[int, str] = {}
def get_ip_address_type(addr: str) -> socket.AddressFamily: def get_ip_address_type(addr: str) -> socket.AddressFamily:
"""Return socket.AF_INET6 or socket.AF_INET4 for the provided address.""" """Return socket.AF_INET6 or socket.AF_INET4 for the provided address."""

View file

@ -38,8 +38,8 @@ class AssetType(Enum):
@dataclass @dataclass
class AssetPackageFlavorManifest: class AssetPackageFlavorManifest:
"""A manifest of asset info for a specific flavor of an asset package.""" """A manifest of asset info for a specific flavor of an asset package."""
assetfiles: Annotated[dict[str, str], cloudfiles: Annotated[dict[str, str],
IOAttrs('assetfiles')] = field(default_factory=dict) IOAttrs('cloudfiles')] = field(default_factory=dict)
@ioprepped @ioprepped

33
dist/ba_data/python/bacommon/build.py vendored Normal file
View file

@ -0,0 +1,33 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality related to game builds."""
from __future__ import annotations
import datetime
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Annotated
from efro.dataclassio import ioprepped, IOAttrs
if TYPE_CHECKING:
pass
@ioprepped
@dataclass
class BuildInfoSet:
"""Set of build infos."""
@dataclass
class Entry:
"""Info about a particular build."""
filename: Annotated[str, IOAttrs('fname')]
size: Annotated[int, IOAttrs('size')]
version: Annotated[str, IOAttrs('version')]
build_number: Annotated[int, IOAttrs('build')]
checksum: Annotated[str, IOAttrs('checksum')]
createtime: Annotated[datetime.datetime, IOAttrs('createtime')]
builds: Annotated[list[Entry],
IOAttrs('builds')] = field(default_factory=list)

View file

@ -104,7 +104,7 @@ class ServerConfig:
# if ${ACCOUNT} is present in the string, it will be replaced by the # if ${ACCOUNT} is present in the string, it will be replaced by the
# currently-signed-in account's id. To fetch info about an account, # currently-signed-in account's id. To fetch info about an account,
# your back-end server can use the following url: # your back-end server can use the following url:
# http://bombsquadgame.com/accountquery?id=ACCOUNT_ID_HERE # https://legacy.ballistica.net/accountquery?id=ACCOUNT_ID_HERE
stats_url: Optional[str] = None stats_url: Optional[str] = None
# If present, the server subprocess will attempt to gracefully exit after # If present, the server subprocess will attempt to gracefully exit after

View file

@ -39,15 +39,21 @@ class Spawner:
The spawn position. The spawn position.
""" """
def __init__(self, spawner: Spawner, data: Any, pt: Sequence[float]): def __init__(
self,
spawner: Spawner,
data: Any,
pt: Sequence[float], # pylint: disable=invalid-name
):
"""Instantiate with the given values.""" """Instantiate with the given values."""
self.spawner = spawner self.spawner = spawner
self.data = data self.data = data
self.pt = pt # pylint: disable=invalid-name self.pt = pt # pylint: disable=invalid-name
def __init__(self, def __init__(
self,
data: Any = None, data: Any = None,
pt: Sequence[float] = (0, 0, 0), pt: Sequence[float] = (0, 0, 0), # pylint: disable=invalid-name
spawn_time: float = 1.0, spawn_time: float = 1.0,
send_spawn_message: bool = True, send_spawn_message: bool = True,
spawn_callback: Callable[[], Any] = None): spawn_callback: Callable[[], Any] = None):

View file

@ -199,6 +199,7 @@ class Spaz(ba.Actor):
self.equip_boxing_gloves() self.equip_boxing_gloves()
self.last_punch_time_ms = -9999 self.last_punch_time_ms = -9999
self.last_pickup_time_ms = -9999 self.last_pickup_time_ms = -9999
self.last_jump_time_ms = -9999
self.last_run_time_ms = -9999 self.last_run_time_ms = -9999
self._last_run_value = 0.0 self._last_run_value = 0.0
self.last_bomb_time_ms = -9999 self.last_bomb_time_ms = -9999
@ -363,7 +364,11 @@ class Spaz(ba.Actor):
""" """
if not self.node: if not self.node:
return return
t_ms = ba.time(timeformat=ba.TimeFormat.MILLISECONDS)
assert isinstance(t_ms, int)
if t_ms - self.last_jump_time_ms >= self._jump_cooldown:
self.node.jump_pressed = True self.node.jump_pressed = True
self.last_jump_time_ms = t_ms
self._turbo_filter_add_press('jump') self._turbo_filter_add_press('jump')
def on_jump_release(self) -> None: def on_jump_release(self) -> None:

View file

@ -26,7 +26,7 @@ class SharedObjects:
def __init__(self) -> None: def __init__(self) -> None:
activity = ba.getactivity() activity = ba.getactivity()
if hasattr(activity, self._STORENAME): if self._STORENAME in activity.customdata:
raise RuntimeError('Use SharedObjects.get() to fetch the' raise RuntimeError('Use SharedObjects.get() to fetch the'
' shared instance for this activity.') ' shared instance for this activity.')
self._object_material: Optional[ba.Material] = None self._object_material: Optional[ba.Material] = None

View file

@ -60,7 +60,7 @@ class MainMenuActivity(ba.Activity[ba.Player, ba.Team]):
'scale': scale, 'scale': scale,
'position': (0, 10), 'position': (0, 10),
'vr_depth': -10, 'vr_depth': -10,
'text': '\xa9 2011-2021 Eric Froemling' 'text': '\xa9 2011-2022 Eric Froemling'
})) }))
# Throw up some text that only clients can see so they know that the # Throw up some text that only clients can see so they know that the

View file

@ -10,17 +10,17 @@ import ba
def show_sign_in_prompt(account_type: str = None) -> None: def show_sign_in_prompt(account_type: str = None) -> None:
"""Bring up a prompt telling the user they must sign in.""" """Bring up a prompt telling the user they must sign in."""
from bastd.ui import confirm from bastd.ui.confirm import ConfirmWindow
from bastd.ui.account import settings from bastd.ui.account import settings
if account_type == 'Google Play': if account_type == 'Google Play':
confirm.ConfirmWindow( ConfirmWindow(
ba.Lstr(resource='notSignedInGooglePlayErrorText'), ba.Lstr(resource='notSignedInGooglePlayErrorText'),
lambda: _ba.sign_in('Google Play'), lambda: _ba.sign_in('Google Play'),
ok_text=ba.Lstr(resource='accountSettingsWindow.signInText'), ok_text=ba.Lstr(resource='accountSettingsWindow.signInText'),
width=460, width=460,
height=130) height=130)
else: else:
confirm.ConfirmWindow( ConfirmWindow(
ba.Lstr(resource='notSignedInErrorText'), ba.Lstr(resource='notSignedInErrorText'),
lambda: settings.AccountSettingsWindow(modal=True, lambda: settings.AccountSettingsWindow(modal=True,
close_once_signed_in=True), close_once_signed_in=True),

View file

@ -25,6 +25,10 @@ class AccountSettingsWindow(ba.Window):
close_once_signed_in: bool = False): close_once_signed_in: bool = False):
# pylint: disable=too-many-statements # pylint: disable=too-many-statements
self._sign_in_game_circle_button: Optional[ba.Widget] = None
self._sign_in_v2_button: Optional[ba.Widget] = None
self._sign_in_device_button: Optional[ba.Widget] = None
self._close_once_signed_in = close_once_signed_in self._close_once_signed_in = close_once_signed_in
ba.set_analytics_screen('Account Window') ba.set_analytics_screen('Account Window')
@ -86,6 +90,10 @@ class AccountSettingsWindow(ba.Window):
# exceptions. # exceptions.
self._show_sign_in_buttons.append('Local') self._show_sign_in_buttons.append('Local')
# Ditto with shiny new V2 ones.
if bool(False):
self._show_sign_in_buttons.append('V2')
top_extra = 15 if uiscale is ba.UIScale.SMALL else 0 top_extra = 15 if uiscale is ba.UIScale.SMALL else 0
super().__init__(root_widget=ba.containerwidget( super().__init__(root_widget=ba.containerwidget(
size=(self._width, self._height + top_extra), size=(self._width, self._height + top_extra),
@ -206,12 +214,10 @@ class AccountSettingsWindow(ba.Window):
show_game_circle_sign_in_button = (account_state == 'signed_out' show_game_circle_sign_in_button = (account_state == 'signed_out'
and 'Game Circle' and 'Game Circle'
in self._show_sign_in_buttons) in self._show_sign_in_buttons)
show_ali_sign_in_button = (account_state == 'signed_out'
and 'Ali' in self._show_sign_in_buttons)
show_test_sign_in_button = (account_state == 'signed_out'
and 'Test' in self._show_sign_in_buttons)
show_device_sign_in_button = (account_state == 'signed_out' and 'Local' show_device_sign_in_button = (account_state == 'signed_out' and 'Local'
in self._show_sign_in_buttons) in self._show_sign_in_buttons)
show_v2_sign_in_button = (account_state == 'signed_out'
and 'V2' in self._show_sign_in_buttons)
sign_in_button_space = 70.0 sign_in_button_space = 70.0
show_game_service_button = (self._signed_in and account_type show_game_service_button = (self._signed_in and account_type
@ -223,9 +229,9 @@ class AccountSettingsWindow(ba.Window):
'allowAccountLinking2', False)) 'allowAccountLinking2', False))
linked_accounts_text_space = 60.0 linked_accounts_text_space = 60.0
show_achievements_button = (self._signed_in and account_type show_achievements_button = (
in ('Google Play', 'Alibaba', 'Local', self._signed_in
'OUYA', 'Test')) and account_type in ('Google Play', 'Alibaba', 'Local', 'OUYA'))
achievements_button_space = 60.0 achievements_button_space = 60.0
show_achievements_text = (self._signed_in show_achievements_text = (self._signed_in
@ -255,8 +261,8 @@ class AccountSettingsWindow(ba.Window):
show_unlink_accounts_button = show_link_accounts_button show_unlink_accounts_button = show_link_accounts_button
unlink_accounts_button_space = 90.0 unlink_accounts_button_space = 90.0
show_sign_out_button = (self._signed_in and account_type show_sign_out_button = (self._signed_in
in ['Test', 'Local', 'Google Play']) and account_type in ['Local', 'Google Play'])
sign_out_button_space = 70.0 sign_out_button_space = 70.0
if self._subcontainer is not None: if self._subcontainer is not None:
@ -272,12 +278,10 @@ class AccountSettingsWindow(ba.Window):
self._sub_height += sign_in_button_space self._sub_height += sign_in_button_space
if show_game_circle_sign_in_button: if show_game_circle_sign_in_button:
self._sub_height += sign_in_button_space self._sub_height += sign_in_button_space
if show_ali_sign_in_button:
self._sub_height += sign_in_button_space
if show_test_sign_in_button:
self._sub_height += sign_in_button_space
if show_device_sign_in_button: if show_device_sign_in_button:
self._sub_height += sign_in_button_space self._sub_height += sign_in_button_space
if show_v2_sign_in_button:
self._sub_height += sign_in_button_space
if show_game_service_button: if show_game_service_button:
self._sub_height += game_service_button_space self._sub_height += game_service_button_space
if show_linked_accounts_text: if show_linked_accounts_text:
@ -462,21 +466,42 @@ class AccountSettingsWindow(ba.Window):
ba.widget(edit=btn, show_buffer_bottom=40, show_buffer_top=100) ba.widget(edit=btn, show_buffer_bottom=40, show_buffer_top=100)
self._sign_in_text = None self._sign_in_text = None
if show_ali_sign_in_button: if show_v2_sign_in_button:
button_width = 350 button_width = 350
v -= sign_in_button_space v -= sign_in_button_space
self._sign_in_ali_button = btn = ba.buttonwidget( self._sign_in_v2_button = btn = ba.buttonwidget(
parent=self._subcontainer, parent=self._subcontainer,
position=((self._sub_width - button_width) * 0.5, v - 20), position=((self._sub_width - button_width) * 0.5, v - 20),
autoselect=True, autoselect=True,
size=(button_width, 60), size=(button_width, 60),
label=ba.Lstr(value='${A}${B}', label='',
subs=[('${A}', on_activate_call=self._v2_sign_in_press)
ba.charstr(ba.SpecialChar.ALIBABA_LOGO)), ba.textwidget(
parent=self._subcontainer,
draw_controller=btn,
h_align='center',
v_align='center',
size=(0, 0),
position=(self._sub_width * 0.5, v + 17),
text=ba.Lstr(
value='${A}${B}',
subs=[('${A}', ba.charstr(ba.SpecialChar.V2_LOGO)),
('${B}', ('${B}',
ba.Lstr(resource=self._r + '.signInText')) ba.Lstr(resource=self._r + '.signInWithV2Text'))]),
]), maxwidth=button_width * 0.8,
on_activate_call=lambda: self._sign_in_press('Ali')) color=(0.75, 1.0, 0.7))
ba.textwidget(parent=self._subcontainer,
draw_controller=btn,
h_align='center',
v_align='center',
size=(0, 0),
position=(self._sub_width * 0.5, v - 4),
text=ba.Lstr(resource=self._r +
'.signInWithV2InfoText'),
flatness=1.0,
scale=0.57,
maxwidth=button_width * 0.9,
color=(0.55, 0.8, 0.5))
if first_selectable is None: if first_selectable is None:
first_selectable = btn first_selectable = btn
if ba.app.ui.use_toolbars: if ba.app.ui.use_toolbars:
@ -532,53 +557,6 @@ class AccountSettingsWindow(ba.Window):
ba.widget(edit=btn, show_buffer_bottom=40, show_buffer_top=100) ba.widget(edit=btn, show_buffer_bottom=40, show_buffer_top=100)
self._sign_in_text = None self._sign_in_text = None
# Old test-account option.
if show_test_sign_in_button:
button_width = 350
v -= sign_in_button_space
self._sign_in_test_button = btn = ba.buttonwidget(
parent=self._subcontainer,
position=((self._sub_width - button_width) * 0.5, v - 20),
autoselect=True,
size=(button_width, 60),
label='',
on_activate_call=lambda: self._sign_in_press('Test'))
ba.textwidget(parent=self._subcontainer,
draw_controller=btn,
h_align='center',
v_align='center',
size=(0, 0),
position=(self._sub_width * 0.5, v + 17),
text=ba.Lstr(
value='${A}${B}',
subs=[('${A}',
ba.charstr(ba.SpecialChar.TEST_ACCOUNT)),
('${B}',
ba.Lstr(resource=self._r +
'.signInWithTestAccountText'))]),
maxwidth=button_width * 0.8,
color=(0.75, 1.0, 0.7))
ba.textwidget(parent=self._subcontainer,
draw_controller=btn,
h_align='center',
v_align='center',
size=(0, 0),
position=(self._sub_width * 0.5, v - 4),
text=ba.Lstr(resource=self._r +
'.signInWithTestAccountInfoText'),
flatness=1.0,
scale=0.57,
maxwidth=button_width * 0.9,
color=(0.55, 0.8, 0.5))
if first_selectable is None:
first_selectable = btn
if ba.app.ui.use_toolbars:
ba.widget(edit=btn,
right_widget=_ba.get_special_widget('party_button'))
ba.widget(edit=btn, left_widget=bbtn)
ba.widget(edit=btn, show_buffer_bottom=40, show_buffer_top=100)
self._sign_in_text = None
if show_player_profiles_button: if show_player_profiles_button:
button_width = 300 button_width = 300
v -= player_profiles_button_space v -= player_profiles_button_space
@ -1051,6 +1029,12 @@ class AccountSettingsWindow(ba.Window):
self._needs_refresh = True self._needs_refresh = True
ba.timer(0.1, ba.WeakCall(self._update), timetype=ba.TimeType.REAL) ba.timer(0.1, ba.WeakCall(self._update), timetype=ba.TimeType.REAL)
def _v2_sign_in_press(self) -> None:
# pylint: disable=cyclic-import
from bastd.ui.account.v2 import V2SignInWindow
assert self._sign_in_v2_button is not None
V2SignInWindow(origin_widget=self._sign_in_v2_button)
def _reset_progress(self) -> None: def _reset_progress(self) -> None:
try: try:
from ba.internal import getcampaign from ba.internal import getcampaign

View file

@ -0,0 +1,92 @@
# Released under the MIT License. See LICENSE for details.
#
"""V2 account ui bits."""
from __future__ import annotations
from typing import TYPE_CHECKING
import ba
import _ba
if TYPE_CHECKING:
from typing import Any, Optional
class V2SignInWindow(ba.Window):
"""A window allowing signing in to a v2 account."""
def __init__(self, origin_widget: ba.Widget):
from ba.internal import is_browser_likely_available
logincode = '1412345'
address = (
f'{_ba.get_master_server_address(version=2)}?login={logincode}')
self._width = 600
self._height = 500
uiscale = ba.app.ui.uiscale
super().__init__(root_widget=ba.containerwidget(
size=(self._width, self._height),
transition='in_scale',
scale_origin_stack_offset=origin_widget.get_screen_space_center(),
scale=(1.25 if uiscale is ba.UIScale.SMALL else
1.0 if uiscale is ba.UIScale.MEDIUM else 0.85)))
ba.textwidget(
parent=self._root_widget,
position=(self._width * 0.5, self._height - 85),
size=(0, 0),
text=ba.Lstr(
resource='accountSettingsWindow.v2LinkInstructionsText'),
color=ba.app.ui.title_color,
maxwidth=self._width * 0.9,
h_align='center',
v_align='center')
button_width = 450
if is_browser_likely_available():
ba.buttonwidget(parent=self._root_widget,
position=((self._width * 0.5 - button_width * 0.5),
self._height - 175),
autoselect=True,
size=(button_width, 60),
label=ba.Lstr(value=address),
color=(0.55, 0.5, 0.6),
textcolor=(0.75, 0.7, 0.8),
on_activate_call=lambda: ba.open_url(address))
qroffs = 0.0
else:
ba.textwidget(parent=self._root_widget,
position=(self._width * 0.5, self._height - 135),
size=(0, 0),
text=ba.Lstr(value=address),
flatness=1.0,
maxwidth=self._width,
scale=0.75,
h_align='center',
v_align='center')
qroffs = 20.0
self._cancel_button = ba.buttonwidget(
parent=self._root_widget,
position=(30, self._height - 55),
size=(130, 50),
scale=0.8,
label=ba.Lstr(resource='cancelText'),
# color=(0.6, 0.5, 0.6),
on_activate_call=self._done,
autoselect=True,
textcolor=(0.75, 0.7, 0.8),
# icon=ba.gettexture('crossOut'),
# iconscale=1.2
)
ba.containerwidget(edit=self._root_widget,
cancel_button=self._cancel_button)
qr_size = 270
ba.imagewidget(parent=self._root_widget,
position=(self._width * 0.5 - qr_size * 0.5,
self._height * 0.34 + qroffs - qr_size * 0.5),
size=(qr_size, qr_size),
texture=_ba.get_qrcode_texture(address))
def _done(self) -> None:
ba.containerwidget(edit=self._root_widget, transition='out_scale')

View file

@ -1209,6 +1209,7 @@ class PublicGatherTab(GatherTab):
def on_public_party_activate(self, party: PartyEntry) -> None: def on_public_party_activate(self, party: PartyEntry) -> None:
"""Called when a party is clicked or otherwise activated.""" """Called when a party is clicked or otherwise activated."""
self.save_state()
if party.queue is not None: if party.queue is not None:
from bastd.ui.partyqueue import PartyQueueWindow from bastd.ui.partyqueue import PartyQueueWindow
ba.playsound(ba.getsound('swish')) ba.playsound(ba.getsound('swish'))

View file

@ -5,7 +5,6 @@
from __future__ import annotations from __future__ import annotations
import math import math
import weakref
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, cast
import _ba import _ba
@ -413,53 +412,3 @@ class PartyWindow(ba.Window):
"""Close the window and make a lovely sound.""" """Close the window and make a lovely sound."""
ba.playsound(ba.getsound('swish')) ba.playsound(ba.getsound('swish'))
self.close() self.close()
def handle_party_invite(name: str, invite_id: str) -> None:
"""Handle an incoming party invitation."""
from bastd import mainmenu
from bastd.ui import confirm
ba.playsound(ba.getsound('fanfare'))
# if we're not in the main menu, just print the invite
# (don't want to screw up an in-progress game)
in_game = not isinstance(_ba.get_foreground_host_session(),
mainmenu.MainMenuSession)
if in_game:
ba.screenmessage(ba.Lstr(
value='${A}\n${B}',
subs=[('${A}',
ba.Lstr(resource='gatherWindow.partyInviteText',
subs=[('${NAME}', name)])),
('${B}',
ba.Lstr(
resource='gatherWindow.partyInviteGooglePlayExtraText'))
]),
color=(0.5, 1, 0))
else:
def do_accept(inv_id: str) -> None:
_ba.accept_party_invitation(inv_id)
conf = confirm.ConfirmWindow(
ba.Lstr(resource='gatherWindow.partyInviteText',
subs=[('${NAME}', name)]),
ba.Call(do_accept, invite_id),
width=500,
height=150,
color=(0.75, 1.0, 0.0),
ok_text=ba.Lstr(resource='gatherWindow.partyInviteAcceptText'),
cancel_text=ba.Lstr(resource='gatherWindow.partyInviteIgnoreText'))
# FIXME: Ugly.
# Let's store the invite-id away on the confirm window so we know if
# we need to kill it later.
conf.party_invite_id = invite_id # type: ignore
# store a weak-ref so we can get at this later
ba.app.invite_confirm_windows.append(weakref.ref(conf))
# go ahead and prune our weak refs while we're here.
ba.app.invite_confirm_windows = [
w for w in ba.app.invite_confirm_windows if w() is not None
]

View file

@ -322,8 +322,8 @@ class AdvancedSettingsWindow(ba.Window):
subs=[('${APP_NAME}', ba.Lstr(resource='titleText')) subs=[('${APP_NAME}', ba.Lstr(resource='titleText'))
]), ]),
autoselect=True, autoselect=True,
on_activate_call=ba.Call(ba.open_url, on_activate_call=ba.Call(
'http://bombsquadgame.com/translate')) ba.open_url, 'https://legacy.ballistica.net/translate'))
self._lang_status_text = ba.textwidget(parent=self._subcontainer, self._lang_status_text = ba.textwidget(parent=self._subcontainer,
position=(self._sub_width * 0.5, position=(self._sub_width * 0.5,

View file

@ -91,21 +91,6 @@ class ControlsSettingsWindow(ba.Window):
else: else:
show_remote = False show_remote = False
show_ps3 = False
# if platform == 'mac':
# show_ps3 = True
# height += spacing
show360 = False
# if platform == 'mac' or is_fire_tv:
# show360 = True
# height += spacing
show_mac_wiimote = False
# if platform == 'mac' and _ba.is_xcode_build():
# show_mac_wiimote = True
# height += spacing
# On windows (outside of oculus/vr), show an option to disable xinput. # On windows (outside of oculus/vr), show an option to disable xinput.
show_xinput_toggle = False show_xinput_toggle = False
if platform == 'windows' and not app.vr_mode: if platform == 'windows' and not app.vr_mode:
@ -152,9 +137,6 @@ class ControlsSettingsWindow(ba.Window):
self._keyboard_button: Optional[ba.Widget] = None self._keyboard_button: Optional[ba.Widget] = None
self._keyboard_2_button: Optional[ba.Widget] = None self._keyboard_2_button: Optional[ba.Widget] = None
self._idevices_button: Optional[ba.Widget] = None self._idevices_button: Optional[ba.Widget] = None
self._ps3_button: Optional[ba.Widget] = None
self._xbox_360_button: Optional[ba.Widget] = None
self._wiimotes_button: Optional[ba.Widget] = None
ba.textwidget(parent=self._root_widget, ba.textwidget(parent=self._root_widget,
position=(0, height - 49), position=(0, height - 49),
@ -261,42 +243,6 @@ class ControlsSettingsWindow(ba.Window):
down_widget=self._idevices_button) down_widget=self._idevices_button)
self._have_selected_child = True self._have_selected_child = True
v -= spacing v -= spacing
if show_ps3:
self._ps3_button = btn = ba.buttonwidget(
parent=self._root_widget,
position=((width - button_width) / 2 + 5, v),
size=(button_width, 43),
autoselect=True,
label=ba.Lstr(resource=self._r + '.ps3Text'),
on_activate_call=self._do_ps3_controllers)
if ba.app.ui.use_toolbars:
ba.widget(edit=btn,
right_widget=_ba.get_special_widget('party_button'))
v -= spacing
if show360:
self._xbox_360_button = btn = ba.buttonwidget(
parent=self._root_widget,
position=((width - button_width) / 2 - 1, v),
size=(button_width, 43),
autoselect=True,
label=ba.Lstr(resource=self._r + '.xbox360Text'),
on_activate_call=self._do_360_controllers)
if ba.app.ui.use_toolbars:
ba.widget(edit=btn,
right_widget=_ba.get_special_widget('party_button'))
v -= spacing
if show_mac_wiimote:
self._wiimotes_button = btn = ba.buttonwidget(
parent=self._root_widget,
position=((width - button_width) / 2 + 5, v),
size=(button_width, 43),
autoselect=True,
label=ba.Lstr(resource=self._r + '.wiimotesText'),
on_activate_call=self._do_wiimotes)
if ba.app.ui.use_toolbars:
ba.widget(edit=btn,
right_widget=_ba.get_special_widget('party_button'))
v -= spacing
if show_xinput_toggle: if show_xinput_toggle:
@ -397,31 +343,6 @@ class ControlsSettingsWindow(ba.Window):
ba.app.ui.set_main_menu_window( ba.app.ui.set_main_menu_window(
RemoteAppSettingsWindow().get_root_widget()) RemoteAppSettingsWindow().get_root_widget())
def _do_ps3_controllers(self) -> None:
# pylint: disable=cyclic-import
from bastd.ui.settings.ps3controller import PS3ControllerSettingsWindow
self._save_state()
ba.containerwidget(edit=self._root_widget, transition='out_left')
ba.app.ui.set_main_menu_window(
PS3ControllerSettingsWindow().get_root_widget())
def _do_360_controllers(self) -> None:
# pylint: disable=cyclic-import
from bastd.ui.settings.xbox360controller import (
XBox360ControllerSettingsWindow)
self._save_state()
ba.containerwidget(edit=self._root_widget, transition='out_left')
ba.app.ui.set_main_menu_window(
XBox360ControllerSettingsWindow().get_root_widget())
def _do_wiimotes(self) -> None:
# pylint: disable=cyclic-import
from bastd.ui.settings.wiimote import WiimoteSettingsWindow
self._save_state()
ba.containerwidget(edit=self._root_widget, transition='out_left')
ba.app.ui.set_main_menu_window(
WiimoteSettingsWindow().get_root_widget())
def _do_gamepads(self) -> None: def _do_gamepads(self) -> None:
# pylint: disable=cyclic-import # pylint: disable=cyclic-import
from bastd.ui.settings.gamepadselect import GamepadSelectWindow from bastd.ui.settings.gamepadselect import GamepadSelectWindow
@ -449,12 +370,6 @@ class ControlsSettingsWindow(ba.Window):
sel_name = 'Keyboard2' sel_name = 'Keyboard2'
elif sel == self._idevices_button: elif sel == self._idevices_button:
sel_name = 'iDevices' sel_name = 'iDevices'
elif sel == self._ps3_button:
sel_name = 'PS3'
elif sel == self._xbox_360_button:
sel_name = 'xbox360'
elif sel == self._wiimotes_button:
sel_name = 'Wiimotes'
else: else:
sel_name = 'Back' sel_name = 'Back'
ba.app.ui.window_states[type(self)] = sel_name ba.app.ui.window_states[type(self)] = sel_name
@ -471,12 +386,6 @@ class ControlsSettingsWindow(ba.Window):
sel = self._keyboard_2_button sel = self._keyboard_2_button
elif sel_name == 'iDevices': elif sel_name == 'iDevices':
sel = self._idevices_button sel = self._idevices_button
elif sel_name == 'PS3':
sel = self._ps3_button
elif sel_name == 'xbox360':
sel = self._xbox_360_button
elif sel_name == 'Wiimotes':
sel = self._wiimotes_button
elif sel_name == 'Back': elif sel_name == 'Back':
sel = self._back_button sel = self._back_button
else: else:

View file

@ -4,11 +4,296 @@
from __future__ import annotations from __future__ import annotations
import time
import copy
import weakref
from threading import Thread
from typing import TYPE_CHECKING
import _ba
import ba import ba
from bastd.ui.settings import testing from bastd.ui.settings.testing import TestingWindow
if TYPE_CHECKING:
from typing import Callable, Any, Optional
class NetTestingWindow(testing.TestingWindow): class NetTestingWindow(ba.Window):
"""Window that runs a networking test suite to help diagnose issues."""
def __init__(self, transition: str = 'in_right'):
self._width = 820
self._height = 500
self._printed_lines: list[str] = []
uiscale = ba.app.ui.uiscale
super().__init__(root_widget=ba.containerwidget(
size=(self._width, self._height),
scale=(1.56 if uiscale is ba.UIScale.SMALL else
1.2 if uiscale is ba.UIScale.MEDIUM else 0.8),
stack_offset=(0.0, -7 if uiscale is ba.UIScale.SMALL else 0.0),
transition=transition))
self._done_button = ba.buttonwidget(parent=self._root_widget,
position=(40, self._height - 77),
size=(120, 60),
scale=0.8,
autoselect=True,
label=ba.Lstr(resource='doneText'),
on_activate_call=self._done)
self._copy_button = ba.buttonwidget(parent=self._root_widget,
position=(self._width - 200,
self._height - 77),
size=(100, 60),
scale=0.8,
autoselect=True,
label=ba.Lstr(resource='copyText'),
on_activate_call=self._copy)
self._settings_button = ba.buttonwidget(
parent=self._root_widget,
position=(self._width - 100, self._height - 77),
size=(60, 60),
scale=0.8,
autoselect=True,
label=ba.Lstr(value='...'),
on_activate_call=self._show_val_testing)
twidth = self._width - 450
ba.textwidget(
parent=self._root_widget,
position=(self._width * 0.5, self._height - 55),
size=(0, 0),
text=ba.Lstr(resource='settingsWindowAdvanced.netTestingText'),
color=(0.8, 0.8, 0.8, 1.0),
h_align='center',
v_align='center',
maxwidth=twidth)
self._scroll = ba.scrollwidget(parent=self._root_widget,
position=(50, 50),
size=(self._width - 100,
self._height - 140),
capture_arrows=True,
autoselect=True)
self._rows = ba.columnwidget(parent=self._scroll)
ba.containerwidget(edit=self._root_widget,
cancel_button=self._done_button)
# Now kick off the tests.
# Pass a weak-ref to this window so we don't keep it alive
# if we back out before it completes. Also set is as daemon
# so it doesn't keep the app running if the user is trying to quit.
Thread(
daemon=True,
target=ba.Call(_run_diagnostics, weakref.ref(self)),
).start()
def print(self, text: str, color: tuple[float, float, float]) -> None:
"""Print text to our console thingie."""
for line in text.splitlines():
txt = ba.textwidget(parent=self._rows,
color=color,
text=line,
scale=0.75,
flatness=1.0,
shadow=0.0,
size=(0, 20))
ba.containerwidget(edit=self._rows, visible_child=txt)
self._printed_lines.append(line)
def _copy(self) -> None:
if not ba.clipboard_is_supported():
ba.screenmessage('Clipboard not supported on this platform.',
color=(1, 0, 0))
return
ba.clipboard_set_text('\n'.join(self._printed_lines))
ba.screenmessage(f'{len(self._printed_lines)} lines copied.')
def _show_val_testing(self) -> None:
ba.app.ui.set_main_menu_window(NetValTestingWindow().get_root_widget())
ba.containerwidget(edit=self._root_widget, transition='out_left')
def _done(self) -> None:
# pylint: disable=cyclic-import
from bastd.ui.settings.advanced import AdvancedSettingsWindow
ba.app.ui.set_main_menu_window(
AdvancedSettingsWindow(transition='in_left').get_root_widget())
ba.containerwidget(edit=self._root_widget, transition='out_right')
def _run_diagnostics(weakwin: weakref.ref[NetTestingWindow]) -> None:
# pylint: disable=too-many-statements
from efro.util import utc_now
have_error = [False]
# We're running in a background thread but UI stuff needs to run
# in the logic thread; give ourself a way to pass stuff to it.
def _print(text: str, color: tuple[float, float, float] = None) -> None:
def _print_in_logic_thread() -> None:
win = weakwin()
if win is not None:
win.print(text, (1.0, 1.0, 1.0) if color is None else color)
ba.pushcall(_print_in_logic_thread, from_other_thread=True)
def _print_test_results(call: Callable[[], Any]) -> None:
"""Run the provided call; return success/fail text & color."""
starttime = time.monotonic()
try:
call()
duration = time.monotonic() - starttime
_print(f'Succeeded in {duration:.2f}s.', color=(0, 1, 0))
except Exception:
import traceback
duration = time.monotonic() - starttime
_print(traceback.format_exc(), color=(1.0, 1.0, 0.3))
_print(f'Failed in {duration:.2f}s.', color=(1, 0, 0))
have_error[0] = True
try:
_print(f'Running network diagnostics...\n'
f'ua: {_ba.app.user_agent_string}\n'
f'time: {utc_now()}.')
if bool(False):
_print('\nRunning dummy success test...')
_print_test_results(_dummy_success)
_print('\nRunning dummy fail test...')
_print_test_results(_dummy_fail)
# V1 ping
baseaddr = _ba.get_master_server_address(source=0, version=1)
_print(f'\nContacting V1 master-server src0 ({baseaddr})...')
_print_test_results(lambda: _test_fetch(baseaddr))
# V1 alternate ping
baseaddr = _ba.get_master_server_address(source=1, version=1)
_print(f'\nContacting V1 master-server src1 ({baseaddr})...')
_print_test_results(lambda: _test_fetch(baseaddr))
_print(f'\nV1-test-log: {ba.app.net.v1_test_log}')
for srcid, result in sorted(ba.app.net.v1_ctest_results.items()):
_print(f'\nV1 src{srcid} result: {result}')
curv1addr = _ba.get_master_server_address(version=1)
_print(f'\nUsing V1 address: {curv1addr}')
_print('\nRunning V1 transaction...')
_print_test_results(_test_v1_transaction)
# V2 ping
baseaddr = _ba.get_master_server_address(version=2)
_print(f'\nContacting V2 master-server ({baseaddr})...')
_print_test_results(lambda: _test_fetch(baseaddr))
# Get V2 nearby region
with ba.app.net.region_pings_lock:
region_pings = copy.deepcopy(ba.app.net.region_pings)
nearest_region = (None if not region_pings else sorted(
region_pings.items(), key=lambda i: i[1])[0])
if nearest_region is not None:
nearstr = f'{nearest_region[0]}: {nearest_region[1]:.0f}ms'
else:
nearstr = '-'
_print(f'\nChecking nearest V2 region ping ({nearstr})...')
_print_test_results(lambda: _test_nearby_region_ping(nearest_region))
if have_error[0]:
_print('\nDiagnostics complete. Some diagnostics failed.',
color=(10, 0, 0))
else:
_print('\nDiagnostics complete. Everything looks good!',
color=(0, 1, 0))
except Exception:
import traceback
_print(
f'An unexpected error occurred during testing;'
f' please report this.\n'
f'{traceback.format_exc()}',
color=(1, 0, 0))
def _dummy_success() -> None:
"""Dummy success test."""
time.sleep(1.2)
def _dummy_fail() -> None:
"""Dummy fail test case."""
raise RuntimeError('fail-test')
def _test_v1_transaction() -> None:
"""Dummy fail test case."""
if _ba.get_account_state() != 'signed_in':
raise RuntimeError('Not signed in.')
starttime = time.monotonic()
# Gets set to True on success or string on error.
results: list[Any] = [False]
def _cb(cbresults: Any) -> None:
# Simply set results here; our other thread acts on them.
if not isinstance(cbresults, dict) or 'party_code' not in cbresults:
results[0] = 'Unexpected transaction response'
return
results[0] = True # Success!
def _do_it() -> None:
# Fire off a transaction with a callback.
_ba.add_transaction(
{
'type': 'PRIVATE_PARTY_QUERY',
'expire_time': time.time() + 20,
},
callback=_cb,
)
_ba.run_transactions()
ba.pushcall(_do_it, from_other_thread=True)
while results[0] is False:
time.sleep(0.01)
if time.monotonic() - starttime > 10.0:
raise RuntimeError('timed out')
# If we got left a string, its an error.
if isinstance(results[0], str):
raise RuntimeError(results[0])
def _test_fetch(baseaddr: str) -> None:
# pylint: disable=consider-using-with
import urllib.request
response = urllib.request.urlopen(urllib.request.Request(
f'{baseaddr}/ping', None, {'User-Agent': _ba.app.user_agent_string}),
timeout=10.0)
if response.getcode() != 200:
raise RuntimeError(
f'Got unexpected response code {response.getcode()}.')
data = response.read()
if data != b'pong':
raise RuntimeError('Got unexpected response data.')
def _test_nearby_region_ping(
nearest_region: Optional[tuple[str, float]]) -> None:
"""Try to ping nearest v2 region."""
if nearest_region is None:
raise RuntimeError('No nearest region.')
if nearest_region[1] > 500:
raise RuntimeError('Ping too high.')
class NetValTestingWindow(TestingWindow):
"""Window to test network related settings.""" """Window to test network related settings."""
def __init__(self, transition: str = 'in_right'): def __init__(self, transition: str = 'in_right'):
@ -35,6 +320,8 @@ class NetTestingWindow(testing.TestingWindow):
'increment': 1 'increment': 1
}, },
] ]
testing.TestingWindow.__init__( super().__init__(
self, ba.Lstr(resource='settingsWindowAdvanced.netTestingText'), title=ba.Lstr(resource='settingsWindowAdvanced.netTestingText'),
entries, transition) entries=entries,
transition=transition,
back_call=lambda: NetTestingWindow(transition='in_left'))

View file

@ -11,7 +11,7 @@ import _ba
import ba import ba
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any from typing import Any, Callable, Optional
class TestingWindow(ba.Window): class TestingWindow(ba.Window):
@ -20,11 +20,13 @@ class TestingWindow(ba.Window):
def __init__(self, def __init__(self,
title: ba.Lstr, title: ba.Lstr,
entries: list[dict[str, Any]], entries: list[dict[str, Any]],
transition: str = 'in_right'): transition: str = 'in_right',
back_call: Optional[Callable[[], ba.Window]] = None):
uiscale = ba.app.ui.uiscale uiscale = ba.app.ui.uiscale
self._width = 600 self._width = 600
self._height = 324 if uiscale is ba.UIScale.SMALL else 400 self._height = 324 if uiscale is ba.UIScale.SMALL else 400
self._entries = copy.deepcopy(entries) self._entries = copy.deepcopy(entries)
self._back_call = back_call
super().__init__(root_widget=ba.containerwidget( super().__init__(root_widget=ba.containerwidget(
size=(self._width, self._height), size=(self._width, self._height),
transition=transition, transition=transition,
@ -176,8 +178,8 @@ class TestingWindow(ba.Window):
def _do_back(self) -> None: def _do_back(self) -> None:
# pylint: disable=cyclic-import # pylint: disable=cyclic-import
import bastd.ui.settings.advanced from bastd.ui.settings.advanced import AdvancedSettingsWindow
ba.containerwidget(edit=self._root_widget, transition='out_right') ba.containerwidget(edit=self._root_widget, transition='out_right')
ba.app.ui.set_main_menu_window( backwin = (self._back_call() if self._back_call is not None else
bastd.ui.settings.advanced.AdvancedSettingsWindow( AdvancedSettingsWindow(transition='in_left'))
transition='in_left').get_root_widget()) ba.app.ui.set_main_menu_window(backwin.get_root_widget())

View file

@ -7,13 +7,13 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import ba import ba
from bastd.ui.settings import testing from bastd.ui.settings.testing import TestingWindow
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any from typing import Any
class VRTestingWindow(testing.TestingWindow): class VRTestingWindow(TestingWindow):
"""Window for testing vr settings.""" """Window for testing vr settings."""
def __init__(self, transition: str = 'in_right'): def __init__(self, transition: str = 'in_right'):

View file

@ -266,8 +266,9 @@ if TYPE_CHECKING:
def Call(*_args: Any, **_keywds: Any) -> Any: def Call(*_args: Any, **_keywds: Any) -> Any:
... ...
# (Type-safe Partial)
# A convenient wrapper around functools.partial which adds type-safety # A convenient wrapper around functools.partial which adds type-safety
# (though it does not support keyword arguments). # (though it does not support keyword arguments).
partial = Call tpartial = Call
else: else:
partial = functools.partial tpartial = functools.partial

View file

@ -10,26 +10,45 @@ data formats in a nondestructive manner.
from __future__ import annotations from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, TypeVar
from efro.dataclassio._outputter import _Outputter from efro.dataclassio._outputter import _Outputter
from efro.dataclassio._inputter import _Inputter from efro.dataclassio._inputter import _Inputter
from efro.dataclassio._base import Codec, IOAttrs from efro.dataclassio._base import Codec, IOAttrs, IOExtendedData
from efro.dataclassio._prep import ioprep, ioprepped, is_ioprepped_dataclass from efro.dataclassio._prep import (ioprep, ioprepped, will_ioprep,
is_ioprepped_dataclass)
from efro.dataclassio._pathcapture import DataclassFieldLookup from efro.dataclassio._pathcapture import DataclassFieldLookup
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any from typing import Any, Optional
__all__ = [ __all__ = [
'Codec', 'IOAttrs', 'ioprep', 'ioprepped', 'is_ioprepped_dataclass', 'Codec', 'IOAttrs', 'IOExtendedData', 'ioprep', 'ioprepped', 'will_ioprep',
'DataclassFieldLookup', 'dataclass_to_dict', 'dataclass_to_json', 'is_ioprepped_dataclass', 'DataclassFieldLookup', 'dataclass_to_dict',
'dataclass_from_dict', 'dataclass_from_json', 'dataclass_validate' 'dataclass_to_json', 'dataclass_from_dict', 'dataclass_from_json',
'dataclass_validate'
] ]
T = TypeVar('T') T = TypeVar('T')
class JsonStyle(Enum):
"""Different style types for json."""
# Single line, no spaces, no sorting. Not deterministic.
# Use this for most storage purposes.
FAST = 'fast'
# Single line, no spaces, sorted keys. Deterministic.
# Use this when output may be hashed or compared for equality.
SORTED = 'sorted'
# Multiple lines, spaces, sorted keys. Deterministic.
# Use this for pretty human readable output.
PRETTY = 'pretty'
def dataclass_to_dict(obj: Any, def dataclass_to_dict(obj: Any,
codec: Codec = Codec.JSON, codec: Codec = Codec.JSON,
coerce_to_float: bool = True) -> dict: coerce_to_float: bool = True) -> dict:
@ -45,7 +64,7 @@ def dataclass_to_dict(obj: Any,
the ability to do a lossless round-trip with data). the ability to do a lossless round-trip with data).
If coerce_to_float is True, integer values present on float typed fields If coerce_to_float is True, integer values present on float typed fields
will be converted to floats in the dict output. If False, a TypeError will be converted to float in the dict output. If False, a TypeError
will be triggered. will be triggered.
""" """
@ -59,18 +78,23 @@ def dataclass_to_dict(obj: Any,
def dataclass_to_json(obj: Any, def dataclass_to_json(obj: Any,
coerce_to_float: bool = True, coerce_to_float: bool = True,
pretty: bool = False) -> str: pretty: bool = False,
sort_keys: Optional[bool] = None) -> str:
"""Utility function; return a json string from a dataclass instance. """Utility function; return a json string from a dataclass instance.
Basically json.dumps(dataclass_to_dict(...)). Basically json.dumps(dataclass_to_dict(...)).
By default, keys are sorted for pretty output and not otherwise, but
this can be overridden by supplying a value for the 'sort_keys' arg.
""" """
import json import json
jdict = dataclass_to_dict(obj=obj, jdict = dataclass_to_dict(obj=obj,
coerce_to_float=coerce_to_float, coerce_to_float=coerce_to_float,
codec=Codec.JSON) codec=Codec.JSON)
if sort_keys is None:
sort_keys = pretty
if pretty: if pretty:
return json.dumps(jdict, indent=2, sort_keys=True) return json.dumps(jdict, indent=2, sort_keys=sort_keys)
return json.dumps(jdict, separators=(',', ':')) return json.dumps(jdict, separators=(',', ':'), sort_keys=sort_keys)
def dataclass_from_dict(cls: type[T], def dataclass_from_dict(cls: type[T],
@ -94,10 +118,10 @@ def dataclass_from_dict(cls: type[T],
(as this would break the ability to do a lossless round-trip with data). (as this would break the ability to do a lossless round-trip with data).
If coerce_to_float is True, int values passed for float typed fields If coerce_to_float is True, int values passed for float typed fields
will be converted to float values. Otherwise a TypeError is raised. will be converted to float values. Otherwise, a TypeError is raised.
If allow_unknown_attrs is False, AttributeErrors will be raised for If allow_unknown_attrs is False, AttributeErrors will be raised for
attributes present in the dict but not on the data class. Otherwise they attributes present in the dict but not on the data class. Otherwise, they
will be preserved as part of the instance and included if it is will be preserved as part of the instance and included if it is
exported back to a dict, unless discard_unknown_attrs is True, in which exported back to a dict, unless discard_unknown_attrs is True, in which
case they will simply be discarded. case they will simply be discarded.

View file

@ -12,15 +12,6 @@ from typing import TYPE_CHECKING, get_args
# noinspection PyProtectedMember # noinspection PyProtectedMember
from typing import _AnnotatedAlias # type: ignore from typing import _AnnotatedAlias # type: ignore
_pytz_utc: Any
# We don't *require* pytz but we want to support it for tzinfos if available.
try:
import pytz
_pytz_utc = pytz.utc
except ModuleNotFoundError:
_pytz_utc = None # pylint: disable=invalid-name
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Optional from typing import Any, Optional
@ -32,14 +23,6 @@ SIMPLE_TYPES = {int, bool, str, float, type(None)}
EXTRA_ATTRS_ATTR = '_DCIOEXATTRS' EXTRA_ATTRS_ATTR = '_DCIOEXATTRS'
def _ensure_datetime_is_timezone_aware(value: datetime.datetime) -> None:
# We only support timezone-aware utc times.
if (value.tzinfo is not datetime.timezone.utc
and (_pytz_utc is None or value.tzinfo is not _pytz_utc)):
raise ValueError(
'datetime values must have timezone set as timezone.utc')
def _raise_type_error(fieldpath: str, valuetype: type, def _raise_type_error(fieldpath: str, valuetype: type,
expected: tuple[type, ...]) -> None: expected: tuple[type, ...]) -> None:
"""Raise an error when a field value's type does not match expected.""" """Raise an error when a field value's type does not match expected."""
@ -67,6 +50,24 @@ class Codec(Enum):
FIRESTORE = 'firestore' FIRESTORE = 'firestore'
class IOExtendedData:
"""A class that data types can inherit from for extra functionality."""
def will_output(self) -> None:
"""Called before data is sent to an outputter.
Can be overridden to validate or filter data before
sending it on its way.
"""
@classmethod
def will_input(cls, data: dict) -> None:
"""Called on raw data before a class instance is created from it.
Can be overridden to migrate old data formats to new, etc.
"""
def _is_valid_for_codec(obj: Any, codec: Codec) -> bool: def _is_valid_for_codec(obj: Any, codec: Codec) -> bool:
"""Return whether a value consists solely of json-supported types. """Return whether a value consists solely of json-supported types.
@ -126,7 +127,7 @@ class IOAttrs:
# Turning off store_default requires the field to have either # Turning off store_default requires the field to have either
# a default_factory or a default # a default_factory or a default
if not self.store_default: if not self.store_default:
default_factory: Any = field.default_factory # type: ignore default_factory: Any = field.default_factory
if (default_factory is dataclasses.MISSING if (default_factory is dataclasses.MISSING
and field.default is dataclasses.MISSING): and field.default is dataclasses.MISSING):
raise TypeError(f'Field {field.name} of {cls} has' raise TypeError(f'Field {field.name} of {cls} has'
@ -163,7 +164,7 @@ def _get_origin(anntype: Any) -> Any:
def _parse_annotated(anntype: Any) -> tuple[Any, Optional[IOAttrs]]: def _parse_annotated(anntype: Any) -> tuple[Any, Optional[IOAttrs]]:
"""Parse Annotated() constructs, returning annotated type & IOAttrs.""" """Parse Annotated() constructs, returning annotated type & IOAttrs."""
# If we get an Annotated[foo, bar, eep] we take # If we get an Annotated[foo, bar, eep] we take
# foo as the actual type and we look for IOAttrs instances in # foo as the actual type, and we look for IOAttrs instances in
# bar/eep to affect our behavior. # bar/eep to affect our behavior.
ioattrs: Optional[IOAttrs] = None ioattrs: Optional[IOAttrs] = None
if isinstance(anntype, _AnnotatedAlias): if isinstance(anntype, _AnnotatedAlias):

View file

@ -14,11 +14,11 @@ import typing
import datetime import datetime
from typing import TYPE_CHECKING, Generic, TypeVar from typing import TYPE_CHECKING, Generic, TypeVar
from efro.util import enum_by_value from efro.util import enum_by_value, check_utc
from efro.dataclassio._base import (Codec, _parse_annotated, EXTRA_ATTRS_ATTR, from efro.dataclassio._base import (Codec, _parse_annotated, EXTRA_ATTRS_ATTR,
_is_valid_for_codec, _get_origin, _is_valid_for_codec, _get_origin,
SIMPLE_TYPES, _raise_type_error, SIMPLE_TYPES, _raise_type_error,
_ensure_datetime_is_timezone_aware) IOExtendedData)
from efro.dataclassio._prep import PrepSession from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING: if TYPE_CHECKING:
@ -48,6 +48,12 @@ class _Inputter(Generic[T]):
def run(self, values: dict) -> T: def run(self, values: dict) -> T:
"""Do the thing.""" """Do the thing."""
# For special extended data types, call their 'will_output' callback.
tcls = self._cls
if issubclass(tcls, IOExtendedData):
tcls.will_input(values)
out = self._dataclass_from_input(self._cls, '', values) out = self._dataclass_from_input(self._cls, '', values)
assert isinstance(out, self._cls) assert isinstance(out, self._cls)
return out return out
@ -70,14 +76,14 @@ class _Inputter(Generic[T]):
return value return value
if origin is typing.Union: if origin is typing.Union:
# Currently the only unions we support are None/Value # Currently, the only unions we support are None/Value
# (translated from Optional), which we verified on prep. # (translated from Optional), which we verified on prep.
# So let's treat this as a simple optional case. # So let's treat this as a simple optional case.
if value is None: if value is None:
return None return None
childanntypes_l = [ childanntypes_l = [
c for c in typing.get_args(anntype) if c is not type(None) c for c in typing.get_args(anntype) if c is not type(None)
] ] # noqa (pycodestyle complains about *is* with type)
assert len(childanntypes_l) == 1 assert len(childanntypes_l) == 1
return self._value_from_input(cls, fieldpath, childanntypes_l[0], return self._value_from_input(cls, fieldpath, childanntypes_l[0],
value, ioattrs) value, ioattrs)
@ -127,7 +133,7 @@ class _Inputter(Generic[T]):
"""Given input data, returns bytes.""" """Given input data, returns bytes."""
import base64 import base64
# For firestore, bytes are passed as-is. Otherwise they're encoded # For firestore, bytes are passed as-is. Otherwise, they're encoded
# as base64. # as base64.
if self._codec is Codec.FIRESTORE: if self._codec is Codec.FIRESTORE:
if not isinstance(value, bytes): if not isinstance(value, bytes):
@ -159,6 +165,7 @@ class _Inputter(Generic[T]):
prep = PrepSession(explicit=False).prep_dataclass(cls, prep = PrepSession(explicit=False).prep_dataclass(cls,
recursion_level=0) recursion_level=0)
assert prep is not None
extra_attrs = {} extra_attrs = {}
@ -268,7 +275,7 @@ class _Inputter(Generic[T]):
cls, fieldpath, valanntype, val, ioattrs) cls, fieldpath, valanntype, val, ioattrs)
elif issubclass(keyanntype, Enum): elif issubclass(keyanntype, Enum):
# In prep we verified that all these enums' values have # In prep, we verified that all these enums' values have
# the same type, so we can just look at the first to see if # the same type, so we can just look at the first to see if
# this is a string enum or an int enum. # this is a string enum or an int enum.
enumvaltype = type(next(iter(keyanntype)).value) enumvaltype = type(next(iter(keyanntype)).value)
@ -344,7 +351,7 @@ class _Inputter(Generic[T]):
f'Invalid input value for "{fieldpath}" on' f'Invalid input value for "{fieldpath}" on'
f' "{cls.__name__}";' f' "{cls.__name__}";'
f' expected a datetime, got a {type(value).__name__}') f' expected a datetime, got a {type(value).__name__}')
_ensure_datetime_is_timezone_aware(value) check_utc(value)
return value return value
assert self._codec is Codec.JSON assert self._codec is Codec.JSON
@ -355,9 +362,9 @@ class _Inputter(Generic[T]):
f'Invalid input value for "{fieldpath}" on "{cls.__name__}";' f'Invalid input value for "{fieldpath}" on "{cls.__name__}";'
f' expected a list, got a {type(value).__name__}') f' expected a list, got a {type(value).__name__}')
if len(value) != 7 or not all(isinstance(x, int) for x in value): if len(value) != 7 or not all(isinstance(x, int) for x in value):
raise TypeError( raise ValueError(
f'Invalid input value for "{fieldpath}" on "{cls.__name__}";' f'Invalid input value for "{fieldpath}" on "{cls.__name__}";'
f' expected a list of 7 ints.') f' expected a list of 7 ints, got {[type(v) for v in value]}.')
out = datetime.datetime( # type: ignore out = datetime.datetime( # type: ignore
*value, tzinfo=datetime.timezone.utc) *value, tzinfo=datetime.timezone.utc)
if ioattrs is not None: if ioattrs is not None:
@ -380,7 +387,7 @@ class _Inputter(Generic[T]):
assert childanntypes assert childanntypes
if len(value) != len(childanntypes): if len(value) != len(childanntypes):
raise TypeError(f'Invalid tuple input for "{fieldpath}";' raise ValueError(f'Invalid tuple input for "{fieldpath}";'
f' expected {len(childanntypes)} values,' f' expected {len(childanntypes)} values,'
f' found {len(value)}.') f' found {len(value)}.')

View file

@ -14,10 +14,11 @@ import typing
import datetime import datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from efro.util import check_utc
from efro.dataclassio._base import (Codec, _parse_annotated, EXTRA_ATTRS_ATTR, from efro.dataclassio._base import (Codec, _parse_annotated, EXTRA_ATTRS_ATTR,
_is_valid_for_codec, _get_origin, _is_valid_for_codec, _get_origin,
SIMPLE_TYPES, _raise_type_error, SIMPLE_TYPES, _raise_type_error,
_ensure_datetime_is_timezone_aware) IOExtendedData)
from efro.dataclassio._prep import PrepSession from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING: if TYPE_CHECKING:
@ -37,6 +38,11 @@ class _Outputter:
def run(self) -> Any: def run(self) -> Any:
"""Do the thing.""" """Do the thing."""
# For special extended data types, call their 'will_output' callback.
if isinstance(self._obj, IOExtendedData):
self._obj.will_output()
return self._process_dataclass(type(self._obj), self._obj, '') return self._process_dataclass(type(self._obj), self._obj, '')
def _process_dataclass(self, cls: type, obj: Any, fieldpath: str) -> Any: def _process_dataclass(self, cls: type, obj: Any, fieldpath: str) -> Any:
@ -44,6 +50,7 @@ class _Outputter:
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
prep = PrepSession(explicit=False).prep_dataclass(type(obj), prep = PrepSession(explicit=False).prep_dataclass(type(obj),
recursion_level=0) recursion_level=0)
assert prep is not None
fields = dataclasses.fields(obj) fields = dataclasses.fields(obj)
out: Optional[dict[str, Any]] = {} if self._create else None out: Optional[dict[str, Any]] = {} if self._create else None
for field in fields: for field in fields:
@ -60,7 +67,7 @@ class _Outputter:
# If we're not storing default values for this fella, # If we're not storing default values for this fella,
# we can skip all output processing if we've got a default value. # we can skip all output processing if we've got a default value.
if ioattrs is not None and not ioattrs.store_default: if ioattrs is not None and not ioattrs.store_default:
default_factory: Any = field.default_factory # type: ignore default_factory: Any = field.default_factory
if default_factory is not dataclasses.MISSING: if default_factory is not dataclasses.MISSING:
if default_factory() == value: if default_factory() == value:
continue continue
@ -113,14 +120,14 @@ class _Outputter:
return value if self._create else None return value if self._create else None
if origin is typing.Union: if origin is typing.Union:
# Currently the only unions we support are None/Value # Currently, the only unions we support are None/Value
# (translated from Optional), which we verified on prep. # (translated from Optional), which we verified on prep.
# So let's treat this as a simple optional case. # So let's treat this as a simple optional case.
if value is None: if value is None:
return None return None
childanntypes_l = [ childanntypes_l = [
c for c in typing.get_args(anntype) if c is not type(None) c for c in typing.get_args(anntype) if c is not type(None)
] ] # noqa (pycodestyle complains about *is* with type)
assert len(childanntypes_l) == 1 assert len(childanntypes_l) == 1
return self._process_value(cls, fieldpath, childanntypes_l[0], return self._process_value(cls, fieldpath, childanntypes_l[0],
value, ioattrs) value, ioattrs)
@ -242,7 +249,7 @@ class _Outputter:
if not isinstance(value, origin): if not isinstance(value, origin):
raise TypeError(f'Expected a {origin} for {fieldpath};' raise TypeError(f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.') f' found a {type(value)}.')
_ensure_datetime_is_timezone_aware(value) check_utc(value)
if ioattrs is not None: if ioattrs is not None:
ioattrs.validate_datetime(value, fieldpath) ioattrs.validate_datetime(value, fieldpath)
if self._codec is Codec.FIRESTORE: if self._codec is Codec.FIRESTORE:

View file

@ -36,6 +36,7 @@ class _PathCapture:
prep = PrepSession(explicit=False).prep_dataclass(self._cls, prep = PrepSession(explicit=False).prep_dataclass(self._cls,
recursion_level=0) recursion_level=0)
assert prep is not None
try: try:
anntype = prep.annotations[name] anntype = prep.annotations[name]
except KeyError as exc: except KeyError as exc:
@ -75,7 +76,7 @@ class DataclassFieldLookup(Generic[T]):
# We tell the type system that we are returning an instance # We tell the type system that we are returning an instance
# of our class, which allows it to perform type checking on # of our class, which allows it to perform type checking on
# member lookups. In reality, however, we are providing a # member lookups. In reality, however, we are providing a
# special object which captures path lookups so we can build # special object which captures path lookups, so we can build
# a string from them. # a string from them.
if not TYPE_CHECKING: if not TYPE_CHECKING:
out = callback(_PathCapture(self.cls)) out = callback(_PathCapture(self.cls))

View file

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, TypeVar, get_type_hints
from efro.dataclassio._base import _parse_annotated, _get_origin, SIMPLE_TYPES from efro.dataclassio._base import _parse_annotated, _get_origin, SIMPLE_TYPES
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any from typing import Any, Optional
T = TypeVar('T') T = TypeVar('T')
@ -27,11 +27,15 @@ T = TypeVar('T')
# (basically for detecting recursive types) # (basically for detecting recursive types)
MAX_RECURSION = 10 MAX_RECURSION = 10
# Attr name for data we store on dataclass types as part of prep. # Attr name for data we store on dataclass types that have been prepped.
PREP_ATTR = '_DCIOPREP' PREP_ATTR = '_DCIOPREP'
# We also store the prep-session while the prep is in progress.
# (necessary to support recursive types).
PREP_SESSION_ATTR = '_DCIOPREPSESSION'
def ioprep(cls: type) -> None:
def ioprep(cls: type, globalns: dict = None) -> None:
"""Prep a dataclass type for use with this module's functionality. """Prep a dataclass type for use with this module's functionality.
Prepping ensures that all types contained in a data class as well as Prepping ensures that all types contained in a data class as well as
@ -45,10 +49,14 @@ def ioprep(cls: type) -> None:
Prepping a dataclass involves evaluating its type annotations, which, Prepping a dataclass involves evaluating its type annotations, which,
as of PEP 563, are stored simply as strings. This evaluation is done as of PEP 563, are stored simply as strings. This evaluation is done
in the module namespace containing the class, so all referenced types with localns set to the class dict (so that types defined in the class
must be defined at that level. can be used) and globalns set to the containing module's class.
It is possible to override globalns for special cases such as when
prepping happens as part of an execed string instead of within a
module.
""" """
PrepSession(explicit=True).prep_dataclass(cls, recursion_level=0) PrepSession(explicit=True,
globalns=globalns).prep_dataclass(cls, recursion_level=0)
def ioprepped(cls: type[T]) -> type[T]: def ioprepped(cls: type[T]) -> type[T]:
@ -64,6 +72,23 @@ def ioprepped(cls: type[T]) -> type[T]:
return cls return cls
def will_ioprep(cls: type[T]) -> type[T]:
"""Class decorator hinting that we will prep a class later.
In some cases (such as recursive types) we cannot use the @ioprepped
decorator and must instead call ioprep() explicitly later. However,
some of our custom pylint checking behaves differently when the
@ioprepped decorator is present, in that case requiring type annotations
to be present and not simply forward declared under an "if TYPE_CHECKING"
block. (since they are used at runtime).
The @will_ioprep decorator triggers the same pylint behavior
differences as @ioprepped (which are necessary for the later ioprep() call
to work correctly) but without actually running any prep itself.
"""
return cls
def is_ioprepped_dataclass(obj: Any) -> bool: def is_ioprepped_dataclass(obj: Any) -> bool:
"""Return whether the obj is an ioprepped dataclass type or instance.""" """Return whether the obj is an ioprepped dataclass type or instance."""
cls = obj if isinstance(obj, type) else type(obj) cls = obj if isinstance(obj, type) else type(obj)
@ -87,11 +112,19 @@ class PrepData:
class PrepSession: class PrepSession:
"""Context for a prep.""" """Context for a prep."""
def __init__(self, explicit: bool): def __init__(self, explicit: bool, globalns: Optional[dict] = None):
self.explicit = explicit self.explicit = explicit
self.globalns = globalns
def prep_dataclass(self, cls: type, recursion_level: int) -> PrepData: def prep_dataclass(self, cls: type,
"""Run prep on a dataclass if necessary and return its prep data.""" recursion_level: int) -> Optional[PrepData]:
"""Run prep on a dataclass if necessary and return its prep data.
The only case where this will return None is for recursive types
if the type is already being prepped higher in the call order.
"""
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
# We should only need to do this once per dataclass. # We should only need to do this once per dataclass.
existing_data = getattr(cls, PREP_ATTR, None) existing_data = getattr(cls, PREP_ATTR, None)
@ -99,8 +132,9 @@ class PrepSession:
assert isinstance(existing_data, PrepData) assert isinstance(existing_data, PrepData)
return existing_data return existing_data
# If we run into classes containing themselves, we may have # Sanity check.
# to do something smarter to handle it. # Note that we now support recursive types via the PREP_SESSION_ATTR,
# so we theoretically shouldn't run into this this.
if recursion_level > MAX_RECURSION: if recursion_level > MAX_RECURSION:
raise RuntimeError('Max recursion exceeded.') raise RuntimeError('Max recursion exceeded.')
@ -108,6 +142,18 @@ class PrepSession:
if not isinstance(cls, type) or not dataclasses.is_dataclass(cls): if not isinstance(cls, type) or not dataclasses.is_dataclass(cls):
raise TypeError(f'Passed arg {cls} is not a dataclass type.') raise TypeError(f'Passed arg {cls} is not a dataclass type.')
# Add a pointer to the prep-session while doing the prep.
# This way we can ignore types that we're already in the process
# of prepping and can support recursive types.
existing_prep = getattr(cls, PREP_SESSION_ATTR, None)
if existing_prep is not None:
if existing_prep is self:
return None
# We shouldn't need to support failed preps
# or preps from multiple threads at once.
raise RuntimeError('Found existing in-progress prep.')
setattr(cls, PREP_SESSION_ATTR, self)
# Generate a warning on non-explicit preps; we prefer prep to # Generate a warning on non-explicit preps; we prefer prep to
# happen explicitly at runtime so errors can be detected early on. # happen explicitly at runtime so errors can be detected early on.
if not self.explicit: if not self.explicit:
@ -123,10 +169,10 @@ class PrepSession:
# which allows us to pick up nested classes, etc. # which allows us to pick up nested classes, etc.
resolved_annotations = get_type_hints(cls, resolved_annotations = get_type_hints(cls,
localns=vars(cls), localns=vars(cls),
globalns=self.globalns,
include_extras=True) include_extras=True)
# pylint: enable=unexpected-keyword-arg # pylint: enable=unexpected-keyword-arg
except Exception as exc: except Exception as exc:
print('GOT', cls.__dict__)
raise TypeError( raise TypeError(
f'dataclassio prep for {cls} failed with error: {exc}.' f'dataclassio prep for {cls} failed with error: {exc}.'
f' Make sure all types used in annotations are defined' f' Make sure all types used in annotations are defined'
@ -175,6 +221,10 @@ class PrepSession:
annotations=resolved_annotations, annotations=resolved_annotations,
storage_names_to_attr_names=storage_names_to_attr_names) storage_names_to_attr_names=storage_names_to_attr_names)
setattr(cls, PREP_ATTR, prepdata) setattr(cls, PREP_ATTR, prepdata)
# Clear our prep-session tag.
assert getattr(cls, PREP_SESSION_ATTR, None) is self
delattr(cls, PREP_SESSION_ATTR)
return prepdata return prepdata
def prep_type(self, cls: type, attrname: str, anntype: Any, def prep_type(self, cls: type, attrname: str, anntype: Any,
@ -303,7 +353,7 @@ class PrepSession:
"""Run prep on a Union type.""" """Run prep on a Union type."""
typeargs = typing.get_args(anntype) typeargs = typing.get_args(anntype)
if (len(typeargs) != 2 if (len(typeargs) != 2
or len([c for c in typeargs if c is type(None)]) != 1): or len([c for c in typeargs if c is type(None)]) != 1): # noqa
raise TypeError(f'Union {anntype} for attr \'{attrname}\' on' raise TypeError(f'Union {anntype} for attr \'{attrname}\' on'
f' {cls.__name__} is not supported by dataclassio;' f' {cls.__name__} is not supported by dataclassio;'
f' only 2 member Unions with one type being None' f' only 2 member Unions with one type being None'

View file

@ -84,6 +84,12 @@ def is_urllib_network_error(exc: BaseException) -> bool:
exc, exc,
(urllib.error.URLError, ConnectionError, http.client.IncompleteRead, (urllib.error.URLError, ConnectionError, http.client.IncompleteRead,
http.client.BadStatusLine, socket.timeout)): http.client.BadStatusLine, socket.timeout)):
# Special case: although an HTTPError is a subclass of URLError,
# we don't return True for it. It means we have successfully
# communicated with the server but what we are asking for is
# not there/etc.
if isinstance(exc, urllib.error.HTTPError):
return False
return True return True
if isinstance(exc, OSError): if isinstance(exc, OSError):
if exc.errno == 10051: # Windows unreachable network error. if exc.errno == 10051: # Windows unreachable network error.

View file

@ -766,7 +766,7 @@ class MessageReceiver:
# Return type of None translates to EmptyResponse. # Return type of None translates to EmptyResponse.
responsetypes = tuple(EmptyResponse if r is type(None) else r responsetypes = tuple(EmptyResponse if r is type(None) else r
for r in responsetypes) for r in responsetypes) # noqa
# Make sure our protocol has this message type registered and our # Make sure our protocol has this message type registered and our
# return types exactly match. (Technically we could return a subset # return types exactly match. (Technically we could return a subset

View file

@ -75,6 +75,10 @@ def _default_color_enabled() -> bool:
if not sys.__stdout__.isatty(): if not sys.__stdout__.isatty():
return False return False
# Another common way to say the terminal can't do fancy stuff like color:
if os.environ.get('TERM') == 'dumb':
return False
# On windows, try to enable ANSI color mode. # On windows, try to enable ANSI color mode.
if platform.system() == 'Windows': if platform.system() == 'Windows':
return _windows_enable_color() return _windows_enable_color()

View file

@ -11,10 +11,19 @@ import functools
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, cast, TypeVar, Generic from typing import TYPE_CHECKING, cast, TypeVar, Generic
_pytz_utc: Any
# We don't *require* pytz, but we want to support it for tzinfos if available.
try:
import pytz
_pytz_utc = pytz.utc
except ModuleNotFoundError:
_pytz_utc = None # pylint: disable=invalid-name
if TYPE_CHECKING: if TYPE_CHECKING:
import asyncio import asyncio
from efro.call import Call as Call # 'as Call' so we re-export. from efro.call import Call as Call # 'as Call' so we re-export.
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, NoReturn
T = TypeVar('T') T = TypeVar('T')
TVAL = TypeVar('TVAL') TVAL = TypeVar('TVAL')
@ -62,6 +71,14 @@ def enum_by_value(cls: type[TENUM], value: Any) -> TENUM:
(value, cls.__name__)) from None (value, cls.__name__)) from None
def check_utc(value: datetime.datetime) -> None:
"""Ensure a datetime value is timezone-aware utc."""
if (value.tzinfo is not datetime.timezone.utc
and (_pytz_utc is None or value.tzinfo is not _pytz_utc)):
raise ValueError('datetime value does not have timezone set as'
' datetime.timezone.utc')
def utc_now() -> datetime.datetime: def utc_now() -> datetime.datetime:
"""Get offset-aware current utc time. """Get offset-aware current utc time.
@ -240,7 +257,8 @@ class DispatchMethodWrapper(Generic[TARG, TRET]):
pass pass
@staticmethod @staticmethod
def register(func: Callable[[Any, Any], TRET]) -> Callable: def register(
func: Callable[[Any, Any], TRET]) -> Callable[[Any, Any], TRET]:
"""Register a new dispatch handler for this dispatch-method.""" """Register a new dispatch handler for this dispatch-method."""
registry: dict[Any, Callable] registry: dict[Any, Callable]
@ -312,12 +330,16 @@ class ValueDispatcher(Generic[TVAL, TRET]):
return handler() return handler()
return self._base_call(value) return self._base_call(value)
def _add_handler(self, value: TVAL, call: Callable[[], TRET]) -> None: def _add_handler(self, value: TVAL,
call: Callable[[], TRET]) -> Callable[[], TRET]:
if value in self._handlers: if value in self._handlers:
raise RuntimeError(f'Duplicate handlers added for {value}') raise RuntimeError(f'Duplicate handlers added for {value}')
self._handlers[value] = call self._handlers[value] = call
return call
def register(self, value: TVAL) -> Callable[[Callable[[], TRET]], None]: def register(
self,
value: TVAL) -> Callable[[Callable[[], TRET]], Callable[[], TRET]]:
"""Add a handler to the dispatcher.""" """Add a handler to the dispatcher."""
from functools import partial from functools import partial
return partial(self._add_handler, value) return partial(self._add_handler, value)
@ -343,13 +365,16 @@ class ValueDispatcher1Arg(Generic[TVAL, TARG, TRET]):
return handler(arg) return handler(arg)
return self._base_call(value, arg) return self._base_call(value, arg)
def _add_handler(self, value: TVAL, call: Callable[[TARG], TRET]) -> None: def _add_handler(self, value: TVAL,
call: Callable[[TARG], TRET]) -> Callable[[TARG], TRET]:
if value in self._handlers: if value in self._handlers:
raise RuntimeError(f'Duplicate handlers added for {value}') raise RuntimeError(f'Duplicate handlers added for {value}')
self._handlers[value] = call self._handlers[value] = call
return call
def register(self, def register(
value: TVAL) -> Callable[[Callable[[TARG], TRET]], None]: self, value: TVAL
) -> Callable[[Callable[[TARG], TRET]], Callable[[TARG], TRET]]:
"""Add a handler to the dispatcher.""" """Add a handler to the dispatcher."""
from functools import partial from functools import partial
return partial(self._add_handler, value) return partial(self._add_handler, value)
@ -363,8 +388,9 @@ if TYPE_CHECKING:
def __call__(self, value: TVAL) -> TRET: def __call__(self, value: TVAL) -> TRET:
... ...
def register(self, def register(
value: TVAL) -> Callable[[Callable[[TSELF], TRET]], None]: self, value: TVAL
) -> Callable[[Callable[[TSELF], TRET]], Callable[[TSELF], TRET]]:
"""Add a handler to the dispatcher.""" """Add a handler to the dispatcher."""
... ...
@ -579,6 +605,8 @@ def human_readable_compact_id(num: int) -> str:
'o' is excluded due to similarity to '0'. 'o' is excluded due to similarity to '0'.
'z' is excluded due to similarity to '2'. 'z' is excluded due to similarity to '2'.
Therefore for n chars this can store values of 21^n.
When reading human input consisting of these IDs, it may be desirable When reading human input consisting of these IDs, it may be desirable
to map the disallowed chars to their corresponding allowed ones to map the disallowed chars to their corresponding allowed ones
('o' -> '0', etc). ('o' -> '0', etc).
@ -599,8 +627,39 @@ def compact_id(num: int) -> str:
friendly to humans due to using both capital and lowercase letters, friendly to humans due to using both capital and lowercase letters,
both 'O' and '0', etc. both 'O' and '0', etc.
Therefore for n chars this can store values of 62^n.
Sort order for these ids is the same as the original numbers. Sort order for these ids is the same as the original numbers.
""" """
return _compact_id( return _compact_id(
num, '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ' num, '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
'abcdefghijklmnopqrstuvwxyz') 'abcdefghijklmnopqrstuvwxyz')
def assert_never(value: NoReturn) -> NoReturn:
"""Trick for checking exhaustive handling of Enums, etc.
See https://github.com/python/typing/issues/735
"""
assert False, f'Unhandled value: {value} ({type(value).__name__})'
def unchanging_hostname() -> str:
"""Return an unchanging name for the local device.
Similar to the `hostname` call (or os.uname().nodename in Python)
except attempts to give a name that doesn't change depending on
network conditions. (A Mac will tend to go from Foo to Foo.local,
Foo.lan etc. throughout its various adventures)
"""
import os
import platform
import subprocess
# On Mac, this should give the computer name assigned in System Prefs.
if platform.system() == 'Darwin':
return subprocess.run(
['scutil', '--get', 'ComputerName'],
check=True,
capture_output=True).stdout.decode().strip().replace(' ', '-')
return os.uname().nodename

Binary file not shown.