1.6.10 update

This commit is contained in:
Ayush Saini 2022-03-13 17:33:09 +05:30
parent 81485da646
commit 240155bce3
37 changed files with 809 additions and 350 deletions

View file

@ -3,8 +3,9 @@
"""Functionality related to the high level state of the app."""
from __future__ import annotations
from enum import Enum
import random
import logging
from enum import Enum
from typing import TYPE_CHECKING
import _ba
@ -184,6 +185,9 @@ class App:
self.state = self.State.LAUNCHING
self._app_launched = False
self._app_paused = False
# Config.
self.config_file_healthy = False
@ -348,27 +352,6 @@ class App:
for key in ('lc14173', 'lc14292'):
cfg.setdefault(key, launch_count)
# Debugging - make note if we're using the local test server so we
# don't accidentally leave it on in a release.
# FIXME - should move this to the native layer.
server_addr = _ba.get_master_server_address()
if 'localhost' in server_addr:
_ba.timer(2.0,
lambda: _ba.screenmessage(
'Note: using local server',
(1, 1, 0),
log=True,
),
timetype=TimeType.REAL)
elif 'test' in server_addr:
_ba.timer(2.0,
lambda: _ba.screenmessage(
'Note: using test server-module',
(1, 1, 0),
log=True,
),
timetype=TimeType.REAL)
cfg['launchCount'] = launch_count
cfg.commit()
@ -389,20 +372,40 @@ class App:
self.accounts.on_app_launch()
self.plugins.on_app_launch()
self.state = self.State.RUNNING
# See note below in on_app_pause.
if self.state != self.State.LAUNCHING:
logging.error('on_app_launch found state %s; expected LAUNCHING.',
self.state)
self._app_launched = True
self._update_state()
# from ba._dependency import test_depset
# test_depset()
if bool(False):
self._test_https()
def _update_state(self) -> None:
if self._app_paused:
self.state = self.State.PAUSED
else:
if self._app_launched:
self.state = self.State.RUNNING
else:
self.state = self.State.LAUNCHING
def on_app_pause(self) -> None:
"""Called when the app goes to a suspended state."""
self.state = self.State.PAUSED
self._app_paused = True
self._update_state()
self.plugins.on_app_pause()
def on_app_resume(self) -> None:
"""Run when the app resumes from a suspended state."""
self.state = self.State.RUNNING
self._app_paused = False
self._update_state()
self.fg_state += 1
self.accounts.on_app_resume()
self.music.on_app_resume()
@ -586,6 +589,8 @@ class App:
try:
with urllib.request.urlopen('https://example.com') as url:
val = url.read()
_ba.screenmessage('HTTPS SUCCESS!')
print('HTTPS TEST SUCCESS', len(val))
except Exception as exc:
_ba.screenmessage('HTTPS FAIL.')
print('HTTPS TEST FAIL:', exc)

View file

@ -60,6 +60,8 @@ def setup_asyncio() -> None:
timetype=TimeType.REAL,
repeat=True)
if bool(False):
async def aio_test() -> None:
print('TEST AIO TASK STARTING')
assert _asyncio_event_loop is not None
@ -67,5 +69,4 @@ def setup_asyncio() -> None:
await asyncio.sleep(2.0)
print('TEST AIO TASK ENDING')
if bool(False):
_asyncio_event_loop.create_task(aio_test())

View file

@ -148,7 +148,7 @@ class DependencyEntry:
instance = self.cls.__new__(self.cls)
# pylint: disable=protected-access
instance._dep_entry = weakref.ref(self)
instance.__init__()
instance.__init__() # type: ignore
assert self.depset
depset = self.depset()

View file

@ -196,3 +196,4 @@ class SpecialChar(Enum):
FLAG_PHILIPPINES = 87
FLAG_CHILE = 88
MIKIROG = 89
V2_LOGO = 90

View file

@ -196,3 +196,4 @@ class SpecialChar(Enum):
FLAG_PHILIPPINES = 87
FLAG_CHILE = 88
MIKIROG = 89
V2_LOGO = 90

View file

@ -119,6 +119,11 @@ def gear_vr_controller_warning() -> None:
color=(1, 0, 0))
def uuid_str() -> str:
import uuid
return str(uuid.uuid4())
def orientation_reset_cb_message() -> None:
from ba._language import Lstr
_ba.screenmessage(
@ -370,3 +375,13 @@ def get_player_icon(sessionplayer: ba.SessionPlayer) -> dict[str, Any]:
'tint_color': info['tint_color'],
'tint2_color': info['tint2_color']
}
def hash_strings(inputs: list[str]) -> str:
"""Hash provided strings into a short output string."""
import hashlib
sha = hashlib.sha1()
for inp in inputs:
sha.update(inp.encode())
return sha.hexdigest()

View file

@ -87,6 +87,7 @@ class LanguageSubsystem:
'vec': 'Venetian',
'hi': 'Hindi',
'ta': 'Tamil',
'fil': 'Filipino',
}
# Special case for Chinese: map specific variations to traditional.
@ -373,7 +374,7 @@ class Lstr:
currently-active language.
To see available resource keys, look at any of the bs_language_*.py files
in the game or the translations pages at bombsquadgame.com/translate.
in the game or the translations pages at legacy.ballistica.net/translate.
# EXAMPLE 1: specify a string from a resource path
mynode.text = ba.Lstr(resource='audioSettingsWindow.titleText')

View file

@ -25,8 +25,15 @@ class NetworkSubsystem:
"""Network related app subsystem."""
def __init__(self) -> None:
# Anyone accessing/modifying region_pings should hold this lock.
self.region_pings_lock = threading.Lock()
self.region_pings: dict[str, float] = {}
# For debugging.
self.v1_test_log: str = ''
self.v1_ctest_results: dict[int, str] = {}
def get_ip_address_type(addr: str) -> socket.AddressFamily:
"""Return socket.AF_INET6 or socket.AF_INET4 for the provided address."""

View file

@ -38,8 +38,8 @@ class AssetType(Enum):
@dataclass
class AssetPackageFlavorManifest:
"""A manifest of asset info for a specific flavor of an asset package."""
assetfiles: Annotated[dict[str, str],
IOAttrs('assetfiles')] = field(default_factory=dict)
cloudfiles: Annotated[dict[str, str],
IOAttrs('cloudfiles')] = field(default_factory=dict)
@ioprepped

33
dist/ba_data/python/bacommon/build.py vendored Normal file
View file

@ -0,0 +1,33 @@
# Released under the MIT License. See LICENSE for details.
#
"""Functionality related to game builds."""
from __future__ import annotations
import datetime
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Annotated
from efro.dataclassio import ioprepped, IOAttrs
if TYPE_CHECKING:
pass
@ioprepped
@dataclass
class BuildInfoSet:
"""Set of build infos."""
@dataclass
class Entry:
"""Info about a particular build."""
filename: Annotated[str, IOAttrs('fname')]
size: Annotated[int, IOAttrs('size')]
version: Annotated[str, IOAttrs('version')]
build_number: Annotated[int, IOAttrs('build')]
checksum: Annotated[str, IOAttrs('checksum')]
createtime: Annotated[datetime.datetime, IOAttrs('createtime')]
builds: Annotated[list[Entry],
IOAttrs('builds')] = field(default_factory=list)

View file

@ -104,7 +104,7 @@ class ServerConfig:
# if ${ACCOUNT} is present in the string, it will be replaced by the
# currently-signed-in account's id. To fetch info about an account,
# your back-end server can use the following url:
# http://bombsquadgame.com/accountquery?id=ACCOUNT_ID_HERE
# https://legacy.ballistica.net/accountquery?id=ACCOUNT_ID_HERE
stats_url: Optional[str] = None
# If present, the server subprocess will attempt to gracefully exit after

View file

@ -39,15 +39,21 @@ class Spawner:
The spawn position.
"""
def __init__(self, spawner: Spawner, data: Any, pt: Sequence[float]):
def __init__(
self,
spawner: Spawner,
data: Any,
pt: Sequence[float], # pylint: disable=invalid-name
):
"""Instantiate with the given values."""
self.spawner = spawner
self.data = data
self.pt = pt # pylint: disable=invalid-name
def __init__(self,
def __init__(
self,
data: Any = None,
pt: Sequence[float] = (0, 0, 0),
pt: Sequence[float] = (0, 0, 0), # pylint: disable=invalid-name
spawn_time: float = 1.0,
send_spawn_message: bool = True,
spawn_callback: Callable[[], Any] = None):

View file

@ -199,6 +199,7 @@ class Spaz(ba.Actor):
self.equip_boxing_gloves()
self.last_punch_time_ms = -9999
self.last_pickup_time_ms = -9999
self.last_jump_time_ms = -9999
self.last_run_time_ms = -9999
self._last_run_value = 0.0
self.last_bomb_time_ms = -9999
@ -363,7 +364,11 @@ class Spaz(ba.Actor):
"""
if not self.node:
return
t_ms = ba.time(timeformat=ba.TimeFormat.MILLISECONDS)
assert isinstance(t_ms, int)
if t_ms - self.last_jump_time_ms >= self._jump_cooldown:
self.node.jump_pressed = True
self.last_jump_time_ms = t_ms
self._turbo_filter_add_press('jump')
def on_jump_release(self) -> None:

View file

@ -26,7 +26,7 @@ class SharedObjects:
def __init__(self) -> None:
activity = ba.getactivity()
if hasattr(activity, self._STORENAME):
if self._STORENAME in activity.customdata:
raise RuntimeError('Use SharedObjects.get() to fetch the'
' shared instance for this activity.')
self._object_material: Optional[ba.Material] = None

View file

@ -60,7 +60,7 @@ class MainMenuActivity(ba.Activity[ba.Player, ba.Team]):
'scale': scale,
'position': (0, 10),
'vr_depth': -10,
'text': '\xa9 2011-2021 Eric Froemling'
'text': '\xa9 2011-2022 Eric Froemling'
}))
# Throw up some text that only clients can see so they know that the

View file

@ -10,17 +10,17 @@ import ba
def show_sign_in_prompt(account_type: str = None) -> None:
"""Bring up a prompt telling the user they must sign in."""
from bastd.ui import confirm
from bastd.ui.confirm import ConfirmWindow
from bastd.ui.account import settings
if account_type == 'Google Play':
confirm.ConfirmWindow(
ConfirmWindow(
ba.Lstr(resource='notSignedInGooglePlayErrorText'),
lambda: _ba.sign_in('Google Play'),
ok_text=ba.Lstr(resource='accountSettingsWindow.signInText'),
width=460,
height=130)
else:
confirm.ConfirmWindow(
ConfirmWindow(
ba.Lstr(resource='notSignedInErrorText'),
lambda: settings.AccountSettingsWindow(modal=True,
close_once_signed_in=True),

View file

@ -25,6 +25,10 @@ class AccountSettingsWindow(ba.Window):
close_once_signed_in: bool = False):
# pylint: disable=too-many-statements
self._sign_in_game_circle_button: Optional[ba.Widget] = None
self._sign_in_v2_button: Optional[ba.Widget] = None
self._sign_in_device_button: Optional[ba.Widget] = None
self._close_once_signed_in = close_once_signed_in
ba.set_analytics_screen('Account Window')
@ -86,6 +90,10 @@ class AccountSettingsWindow(ba.Window):
# exceptions.
self._show_sign_in_buttons.append('Local')
# Ditto with shiny new V2 ones.
if bool(False):
self._show_sign_in_buttons.append('V2')
top_extra = 15 if uiscale is ba.UIScale.SMALL else 0
super().__init__(root_widget=ba.containerwidget(
size=(self._width, self._height + top_extra),
@ -206,12 +214,10 @@ class AccountSettingsWindow(ba.Window):
show_game_circle_sign_in_button = (account_state == 'signed_out'
and 'Game Circle'
in self._show_sign_in_buttons)
show_ali_sign_in_button = (account_state == 'signed_out'
and 'Ali' in self._show_sign_in_buttons)
show_test_sign_in_button = (account_state == 'signed_out'
and 'Test' in self._show_sign_in_buttons)
show_device_sign_in_button = (account_state == 'signed_out' and 'Local'
in self._show_sign_in_buttons)
show_v2_sign_in_button = (account_state == 'signed_out'
and 'V2' in self._show_sign_in_buttons)
sign_in_button_space = 70.0
show_game_service_button = (self._signed_in and account_type
@ -223,9 +229,9 @@ class AccountSettingsWindow(ba.Window):
'allowAccountLinking2', False))
linked_accounts_text_space = 60.0
show_achievements_button = (self._signed_in and account_type
in ('Google Play', 'Alibaba', 'Local',
'OUYA', 'Test'))
show_achievements_button = (
self._signed_in
and account_type in ('Google Play', 'Alibaba', 'Local', 'OUYA'))
achievements_button_space = 60.0
show_achievements_text = (self._signed_in
@ -255,8 +261,8 @@ class AccountSettingsWindow(ba.Window):
show_unlink_accounts_button = show_link_accounts_button
unlink_accounts_button_space = 90.0
show_sign_out_button = (self._signed_in and account_type
in ['Test', 'Local', 'Google Play'])
show_sign_out_button = (self._signed_in
and account_type in ['Local', 'Google Play'])
sign_out_button_space = 70.0
if self._subcontainer is not None:
@ -272,12 +278,10 @@ class AccountSettingsWindow(ba.Window):
self._sub_height += sign_in_button_space
if show_game_circle_sign_in_button:
self._sub_height += sign_in_button_space
if show_ali_sign_in_button:
self._sub_height += sign_in_button_space
if show_test_sign_in_button:
self._sub_height += sign_in_button_space
if show_device_sign_in_button:
self._sub_height += sign_in_button_space
if show_v2_sign_in_button:
self._sub_height += sign_in_button_space
if show_game_service_button:
self._sub_height += game_service_button_space
if show_linked_accounts_text:
@ -462,21 +466,42 @@ class AccountSettingsWindow(ba.Window):
ba.widget(edit=btn, show_buffer_bottom=40, show_buffer_top=100)
self._sign_in_text = None
if show_ali_sign_in_button:
if show_v2_sign_in_button:
button_width = 350
v -= sign_in_button_space
self._sign_in_ali_button = btn = ba.buttonwidget(
self._sign_in_v2_button = btn = ba.buttonwidget(
parent=self._subcontainer,
position=((self._sub_width - button_width) * 0.5, v - 20),
autoselect=True,
size=(button_width, 60),
label=ba.Lstr(value='${A}${B}',
subs=[('${A}',
ba.charstr(ba.SpecialChar.ALIBABA_LOGO)),
label='',
on_activate_call=self._v2_sign_in_press)
ba.textwidget(
parent=self._subcontainer,
draw_controller=btn,
h_align='center',
v_align='center',
size=(0, 0),
position=(self._sub_width * 0.5, v + 17),
text=ba.Lstr(
value='${A}${B}',
subs=[('${A}', ba.charstr(ba.SpecialChar.V2_LOGO)),
('${B}',
ba.Lstr(resource=self._r + '.signInText'))
]),
on_activate_call=lambda: self._sign_in_press('Ali'))
ba.Lstr(resource=self._r + '.signInWithV2Text'))]),
maxwidth=button_width * 0.8,
color=(0.75, 1.0, 0.7))
ba.textwidget(parent=self._subcontainer,
draw_controller=btn,
h_align='center',
v_align='center',
size=(0, 0),
position=(self._sub_width * 0.5, v - 4),
text=ba.Lstr(resource=self._r +
'.signInWithV2InfoText'),
flatness=1.0,
scale=0.57,
maxwidth=button_width * 0.9,
color=(0.55, 0.8, 0.5))
if first_selectable is None:
first_selectable = btn
if ba.app.ui.use_toolbars:
@ -532,53 +557,6 @@ class AccountSettingsWindow(ba.Window):
ba.widget(edit=btn, show_buffer_bottom=40, show_buffer_top=100)
self._sign_in_text = None
# Old test-account option.
if show_test_sign_in_button:
button_width = 350
v -= sign_in_button_space
self._sign_in_test_button = btn = ba.buttonwidget(
parent=self._subcontainer,
position=((self._sub_width - button_width) * 0.5, v - 20),
autoselect=True,
size=(button_width, 60),
label='',
on_activate_call=lambda: self._sign_in_press('Test'))
ba.textwidget(parent=self._subcontainer,
draw_controller=btn,
h_align='center',
v_align='center',
size=(0, 0),
position=(self._sub_width * 0.5, v + 17),
text=ba.Lstr(
value='${A}${B}',
subs=[('${A}',
ba.charstr(ba.SpecialChar.TEST_ACCOUNT)),
('${B}',
ba.Lstr(resource=self._r +
'.signInWithTestAccountText'))]),
maxwidth=button_width * 0.8,
color=(0.75, 1.0, 0.7))
ba.textwidget(parent=self._subcontainer,
draw_controller=btn,
h_align='center',
v_align='center',
size=(0, 0),
position=(self._sub_width * 0.5, v - 4),
text=ba.Lstr(resource=self._r +
'.signInWithTestAccountInfoText'),
flatness=1.0,
scale=0.57,
maxwidth=button_width * 0.9,
color=(0.55, 0.8, 0.5))
if first_selectable is None:
first_selectable = btn
if ba.app.ui.use_toolbars:
ba.widget(edit=btn,
right_widget=_ba.get_special_widget('party_button'))
ba.widget(edit=btn, left_widget=bbtn)
ba.widget(edit=btn, show_buffer_bottom=40, show_buffer_top=100)
self._sign_in_text = None
if show_player_profiles_button:
button_width = 300
v -= player_profiles_button_space
@ -1051,6 +1029,12 @@ class AccountSettingsWindow(ba.Window):
self._needs_refresh = True
ba.timer(0.1, ba.WeakCall(self._update), timetype=ba.TimeType.REAL)
def _v2_sign_in_press(self) -> None:
# pylint: disable=cyclic-import
from bastd.ui.account.v2 import V2SignInWindow
assert self._sign_in_v2_button is not None
V2SignInWindow(origin_widget=self._sign_in_v2_button)
def _reset_progress(self) -> None:
try:
from ba.internal import getcampaign

View file

@ -0,0 +1,92 @@
# Released under the MIT License. See LICENSE for details.
#
"""V2 account ui bits."""
from __future__ import annotations
from typing import TYPE_CHECKING
import ba
import _ba
if TYPE_CHECKING:
from typing import Any, Optional
class V2SignInWindow(ba.Window):
"""A window allowing signing in to a v2 account."""
def __init__(self, origin_widget: ba.Widget):
from ba.internal import is_browser_likely_available
logincode = '1412345'
address = (
f'{_ba.get_master_server_address(version=2)}?login={logincode}')
self._width = 600
self._height = 500
uiscale = ba.app.ui.uiscale
super().__init__(root_widget=ba.containerwidget(
size=(self._width, self._height),
transition='in_scale',
scale_origin_stack_offset=origin_widget.get_screen_space_center(),
scale=(1.25 if uiscale is ba.UIScale.SMALL else
1.0 if uiscale is ba.UIScale.MEDIUM else 0.85)))
ba.textwidget(
parent=self._root_widget,
position=(self._width * 0.5, self._height - 85),
size=(0, 0),
text=ba.Lstr(
resource='accountSettingsWindow.v2LinkInstructionsText'),
color=ba.app.ui.title_color,
maxwidth=self._width * 0.9,
h_align='center',
v_align='center')
button_width = 450
if is_browser_likely_available():
ba.buttonwidget(parent=self._root_widget,
position=((self._width * 0.5 - button_width * 0.5),
self._height - 175),
autoselect=True,
size=(button_width, 60),
label=ba.Lstr(value=address),
color=(0.55, 0.5, 0.6),
textcolor=(0.75, 0.7, 0.8),
on_activate_call=lambda: ba.open_url(address))
qroffs = 0.0
else:
ba.textwidget(parent=self._root_widget,
position=(self._width * 0.5, self._height - 135),
size=(0, 0),
text=ba.Lstr(value=address),
flatness=1.0,
maxwidth=self._width,
scale=0.75,
h_align='center',
v_align='center')
qroffs = 20.0
self._cancel_button = ba.buttonwidget(
parent=self._root_widget,
position=(30, self._height - 55),
size=(130, 50),
scale=0.8,
label=ba.Lstr(resource='cancelText'),
# color=(0.6, 0.5, 0.6),
on_activate_call=self._done,
autoselect=True,
textcolor=(0.75, 0.7, 0.8),
# icon=ba.gettexture('crossOut'),
# iconscale=1.2
)
ba.containerwidget(edit=self._root_widget,
cancel_button=self._cancel_button)
qr_size = 270
ba.imagewidget(parent=self._root_widget,
position=(self._width * 0.5 - qr_size * 0.5,
self._height * 0.34 + qroffs - qr_size * 0.5),
size=(qr_size, qr_size),
texture=_ba.get_qrcode_texture(address))
def _done(self) -> None:
ba.containerwidget(edit=self._root_widget, transition='out_scale')

View file

@ -1209,6 +1209,7 @@ class PublicGatherTab(GatherTab):
def on_public_party_activate(self, party: PartyEntry) -> None:
"""Called when a party is clicked or otherwise activated."""
self.save_state()
if party.queue is not None:
from bastd.ui.partyqueue import PartyQueueWindow
ba.playsound(ba.getsound('swish'))

View file

@ -5,7 +5,6 @@
from __future__ import annotations
import math
import weakref
from typing import TYPE_CHECKING, cast
import _ba
@ -413,53 +412,3 @@ class PartyWindow(ba.Window):
"""Close the window and make a lovely sound."""
ba.playsound(ba.getsound('swish'))
self.close()
def handle_party_invite(name: str, invite_id: str) -> None:
"""Handle an incoming party invitation."""
from bastd import mainmenu
from bastd.ui import confirm
ba.playsound(ba.getsound('fanfare'))
# if we're not in the main menu, just print the invite
# (don't want to screw up an in-progress game)
in_game = not isinstance(_ba.get_foreground_host_session(),
mainmenu.MainMenuSession)
if in_game:
ba.screenmessage(ba.Lstr(
value='${A}\n${B}',
subs=[('${A}',
ba.Lstr(resource='gatherWindow.partyInviteText',
subs=[('${NAME}', name)])),
('${B}',
ba.Lstr(
resource='gatherWindow.partyInviteGooglePlayExtraText'))
]),
color=(0.5, 1, 0))
else:
def do_accept(inv_id: str) -> None:
_ba.accept_party_invitation(inv_id)
conf = confirm.ConfirmWindow(
ba.Lstr(resource='gatherWindow.partyInviteText',
subs=[('${NAME}', name)]),
ba.Call(do_accept, invite_id),
width=500,
height=150,
color=(0.75, 1.0, 0.0),
ok_text=ba.Lstr(resource='gatherWindow.partyInviteAcceptText'),
cancel_text=ba.Lstr(resource='gatherWindow.partyInviteIgnoreText'))
# FIXME: Ugly.
# Let's store the invite-id away on the confirm window so we know if
# we need to kill it later.
conf.party_invite_id = invite_id # type: ignore
# store a weak-ref so we can get at this later
ba.app.invite_confirm_windows.append(weakref.ref(conf))
# go ahead and prune our weak refs while we're here.
ba.app.invite_confirm_windows = [
w for w in ba.app.invite_confirm_windows if w() is not None
]

View file

@ -322,8 +322,8 @@ class AdvancedSettingsWindow(ba.Window):
subs=[('${APP_NAME}', ba.Lstr(resource='titleText'))
]),
autoselect=True,
on_activate_call=ba.Call(ba.open_url,
'http://bombsquadgame.com/translate'))
on_activate_call=ba.Call(
ba.open_url, 'https://legacy.ballistica.net/translate'))
self._lang_status_text = ba.textwidget(parent=self._subcontainer,
position=(self._sub_width * 0.5,

View file

@ -91,21 +91,6 @@ class ControlsSettingsWindow(ba.Window):
else:
show_remote = False
show_ps3 = False
# if platform == 'mac':
# show_ps3 = True
# height += spacing
show360 = False
# if platform == 'mac' or is_fire_tv:
# show360 = True
# height += spacing
show_mac_wiimote = False
# if platform == 'mac' and _ba.is_xcode_build():
# show_mac_wiimote = True
# height += spacing
# On windows (outside of oculus/vr), show an option to disable xinput.
show_xinput_toggle = False
if platform == 'windows' and not app.vr_mode:
@ -152,9 +137,6 @@ class ControlsSettingsWindow(ba.Window):
self._keyboard_button: Optional[ba.Widget] = None
self._keyboard_2_button: Optional[ba.Widget] = None
self._idevices_button: Optional[ba.Widget] = None
self._ps3_button: Optional[ba.Widget] = None
self._xbox_360_button: Optional[ba.Widget] = None
self._wiimotes_button: Optional[ba.Widget] = None
ba.textwidget(parent=self._root_widget,
position=(0, height - 49),
@ -261,42 +243,6 @@ class ControlsSettingsWindow(ba.Window):
down_widget=self._idevices_button)
self._have_selected_child = True
v -= spacing
if show_ps3:
self._ps3_button = btn = ba.buttonwidget(
parent=self._root_widget,
position=((width - button_width) / 2 + 5, v),
size=(button_width, 43),
autoselect=True,
label=ba.Lstr(resource=self._r + '.ps3Text'),
on_activate_call=self._do_ps3_controllers)
if ba.app.ui.use_toolbars:
ba.widget(edit=btn,
right_widget=_ba.get_special_widget('party_button'))
v -= spacing
if show360:
self._xbox_360_button = btn = ba.buttonwidget(
parent=self._root_widget,
position=((width - button_width) / 2 - 1, v),
size=(button_width, 43),
autoselect=True,
label=ba.Lstr(resource=self._r + '.xbox360Text'),
on_activate_call=self._do_360_controllers)
if ba.app.ui.use_toolbars:
ba.widget(edit=btn,
right_widget=_ba.get_special_widget('party_button'))
v -= spacing
if show_mac_wiimote:
self._wiimotes_button = btn = ba.buttonwidget(
parent=self._root_widget,
position=((width - button_width) / 2 + 5, v),
size=(button_width, 43),
autoselect=True,
label=ba.Lstr(resource=self._r + '.wiimotesText'),
on_activate_call=self._do_wiimotes)
if ba.app.ui.use_toolbars:
ba.widget(edit=btn,
right_widget=_ba.get_special_widget('party_button'))
v -= spacing
if show_xinput_toggle:
@ -397,31 +343,6 @@ class ControlsSettingsWindow(ba.Window):
ba.app.ui.set_main_menu_window(
RemoteAppSettingsWindow().get_root_widget())
def _do_ps3_controllers(self) -> None:
# pylint: disable=cyclic-import
from bastd.ui.settings.ps3controller import PS3ControllerSettingsWindow
self._save_state()
ba.containerwidget(edit=self._root_widget, transition='out_left')
ba.app.ui.set_main_menu_window(
PS3ControllerSettingsWindow().get_root_widget())
def _do_360_controllers(self) -> None:
# pylint: disable=cyclic-import
from bastd.ui.settings.xbox360controller import (
XBox360ControllerSettingsWindow)
self._save_state()
ba.containerwidget(edit=self._root_widget, transition='out_left')
ba.app.ui.set_main_menu_window(
XBox360ControllerSettingsWindow().get_root_widget())
def _do_wiimotes(self) -> None:
# pylint: disable=cyclic-import
from bastd.ui.settings.wiimote import WiimoteSettingsWindow
self._save_state()
ba.containerwidget(edit=self._root_widget, transition='out_left')
ba.app.ui.set_main_menu_window(
WiimoteSettingsWindow().get_root_widget())
def _do_gamepads(self) -> None:
# pylint: disable=cyclic-import
from bastd.ui.settings.gamepadselect import GamepadSelectWindow
@ -449,12 +370,6 @@ class ControlsSettingsWindow(ba.Window):
sel_name = 'Keyboard2'
elif sel == self._idevices_button:
sel_name = 'iDevices'
elif sel == self._ps3_button:
sel_name = 'PS3'
elif sel == self._xbox_360_button:
sel_name = 'xbox360'
elif sel == self._wiimotes_button:
sel_name = 'Wiimotes'
else:
sel_name = 'Back'
ba.app.ui.window_states[type(self)] = sel_name
@ -471,12 +386,6 @@ class ControlsSettingsWindow(ba.Window):
sel = self._keyboard_2_button
elif sel_name == 'iDevices':
sel = self._idevices_button
elif sel_name == 'PS3':
sel = self._ps3_button
elif sel_name == 'xbox360':
sel = self._xbox_360_button
elif sel_name == 'Wiimotes':
sel = self._wiimotes_button
elif sel_name == 'Back':
sel = self._back_button
else:

View file

@ -4,11 +4,296 @@
from __future__ import annotations
import time
import copy
import weakref
from threading import Thread
from typing import TYPE_CHECKING
import _ba
import ba
from bastd.ui.settings import testing
from bastd.ui.settings.testing import TestingWindow
if TYPE_CHECKING:
from typing import Callable, Any, Optional
class NetTestingWindow(testing.TestingWindow):
class NetTestingWindow(ba.Window):
"""Window that runs a networking test suite to help diagnose issues."""
def __init__(self, transition: str = 'in_right'):
self._width = 820
self._height = 500
self._printed_lines: list[str] = []
uiscale = ba.app.ui.uiscale
super().__init__(root_widget=ba.containerwidget(
size=(self._width, self._height),
scale=(1.56 if uiscale is ba.UIScale.SMALL else
1.2 if uiscale is ba.UIScale.MEDIUM else 0.8),
stack_offset=(0.0, -7 if uiscale is ba.UIScale.SMALL else 0.0),
transition=transition))
self._done_button = ba.buttonwidget(parent=self._root_widget,
position=(40, self._height - 77),
size=(120, 60),
scale=0.8,
autoselect=True,
label=ba.Lstr(resource='doneText'),
on_activate_call=self._done)
self._copy_button = ba.buttonwidget(parent=self._root_widget,
position=(self._width - 200,
self._height - 77),
size=(100, 60),
scale=0.8,
autoselect=True,
label=ba.Lstr(resource='copyText'),
on_activate_call=self._copy)
self._settings_button = ba.buttonwidget(
parent=self._root_widget,
position=(self._width - 100, self._height - 77),
size=(60, 60),
scale=0.8,
autoselect=True,
label=ba.Lstr(value='...'),
on_activate_call=self._show_val_testing)
twidth = self._width - 450
ba.textwidget(
parent=self._root_widget,
position=(self._width * 0.5, self._height - 55),
size=(0, 0),
text=ba.Lstr(resource='settingsWindowAdvanced.netTestingText'),
color=(0.8, 0.8, 0.8, 1.0),
h_align='center',
v_align='center',
maxwidth=twidth)
self._scroll = ba.scrollwidget(parent=self._root_widget,
position=(50, 50),
size=(self._width - 100,
self._height - 140),
capture_arrows=True,
autoselect=True)
self._rows = ba.columnwidget(parent=self._scroll)
ba.containerwidget(edit=self._root_widget,
cancel_button=self._done_button)
# Now kick off the tests.
# Pass a weak-ref to this window so we don't keep it alive
# if we back out before it completes. Also set is as daemon
# so it doesn't keep the app running if the user is trying to quit.
Thread(
daemon=True,
target=ba.Call(_run_diagnostics, weakref.ref(self)),
).start()
def print(self, text: str, color: tuple[float, float, float]) -> None:
"""Print text to our console thingie."""
for line in text.splitlines():
txt = ba.textwidget(parent=self._rows,
color=color,
text=line,
scale=0.75,
flatness=1.0,
shadow=0.0,
size=(0, 20))
ba.containerwidget(edit=self._rows, visible_child=txt)
self._printed_lines.append(line)
def _copy(self) -> None:
if not ba.clipboard_is_supported():
ba.screenmessage('Clipboard not supported on this platform.',
color=(1, 0, 0))
return
ba.clipboard_set_text('\n'.join(self._printed_lines))
ba.screenmessage(f'{len(self._printed_lines)} lines copied.')
def _show_val_testing(self) -> None:
ba.app.ui.set_main_menu_window(NetValTestingWindow().get_root_widget())
ba.containerwidget(edit=self._root_widget, transition='out_left')
def _done(self) -> None:
# pylint: disable=cyclic-import
from bastd.ui.settings.advanced import AdvancedSettingsWindow
ba.app.ui.set_main_menu_window(
AdvancedSettingsWindow(transition='in_left').get_root_widget())
ba.containerwidget(edit=self._root_widget, transition='out_right')
def _run_diagnostics(weakwin: weakref.ref[NetTestingWindow]) -> None:
# pylint: disable=too-many-statements
from efro.util import utc_now
have_error = [False]
# We're running in a background thread but UI stuff needs to run
# in the logic thread; give ourself a way to pass stuff to it.
def _print(text: str, color: tuple[float, float, float] = None) -> None:
def _print_in_logic_thread() -> None:
win = weakwin()
if win is not None:
win.print(text, (1.0, 1.0, 1.0) if color is None else color)
ba.pushcall(_print_in_logic_thread, from_other_thread=True)
def _print_test_results(call: Callable[[], Any]) -> None:
"""Run the provided call; return success/fail text & color."""
starttime = time.monotonic()
try:
call()
duration = time.monotonic() - starttime
_print(f'Succeeded in {duration:.2f}s.', color=(0, 1, 0))
except Exception:
import traceback
duration = time.monotonic() - starttime
_print(traceback.format_exc(), color=(1.0, 1.0, 0.3))
_print(f'Failed in {duration:.2f}s.', color=(1, 0, 0))
have_error[0] = True
try:
_print(f'Running network diagnostics...\n'
f'ua: {_ba.app.user_agent_string}\n'
f'time: {utc_now()}.')
if bool(False):
_print('\nRunning dummy success test...')
_print_test_results(_dummy_success)
_print('\nRunning dummy fail test...')
_print_test_results(_dummy_fail)
# V1 ping
baseaddr = _ba.get_master_server_address(source=0, version=1)
_print(f'\nContacting V1 master-server src0 ({baseaddr})...')
_print_test_results(lambda: _test_fetch(baseaddr))
# V1 alternate ping
baseaddr = _ba.get_master_server_address(source=1, version=1)
_print(f'\nContacting V1 master-server src1 ({baseaddr})...')
_print_test_results(lambda: _test_fetch(baseaddr))
_print(f'\nV1-test-log: {ba.app.net.v1_test_log}')
for srcid, result in sorted(ba.app.net.v1_ctest_results.items()):
_print(f'\nV1 src{srcid} result: {result}')
curv1addr = _ba.get_master_server_address(version=1)
_print(f'\nUsing V1 address: {curv1addr}')
_print('\nRunning V1 transaction...')
_print_test_results(_test_v1_transaction)
# V2 ping
baseaddr = _ba.get_master_server_address(version=2)
_print(f'\nContacting V2 master-server ({baseaddr})...')
_print_test_results(lambda: _test_fetch(baseaddr))
# Get V2 nearby region
with ba.app.net.region_pings_lock:
region_pings = copy.deepcopy(ba.app.net.region_pings)
nearest_region = (None if not region_pings else sorted(
region_pings.items(), key=lambda i: i[1])[0])
if nearest_region is not None:
nearstr = f'{nearest_region[0]}: {nearest_region[1]:.0f}ms'
else:
nearstr = '-'
_print(f'\nChecking nearest V2 region ping ({nearstr})...')
_print_test_results(lambda: _test_nearby_region_ping(nearest_region))
if have_error[0]:
_print('\nDiagnostics complete. Some diagnostics failed.',
color=(10, 0, 0))
else:
_print('\nDiagnostics complete. Everything looks good!',
color=(0, 1, 0))
except Exception:
import traceback
_print(
f'An unexpected error occurred during testing;'
f' please report this.\n'
f'{traceback.format_exc()}',
color=(1, 0, 0))
def _dummy_success() -> None:
"""Dummy success test."""
time.sleep(1.2)
def _dummy_fail() -> None:
"""Dummy fail test case."""
raise RuntimeError('fail-test')
def _test_v1_transaction() -> None:
"""Dummy fail test case."""
if _ba.get_account_state() != 'signed_in':
raise RuntimeError('Not signed in.')
starttime = time.monotonic()
# Gets set to True on success or string on error.
results: list[Any] = [False]
def _cb(cbresults: Any) -> None:
# Simply set results here; our other thread acts on them.
if not isinstance(cbresults, dict) or 'party_code' not in cbresults:
results[0] = 'Unexpected transaction response'
return
results[0] = True # Success!
def _do_it() -> None:
# Fire off a transaction with a callback.
_ba.add_transaction(
{
'type': 'PRIVATE_PARTY_QUERY',
'expire_time': time.time() + 20,
},
callback=_cb,
)
_ba.run_transactions()
ba.pushcall(_do_it, from_other_thread=True)
while results[0] is False:
time.sleep(0.01)
if time.monotonic() - starttime > 10.0:
raise RuntimeError('timed out')
# If we got left a string, its an error.
if isinstance(results[0], str):
raise RuntimeError(results[0])
def _test_fetch(baseaddr: str) -> None:
# pylint: disable=consider-using-with
import urllib.request
response = urllib.request.urlopen(urllib.request.Request(
f'{baseaddr}/ping', None, {'User-Agent': _ba.app.user_agent_string}),
timeout=10.0)
if response.getcode() != 200:
raise RuntimeError(
f'Got unexpected response code {response.getcode()}.')
data = response.read()
if data != b'pong':
raise RuntimeError('Got unexpected response data.')
def _test_nearby_region_ping(
nearest_region: Optional[tuple[str, float]]) -> None:
"""Try to ping nearest v2 region."""
if nearest_region is None:
raise RuntimeError('No nearest region.')
if nearest_region[1] > 500:
raise RuntimeError('Ping too high.')
class NetValTestingWindow(TestingWindow):
"""Window to test network related settings."""
def __init__(self, transition: str = 'in_right'):
@ -35,6 +320,8 @@ class NetTestingWindow(testing.TestingWindow):
'increment': 1
},
]
testing.TestingWindow.__init__(
self, ba.Lstr(resource='settingsWindowAdvanced.netTestingText'),
entries, transition)
super().__init__(
title=ba.Lstr(resource='settingsWindowAdvanced.netTestingText'),
entries=entries,
transition=transition,
back_call=lambda: NetTestingWindow(transition='in_left'))

View file

@ -11,7 +11,7 @@ import _ba
import ba
if TYPE_CHECKING:
from typing import Any
from typing import Any, Callable, Optional
class TestingWindow(ba.Window):
@ -20,11 +20,13 @@ class TestingWindow(ba.Window):
def __init__(self,
title: ba.Lstr,
entries: list[dict[str, Any]],
transition: str = 'in_right'):
transition: str = 'in_right',
back_call: Optional[Callable[[], ba.Window]] = None):
uiscale = ba.app.ui.uiscale
self._width = 600
self._height = 324 if uiscale is ba.UIScale.SMALL else 400
self._entries = copy.deepcopy(entries)
self._back_call = back_call
super().__init__(root_widget=ba.containerwidget(
size=(self._width, self._height),
transition=transition,
@ -176,8 +178,8 @@ class TestingWindow(ba.Window):
def _do_back(self) -> None:
# pylint: disable=cyclic-import
import bastd.ui.settings.advanced
from bastd.ui.settings.advanced import AdvancedSettingsWindow
ba.containerwidget(edit=self._root_widget, transition='out_right')
ba.app.ui.set_main_menu_window(
bastd.ui.settings.advanced.AdvancedSettingsWindow(
transition='in_left').get_root_widget())
backwin = (self._back_call() if self._back_call is not None else
AdvancedSettingsWindow(transition='in_left'))
ba.app.ui.set_main_menu_window(backwin.get_root_widget())

View file

@ -7,13 +7,13 @@ from __future__ import annotations
from typing import TYPE_CHECKING
import ba
from bastd.ui.settings import testing
from bastd.ui.settings.testing import TestingWindow
if TYPE_CHECKING:
from typing import Any
class VRTestingWindow(testing.TestingWindow):
class VRTestingWindow(TestingWindow):
"""Window for testing vr settings."""
def __init__(self, transition: str = 'in_right'):

View file

@ -266,8 +266,9 @@ if TYPE_CHECKING:
def Call(*_args: Any, **_keywds: Any) -> Any:
...
# (Type-safe Partial)
# A convenient wrapper around functools.partial which adds type-safety
# (though it does not support keyword arguments).
partial = Call
tpartial = Call
else:
partial = functools.partial
tpartial = functools.partial

View file

@ -10,26 +10,45 @@ data formats in a nondestructive manner.
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, TypeVar
from efro.dataclassio._outputter import _Outputter
from efro.dataclassio._inputter import _Inputter
from efro.dataclassio._base import Codec, IOAttrs
from efro.dataclassio._prep import ioprep, ioprepped, is_ioprepped_dataclass
from efro.dataclassio._base import Codec, IOAttrs, IOExtendedData
from efro.dataclassio._prep import (ioprep, ioprepped, will_ioprep,
is_ioprepped_dataclass)
from efro.dataclassio._pathcapture import DataclassFieldLookup
if TYPE_CHECKING:
from typing import Any
from typing import Any, Optional
__all__ = [
'Codec', 'IOAttrs', 'ioprep', 'ioprepped', 'is_ioprepped_dataclass',
'DataclassFieldLookup', 'dataclass_to_dict', 'dataclass_to_json',
'dataclass_from_dict', 'dataclass_from_json', 'dataclass_validate'
'Codec', 'IOAttrs', 'IOExtendedData', 'ioprep', 'ioprepped', 'will_ioprep',
'is_ioprepped_dataclass', 'DataclassFieldLookup', 'dataclass_to_dict',
'dataclass_to_json', 'dataclass_from_dict', 'dataclass_from_json',
'dataclass_validate'
]
T = TypeVar('T')
class JsonStyle(Enum):
"""Different style types for json."""
# Single line, no spaces, no sorting. Not deterministic.
# Use this for most storage purposes.
FAST = 'fast'
# Single line, no spaces, sorted keys. Deterministic.
# Use this when output may be hashed or compared for equality.
SORTED = 'sorted'
# Multiple lines, spaces, sorted keys. Deterministic.
# Use this for pretty human readable output.
PRETTY = 'pretty'
def dataclass_to_dict(obj: Any,
codec: Codec = Codec.JSON,
coerce_to_float: bool = True) -> dict:
@ -45,7 +64,7 @@ def dataclass_to_dict(obj: Any,
the ability to do a lossless round-trip with data).
If coerce_to_float is True, integer values present on float typed fields
will be converted to floats in the dict output. If False, a TypeError
will be converted to float in the dict output. If False, a TypeError
will be triggered.
"""
@ -59,18 +78,23 @@ def dataclass_to_dict(obj: Any,
def dataclass_to_json(obj: Any,
coerce_to_float: bool = True,
pretty: bool = False) -> str:
pretty: bool = False,
sort_keys: Optional[bool] = None) -> str:
"""Utility function; return a json string from a dataclass instance.
Basically json.dumps(dataclass_to_dict(...)).
By default, keys are sorted for pretty output and not otherwise, but
this can be overridden by supplying a value for the 'sort_keys' arg.
"""
import json
jdict = dataclass_to_dict(obj=obj,
coerce_to_float=coerce_to_float,
codec=Codec.JSON)
if sort_keys is None:
sort_keys = pretty
if pretty:
return json.dumps(jdict, indent=2, sort_keys=True)
return json.dumps(jdict, separators=(',', ':'))
return json.dumps(jdict, indent=2, sort_keys=sort_keys)
return json.dumps(jdict, separators=(',', ':'), sort_keys=sort_keys)
def dataclass_from_dict(cls: type[T],
@ -94,10 +118,10 @@ def dataclass_from_dict(cls: type[T],
(as this would break the ability to do a lossless round-trip with data).
If coerce_to_float is True, int values passed for float typed fields
will be converted to float values. Otherwise a TypeError is raised.
will be converted to float values. Otherwise, a TypeError is raised.
If allow_unknown_attrs is False, AttributeErrors will be raised for
attributes present in the dict but not on the data class. Otherwise they
attributes present in the dict but not on the data class. Otherwise, they
will be preserved as part of the instance and included if it is
exported back to a dict, unless discard_unknown_attrs is True, in which
case they will simply be discarded.

View file

@ -12,15 +12,6 @@ from typing import TYPE_CHECKING, get_args
# noinspection PyProtectedMember
from typing import _AnnotatedAlias # type: ignore
_pytz_utc: Any
# We don't *require* pytz but we want to support it for tzinfos if available.
try:
import pytz
_pytz_utc = pytz.utc
except ModuleNotFoundError:
_pytz_utc = None # pylint: disable=invalid-name
if TYPE_CHECKING:
from typing import Any, Optional
@ -32,14 +23,6 @@ SIMPLE_TYPES = {int, bool, str, float, type(None)}
EXTRA_ATTRS_ATTR = '_DCIOEXATTRS'
def _ensure_datetime_is_timezone_aware(value: datetime.datetime) -> None:
# We only support timezone-aware utc times.
if (value.tzinfo is not datetime.timezone.utc
and (_pytz_utc is None or value.tzinfo is not _pytz_utc)):
raise ValueError(
'datetime values must have timezone set as timezone.utc')
def _raise_type_error(fieldpath: str, valuetype: type,
expected: tuple[type, ...]) -> None:
"""Raise an error when a field value's type does not match expected."""
@ -67,6 +50,24 @@ class Codec(Enum):
FIRESTORE = 'firestore'
class IOExtendedData:
"""A class that data types can inherit from for extra functionality."""
def will_output(self) -> None:
"""Called before data is sent to an outputter.
Can be overridden to validate or filter data before
sending it on its way.
"""
@classmethod
def will_input(cls, data: dict) -> None:
"""Called on raw data before a class instance is created from it.
Can be overridden to migrate old data formats to new, etc.
"""
def _is_valid_for_codec(obj: Any, codec: Codec) -> bool:
"""Return whether a value consists solely of json-supported types.
@ -126,7 +127,7 @@ class IOAttrs:
# Turning off store_default requires the field to have either
# a default_factory or a default
if not self.store_default:
default_factory: Any = field.default_factory # type: ignore
default_factory: Any = field.default_factory
if (default_factory is dataclasses.MISSING
and field.default is dataclasses.MISSING):
raise TypeError(f'Field {field.name} of {cls} has'
@ -163,7 +164,7 @@ def _get_origin(anntype: Any) -> Any:
def _parse_annotated(anntype: Any) -> tuple[Any, Optional[IOAttrs]]:
"""Parse Annotated() constructs, returning annotated type & IOAttrs."""
# If we get an Annotated[foo, bar, eep] we take
# foo as the actual type and we look for IOAttrs instances in
# foo as the actual type, and we look for IOAttrs instances in
# bar/eep to affect our behavior.
ioattrs: Optional[IOAttrs] = None
if isinstance(anntype, _AnnotatedAlias):

View file

@ -14,11 +14,11 @@ import typing
import datetime
from typing import TYPE_CHECKING, Generic, TypeVar
from efro.util import enum_by_value
from efro.util import enum_by_value, check_utc
from efro.dataclassio._base import (Codec, _parse_annotated, EXTRA_ATTRS_ATTR,
_is_valid_for_codec, _get_origin,
SIMPLE_TYPES, _raise_type_error,
_ensure_datetime_is_timezone_aware)
IOExtendedData)
from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING:
@ -48,6 +48,12 @@ class _Inputter(Generic[T]):
def run(self, values: dict) -> T:
"""Do the thing."""
# For special extended data types, call their 'will_output' callback.
tcls = self._cls
if issubclass(tcls, IOExtendedData):
tcls.will_input(values)
out = self._dataclass_from_input(self._cls, '', values)
assert isinstance(out, self._cls)
return out
@ -70,14 +76,14 @@ class _Inputter(Generic[T]):
return value
if origin is typing.Union:
# Currently the only unions we support are None/Value
# Currently, the only unions we support are None/Value
# (translated from Optional), which we verified on prep.
# So let's treat this as a simple optional case.
if value is None:
return None
childanntypes_l = [
c for c in typing.get_args(anntype) if c is not type(None)
]
] # noqa (pycodestyle complains about *is* with type)
assert len(childanntypes_l) == 1
return self._value_from_input(cls, fieldpath, childanntypes_l[0],
value, ioattrs)
@ -127,7 +133,7 @@ class _Inputter(Generic[T]):
"""Given input data, returns bytes."""
import base64
# For firestore, bytes are passed as-is. Otherwise they're encoded
# For firestore, bytes are passed as-is. Otherwise, they're encoded
# as base64.
if self._codec is Codec.FIRESTORE:
if not isinstance(value, bytes):
@ -159,6 +165,7 @@ class _Inputter(Generic[T]):
prep = PrepSession(explicit=False).prep_dataclass(cls,
recursion_level=0)
assert prep is not None
extra_attrs = {}
@ -268,7 +275,7 @@ class _Inputter(Generic[T]):
cls, fieldpath, valanntype, val, ioattrs)
elif issubclass(keyanntype, Enum):
# In prep we verified that all these enums' values have
# In prep, we verified that all these enums' values have
# the same type, so we can just look at the first to see if
# this is a string enum or an int enum.
enumvaltype = type(next(iter(keyanntype)).value)
@ -344,7 +351,7 @@ class _Inputter(Generic[T]):
f'Invalid input value for "{fieldpath}" on'
f' "{cls.__name__}";'
f' expected a datetime, got a {type(value).__name__}')
_ensure_datetime_is_timezone_aware(value)
check_utc(value)
return value
assert self._codec is Codec.JSON
@ -355,9 +362,9 @@ class _Inputter(Generic[T]):
f'Invalid input value for "{fieldpath}" on "{cls.__name__}";'
f' expected a list, got a {type(value).__name__}')
if len(value) != 7 or not all(isinstance(x, int) for x in value):
raise TypeError(
raise ValueError(
f'Invalid input value for "{fieldpath}" on "{cls.__name__}";'
f' expected a list of 7 ints.')
f' expected a list of 7 ints, got {[type(v) for v in value]}.')
out = datetime.datetime( # type: ignore
*value, tzinfo=datetime.timezone.utc)
if ioattrs is not None:
@ -380,7 +387,7 @@ class _Inputter(Generic[T]):
assert childanntypes
if len(value) != len(childanntypes):
raise TypeError(f'Invalid tuple input for "{fieldpath}";'
raise ValueError(f'Invalid tuple input for "{fieldpath}";'
f' expected {len(childanntypes)} values,'
f' found {len(value)}.')

View file

@ -14,10 +14,11 @@ import typing
import datetime
from typing import TYPE_CHECKING
from efro.util import check_utc
from efro.dataclassio._base import (Codec, _parse_annotated, EXTRA_ATTRS_ATTR,
_is_valid_for_codec, _get_origin,
SIMPLE_TYPES, _raise_type_error,
_ensure_datetime_is_timezone_aware)
IOExtendedData)
from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING:
@ -37,6 +38,11 @@ class _Outputter:
def run(self) -> Any:
"""Do the thing."""
# For special extended data types, call their 'will_output' callback.
if isinstance(self._obj, IOExtendedData):
self._obj.will_output()
return self._process_dataclass(type(self._obj), self._obj, '')
def _process_dataclass(self, cls: type, obj: Any, fieldpath: str) -> Any:
@ -44,6 +50,7 @@ class _Outputter:
# pylint: disable=too-many-branches
prep = PrepSession(explicit=False).prep_dataclass(type(obj),
recursion_level=0)
assert prep is not None
fields = dataclasses.fields(obj)
out: Optional[dict[str, Any]] = {} if self._create else None
for field in fields:
@ -60,7 +67,7 @@ class _Outputter:
# If we're not storing default values for this fella,
# we can skip all output processing if we've got a default value.
if ioattrs is not None and not ioattrs.store_default:
default_factory: Any = field.default_factory # type: ignore
default_factory: Any = field.default_factory
if default_factory is not dataclasses.MISSING:
if default_factory() == value:
continue
@ -113,14 +120,14 @@ class _Outputter:
return value if self._create else None
if origin is typing.Union:
# Currently the only unions we support are None/Value
# Currently, the only unions we support are None/Value
# (translated from Optional), which we verified on prep.
# So let's treat this as a simple optional case.
if value is None:
return None
childanntypes_l = [
c for c in typing.get_args(anntype) if c is not type(None)
]
] # noqa (pycodestyle complains about *is* with type)
assert len(childanntypes_l) == 1
return self._process_value(cls, fieldpath, childanntypes_l[0],
value, ioattrs)
@ -242,7 +249,7 @@ class _Outputter:
if not isinstance(value, origin):
raise TypeError(f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.')
_ensure_datetime_is_timezone_aware(value)
check_utc(value)
if ioattrs is not None:
ioattrs.validate_datetime(value, fieldpath)
if self._codec is Codec.FIRESTORE:

View file

@ -36,6 +36,7 @@ class _PathCapture:
prep = PrepSession(explicit=False).prep_dataclass(self._cls,
recursion_level=0)
assert prep is not None
try:
anntype = prep.annotations[name]
except KeyError as exc:
@ -75,7 +76,7 @@ class DataclassFieldLookup(Generic[T]):
# We tell the type system that we are returning an instance
# of our class, which allows it to perform type checking on
# member lookups. In reality, however, we are providing a
# special object which captures path lookups so we can build
# special object which captures path lookups, so we can build
# a string from them.
if not TYPE_CHECKING:
out = callback(_PathCapture(self.cls))

View file

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, TypeVar, get_type_hints
from efro.dataclassio._base import _parse_annotated, _get_origin, SIMPLE_TYPES
if TYPE_CHECKING:
from typing import Any
from typing import Any, Optional
T = TypeVar('T')
@ -27,11 +27,15 @@ T = TypeVar('T')
# (basically for detecting recursive types)
MAX_RECURSION = 10
# Attr name for data we store on dataclass types as part of prep.
# Attr name for data we store on dataclass types that have been prepped.
PREP_ATTR = '_DCIOPREP'
# We also store the prep-session while the prep is in progress.
# (necessary to support recursive types).
PREP_SESSION_ATTR = '_DCIOPREPSESSION'
def ioprep(cls: type) -> None:
def ioprep(cls: type, globalns: dict = None) -> None:
"""Prep a dataclass type for use with this module's functionality.
Prepping ensures that all types contained in a data class as well as
@ -45,10 +49,14 @@ def ioprep(cls: type) -> None:
Prepping a dataclass involves evaluating its type annotations, which,
as of PEP 563, are stored simply as strings. This evaluation is done
in the module namespace containing the class, so all referenced types
must be defined at that level.
with localns set to the class dict (so that types defined in the class
can be used) and globalns set to the containing module's class.
It is possible to override globalns for special cases such as when
prepping happens as part of an execed string instead of within a
module.
"""
PrepSession(explicit=True).prep_dataclass(cls, recursion_level=0)
PrepSession(explicit=True,
globalns=globalns).prep_dataclass(cls, recursion_level=0)
def ioprepped(cls: type[T]) -> type[T]:
@ -64,6 +72,23 @@ def ioprepped(cls: type[T]) -> type[T]:
return cls
def will_ioprep(cls: type[T]) -> type[T]:
"""Class decorator hinting that we will prep a class later.
In some cases (such as recursive types) we cannot use the @ioprepped
decorator and must instead call ioprep() explicitly later. However,
some of our custom pylint checking behaves differently when the
@ioprepped decorator is present, in that case requiring type annotations
to be present and not simply forward declared under an "if TYPE_CHECKING"
block. (since they are used at runtime).
The @will_ioprep decorator triggers the same pylint behavior
differences as @ioprepped (which are necessary for the later ioprep() call
to work correctly) but without actually running any prep itself.
"""
return cls
def is_ioprepped_dataclass(obj: Any) -> bool:
"""Return whether the obj is an ioprepped dataclass type or instance."""
cls = obj if isinstance(obj, type) else type(obj)
@ -87,11 +112,19 @@ class PrepData:
class PrepSession:
"""Context for a prep."""
def __init__(self, explicit: bool):
def __init__(self, explicit: bool, globalns: Optional[dict] = None):
self.explicit = explicit
self.globalns = globalns
def prep_dataclass(self, cls: type, recursion_level: int) -> PrepData:
"""Run prep on a dataclass if necessary and return its prep data."""
def prep_dataclass(self, cls: type,
recursion_level: int) -> Optional[PrepData]:
"""Run prep on a dataclass if necessary and return its prep data.
The only case where this will return None is for recursive types
if the type is already being prepped higher in the call order.
"""
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
# We should only need to do this once per dataclass.
existing_data = getattr(cls, PREP_ATTR, None)
@ -99,8 +132,9 @@ class PrepSession:
assert isinstance(existing_data, PrepData)
return existing_data
# If we run into classes containing themselves, we may have
# to do something smarter to handle it.
# Sanity check.
# Note that we now support recursive types via the PREP_SESSION_ATTR,
# so we theoretically shouldn't run into this this.
if recursion_level > MAX_RECURSION:
raise RuntimeError('Max recursion exceeded.')
@ -108,6 +142,18 @@ class PrepSession:
if not isinstance(cls, type) or not dataclasses.is_dataclass(cls):
raise TypeError(f'Passed arg {cls} is not a dataclass type.')
# Add a pointer to the prep-session while doing the prep.
# This way we can ignore types that we're already in the process
# of prepping and can support recursive types.
existing_prep = getattr(cls, PREP_SESSION_ATTR, None)
if existing_prep is not None:
if existing_prep is self:
return None
# We shouldn't need to support failed preps
# or preps from multiple threads at once.
raise RuntimeError('Found existing in-progress prep.')
setattr(cls, PREP_SESSION_ATTR, self)
# Generate a warning on non-explicit preps; we prefer prep to
# happen explicitly at runtime so errors can be detected early on.
if not self.explicit:
@ -123,10 +169,10 @@ class PrepSession:
# which allows us to pick up nested classes, etc.
resolved_annotations = get_type_hints(cls,
localns=vars(cls),
globalns=self.globalns,
include_extras=True)
# pylint: enable=unexpected-keyword-arg
except Exception as exc:
print('GOT', cls.__dict__)
raise TypeError(
f'dataclassio prep for {cls} failed with error: {exc}.'
f' Make sure all types used in annotations are defined'
@ -175,6 +221,10 @@ class PrepSession:
annotations=resolved_annotations,
storage_names_to_attr_names=storage_names_to_attr_names)
setattr(cls, PREP_ATTR, prepdata)
# Clear our prep-session tag.
assert getattr(cls, PREP_SESSION_ATTR, None) is self
delattr(cls, PREP_SESSION_ATTR)
return prepdata
def prep_type(self, cls: type, attrname: str, anntype: Any,
@ -303,7 +353,7 @@ class PrepSession:
"""Run prep on a Union type."""
typeargs = typing.get_args(anntype)
if (len(typeargs) != 2
or len([c for c in typeargs if c is type(None)]) != 1):
or len([c for c in typeargs if c is type(None)]) != 1): # noqa
raise TypeError(f'Union {anntype} for attr \'{attrname}\' on'
f' {cls.__name__} is not supported by dataclassio;'
f' only 2 member Unions with one type being None'

View file

@ -84,6 +84,12 @@ def is_urllib_network_error(exc: BaseException) -> bool:
exc,
(urllib.error.URLError, ConnectionError, http.client.IncompleteRead,
http.client.BadStatusLine, socket.timeout)):
# Special case: although an HTTPError is a subclass of URLError,
# we don't return True for it. It means we have successfully
# communicated with the server but what we are asking for is
# not there/etc.
if isinstance(exc, urllib.error.HTTPError):
return False
return True
if isinstance(exc, OSError):
if exc.errno == 10051: # Windows unreachable network error.

View file

@ -766,7 +766,7 @@ class MessageReceiver:
# Return type of None translates to EmptyResponse.
responsetypes = tuple(EmptyResponse if r is type(None) else r
for r in responsetypes)
for r in responsetypes) # noqa
# Make sure our protocol has this message type registered and our
# return types exactly match. (Technically we could return a subset

View file

@ -75,6 +75,10 @@ def _default_color_enabled() -> bool:
if not sys.__stdout__.isatty():
return False
# Another common way to say the terminal can't do fancy stuff like color:
if os.environ.get('TERM') == 'dumb':
return False
# On windows, try to enable ANSI color mode.
if platform.system() == 'Windows':
return _windows_enable_color()

View file

@ -11,10 +11,19 @@ import functools
from enum import Enum
from typing import TYPE_CHECKING, cast, TypeVar, Generic
_pytz_utc: Any
# We don't *require* pytz, but we want to support it for tzinfos if available.
try:
import pytz
_pytz_utc = pytz.utc
except ModuleNotFoundError:
_pytz_utc = None # pylint: disable=invalid-name
if TYPE_CHECKING:
import asyncio
from efro.call import Call as Call # 'as Call' so we re-export.
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, NoReturn
T = TypeVar('T')
TVAL = TypeVar('TVAL')
@ -62,6 +71,14 @@ def enum_by_value(cls: type[TENUM], value: Any) -> TENUM:
(value, cls.__name__)) from None
def check_utc(value: datetime.datetime) -> None:
"""Ensure a datetime value is timezone-aware utc."""
if (value.tzinfo is not datetime.timezone.utc
and (_pytz_utc is None or value.tzinfo is not _pytz_utc)):
raise ValueError('datetime value does not have timezone set as'
' datetime.timezone.utc')
def utc_now() -> datetime.datetime:
"""Get offset-aware current utc time.
@ -240,7 +257,8 @@ class DispatchMethodWrapper(Generic[TARG, TRET]):
pass
@staticmethod
def register(func: Callable[[Any, Any], TRET]) -> Callable:
def register(
func: Callable[[Any, Any], TRET]) -> Callable[[Any, Any], TRET]:
"""Register a new dispatch handler for this dispatch-method."""
registry: dict[Any, Callable]
@ -312,12 +330,16 @@ class ValueDispatcher(Generic[TVAL, TRET]):
return handler()
return self._base_call(value)
def _add_handler(self, value: TVAL, call: Callable[[], TRET]) -> None:
def _add_handler(self, value: TVAL,
call: Callable[[], TRET]) -> Callable[[], TRET]:
if value in self._handlers:
raise RuntimeError(f'Duplicate handlers added for {value}')
self._handlers[value] = call
return call
def register(self, value: TVAL) -> Callable[[Callable[[], TRET]], None]:
def register(
self,
value: TVAL) -> Callable[[Callable[[], TRET]], Callable[[], TRET]]:
"""Add a handler to the dispatcher."""
from functools import partial
return partial(self._add_handler, value)
@ -343,13 +365,16 @@ class ValueDispatcher1Arg(Generic[TVAL, TARG, TRET]):
return handler(arg)
return self._base_call(value, arg)
def _add_handler(self, value: TVAL, call: Callable[[TARG], TRET]) -> None:
def _add_handler(self, value: TVAL,
call: Callable[[TARG], TRET]) -> Callable[[TARG], TRET]:
if value in self._handlers:
raise RuntimeError(f'Duplicate handlers added for {value}')
self._handlers[value] = call
return call
def register(self,
value: TVAL) -> Callable[[Callable[[TARG], TRET]], None]:
def register(
self, value: TVAL
) -> Callable[[Callable[[TARG], TRET]], Callable[[TARG], TRET]]:
"""Add a handler to the dispatcher."""
from functools import partial
return partial(self._add_handler, value)
@ -363,8 +388,9 @@ if TYPE_CHECKING:
def __call__(self, value: TVAL) -> TRET:
...
def register(self,
value: TVAL) -> Callable[[Callable[[TSELF], TRET]], None]:
def register(
self, value: TVAL
) -> Callable[[Callable[[TSELF], TRET]], Callable[[TSELF], TRET]]:
"""Add a handler to the dispatcher."""
...
@ -579,6 +605,8 @@ def human_readable_compact_id(num: int) -> str:
'o' is excluded due to similarity to '0'.
'z' is excluded due to similarity to '2'.
Therefore for n chars this can store values of 21^n.
When reading human input consisting of these IDs, it may be desirable
to map the disallowed chars to their corresponding allowed ones
('o' -> '0', etc).
@ -599,8 +627,39 @@ def compact_id(num: int) -> str:
friendly to humans due to using both capital and lowercase letters,
both 'O' and '0', etc.
Therefore for n chars this can store values of 62^n.
Sort order for these ids is the same as the original numbers.
"""
return _compact_id(
num, '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
'abcdefghijklmnopqrstuvwxyz')
def assert_never(value: NoReturn) -> NoReturn:
"""Trick for checking exhaustive handling of Enums, etc.
See https://github.com/python/typing/issues/735
"""
assert False, f'Unhandled value: {value} ({type(value).__name__})'
def unchanging_hostname() -> str:
"""Return an unchanging name for the local device.
Similar to the `hostname` call (or os.uname().nodename in Python)
except attempts to give a name that doesn't change depending on
network conditions. (A Mac will tend to go from Foo to Foo.local,
Foo.lan etc. throughout its various adventures)
"""
import os
import platform
import subprocess
# On Mac, this should give the computer name assigned in System Prefs.
if platform.system() == 'Darwin':
return subprocess.run(
['scutil', '--get', 'ComputerName'],
check=True,
capture_output=True).stdout.decode().strip().replace(' ', '-')
return os.uname().nodename

Binary file not shown.