bug fix , update

This commit is contained in:
Ayush Saini 2022-10-05 03:11:34 +05:30
parent 3e039cf40f
commit da4d57b0b1
15 changed files with 523 additions and 122 deletions

View file

@ -45,7 +45,7 @@ def bootstrap() -> None:
# Give a soft warning if we're being used with a different binary
# version than we expect.
expected_build = 20882
expected_build = 20887
running_build: int = env['build_number']
if running_build != expected_build:
print(

241
dist/ba_data/python/efro/debug.py vendored Normal file
View file

@ -0,0 +1,241 @@
# Released under the MIT License. See LICENSE for details.
#
"""Utilities for debugging memory leaks or other issues."""
from __future__ import annotations
import gc
import sys
import types
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, TextIO
ABS_MAX_LEVEL = 10
# NOTE: In general we want this toolset to allow us to explore
# which objects are holding references to others so we can diagnose
# leaks/etc. It is a bit tricky to do that, however, without
# affecting the objects we are looking at by adding temporary references
# from module dicts, function scopes, etc. So we need to try to be
# careful about cleaning up after ourselves and explicitly avoiding
# returning these temporary references wherever possible.
# A good test is running printrefs() repeatedly on some object that is
# known to be static. If the list of references or the ids or any
# the listed references changes with each run, it's a good sign that
# we're showing some temporary objects that we should be ignoring.
def getobjs(cls: type | str, contains: str | None = None) -> list[Any]:
"""Return all garbage-collected objects matching criteria.
'type' can be an actual type or a string in which case objects
whose types contain that string will be returned.
If 'contains' is provided, objects will be filtered to those
containing that in their str() representations.
"""
# Don't wanna return stuff waiting to be garbage-collected.
gc.collect()
if not isinstance(cls, type | str):
raise TypeError('Expected a type or string for cls')
if not isinstance(contains, str | None):
raise TypeError('Expected a string or None for contains')
if isinstance(cls, str):
objs = [o for o in gc.get_objects() if cls in str(type(o))]
else:
objs = [o for o in gc.get_objects() if isinstance(o, cls)]
if contains is not None:
objs = [o for o in objs if contains in str(o)]
return objs
def getobj(objid: int) -> Any:
"""Return a garbage-collected object by its id.
Remember that this is VERY inefficient and should only ever be used
for debugging.
"""
if not isinstance(objid, int):
raise TypeError(f'Expected an int for objid; got a {type(objid)}.')
# Don't wanna return stuff waiting to be garbage-collected.
for obj in gc.get_objects():
if id(obj) == objid:
return obj
raise RuntimeError(f'Object with id {objid} not found.')
def getrefs(obj: Any) -> list[Any]:
"""Given an object, return things referencing it."""
v = vars() # Ignore ref coming from locals.
return [o for o in gc.get_referrers(obj) if o is not v]
def printfiles(file: TextIO | None = None) -> None:
"""Print info about open files in the current app."""
import io
file = sys.stderr if file is None else file
try:
import psutil
except ImportError:
print(
"Error: printfiles requires the 'psutil' module to be installed.",
file=file)
return
proc = psutil.Process()
# Let's grab all Python file handles so we can associate raw files
# with their Python objects when possible.
fileio_ids = {obj.fileno(): obj for obj in getobjs(io.FileIO)}
textio_ids = {obj.fileno(): obj for obj in getobjs(io.TextIOWrapper)}
# FIXME: we could do a more limited version of this when psutil is
# not present that simply includes Python's files.
print('Files open by this app (not limited to Python\'s):', file=file)
for i, ofile in enumerate(proc.open_files()):
# Mypy doesn't know about mode apparently.
# (and can't use type: ignore because we don't require psutil
# and then mypy complains about unused ignore comment when its
# not present)
mode = getattr(ofile, 'mode')
assert isinstance(mode, str)
textio = textio_ids.get(ofile.fd)
textio_s = id(textio) if textio is not None else '<not found>'
fileio = fileio_ids.get(ofile.fd)
fileio_s = id(fileio) if fileio is not None else '<not found>'
print(f'#{i+1}: path={ofile.path!r},'
f' fd={ofile.fd}, mode={mode!r}, TextIOWrapper={textio_s},'
f' FileIO={fileio_s}')
def printrefs(obj: Any,
max_level: int = 2,
exclude_objs: list[Any] | None = None,
expand_ids: list[int] | None = None,
file: TextIO | None = None) -> None:
"""Print human readable list of objects referring to an object.
'max_level' specifies how many levels of recursion are printed.
'exclude_objs' can be a list of exact objects to skip if found in the
referrers list. This can be useful to avoid printing the local context
where the object was passed in from (locals(), etc).
'expand_ids' can be a list of object ids; if that particular object is
found, it will always be expanded even if max_level has been reached.
"""
_printrefs(obj,
level=0,
max_level=max_level,
exclude_objs=[] if exclude_objs is None else exclude_objs,
expand_ids=[] if expand_ids is None else expand_ids,
file=sys.stderr if file is None else file)
def printtypes(limit: int = 50, file: TextIO | None = None) -> None:
"""Print a human readable list of which types have the most instances."""
assert limit > 0
objtypes: dict[str, int] = {}
gc.collect() # Recommended before get_objects().
allobjs = gc.get_objects()
allobjc = len(allobjs)
for obj in allobjs:
modname = type(obj).__module__
tpname = type(obj).__qualname__
if modname != 'builtins':
tpname = f'{modname}.{tpname}'
objtypes[tpname] = objtypes.get(tpname, 0) + 1
# Presumably allobjs contains stack-frame/dict type stuff
# from this function call which in turn contain refs to allobjs.
# Let's try to prevent these huge lists from accumulating until
# the cyclical collector (hopefully) gets to them.
allobjs.clear()
del allobjs
print(f'Types most allocated ({allobjc} total objects):', file=file)
for i, tpitem in enumerate(
sorted(objtypes.items(), key=lambda x: x[1],
reverse=True)[:limit]):
tpname, tpval = tpitem
percent = tpval / allobjc * 100.0
print(f'{i+1}: {tpname}: {tpval} ({percent:.2f}%)', file=file)
def _desctype(obj: Any) -> str:
cls = type(obj)
if cls is types.ModuleType:
return f'{type(obj).__name__} {obj.__name__}'
if cls is types.MethodType:
bnd = 'bound' if hasattr(obj, '__self__') else 'unbound'
return f'{bnd} {type(obj).__name__} {obj.__name__}'
return f'{type(obj).__name__}'
def _desc(obj: Any) -> str:
extra: str | None = None
if isinstance(obj, list | tuple):
# Print length and the first few types.
tps = [_desctype(i) for i in obj[:3]]
tpsj = ', '.join(tps)
tpss = (f', contains [{tpsj}, ...]'
if len(obj) > 3 else f', contains [{tpsj}]' if tps else '')
extra = f' (len {len(obj)}{tpss})'
elif isinstance(obj, dict):
# If it seems to be the vars() for a type or module,
# try to identify what.
for ref in getrefs(obj):
if hasattr(ref, '__dict__') and vars(ref) is obj:
extra = f' (vars for {_desctype(ref)} @ {id(ref)})'
# Generic dict: print length and the first few key:type pairs.
if extra is None:
pairs = [
f'{repr(n)}: {_desctype(v)}' for n, v in list(obj.items())[:3]
]
pairsj = ', '.join(pairs)
pairss = (f', contains {{{pairsj}, ...}}' if len(obj) > 3 else
f', contains {{{pairsj}}}' if pairs else '')
extra = f' (len {len(obj)}{pairss})'
if extra is None:
extra = ''
return f'{_desctype(obj)} @ {id(obj)}{extra}'
def _printrefs(obj: Any, level: int, max_level: int, exclude_objs: list,
expand_ids: list[int], file: TextIO) -> None:
ind = ' ' * level
print(ind + _desc(obj), file=file)
v = vars()
if level < max_level or (id(obj) in expand_ids and level < ABS_MAX_LEVEL):
refs = getrefs(obj)
for ref in refs:
# It seems we tend to get a transient cell object with contents
# set to obj. Would be nice to understand why that happens
# but just ignoring it for now.
if isinstance(ref, types.CellType) and ref.cell_contents is obj:
continue
# Ignore anything we were asked to ignore.
if exclude_objs is not None:
if any(ref is eobj for eobj in exclude_objs):
continue
# Ignore references from our locals.
if ref is v:
continue
# The 'refs' list we just made will be listed as a referrer
# of this obj, so explicitly exclude it from the obj's listing.
_printrefs(ref,
level=level + 1,
max_level=max_level,
exclude_objs=exclude_objs + [refs],
expand_ids=expand_ids,
file=file)

View file

@ -62,9 +62,17 @@ class RemoteError(Exception):
as a catch-all.
"""
def __init__(self, msg: str, peer_desc: str):
super().__init__(msg)
self._peer_desc = peer_desc
def __str__(self) -> str:
s = ''.join(str(arg) for arg in self.args)
return f'Remote Exception Follows:\n{s}'
# Indent so we can more easily tell what is the remote part when
# this is in the middle of a long exception chain.
padding = ' '
s = ''.join(padding + line for line in s.splitlines(keepends=True))
return f'The following occurred on {self._peer_desc}:\n{s}'
class IntegrityError(ValueError):

View file

@ -48,6 +48,7 @@ class MessageSender:
None] | None = None
self._decode_filter_call: Callable[
[Any, Message, dict, Response | SysResponse], None] | None = None
self._peer_desc_call: Callable[[Any], str] | None = None
def send_method(
self, call: Callable[[Any, str],
@ -102,9 +103,20 @@ class MessageSender:
self._decode_filter_call = call
return call
def peer_desc_method(self, call: Callable[[Any],
str]) -> Callable[[Any], str]:
"""Function decorator for defining peer descriptions.
These are included in error messages or other diagnostics.
"""
assert self._peer_desc_call is None
self._peer_desc_call = call
return call
def send(self, bound_obj: Any, message: Message) -> Response | None:
"""Send a message synchronously."""
return self.send_split_part_2(
bound_obj=bound_obj,
message=message,
raw_response=self.send_split_part_1(
bound_obj=bound_obj,
@ -116,6 +128,7 @@ class MessageSender:
message: Message) -> Response | None:
"""Send a message asynchronously."""
return self.send_split_part_2(
bound_obj=bound_obj,
message=message,
raw_response=await self.send_split_part_1_async(
bound_obj=bound_obj,
@ -178,7 +191,7 @@ class MessageSender:
return self._decode_raw_response(bound_obj, message, response_encoded)
def send_split_part_2(
self, message: Message,
self, bound_obj: Any, message: Message,
raw_response: Response | SysResponse) -> Response | None:
"""Complete message sending (both sync and async).
@ -186,7 +199,7 @@ class MessageSender:
for when message sending and response handling need to happen
in different contexts/threads.
"""
response = self._unpack_raw_response(raw_response)
response = self._unpack_raw_response(bound_obj, raw_response)
assert (response is None
or type(response) in type(message).get_response_types())
return response
@ -228,7 +241,8 @@ class MessageSender:
return response
def _unpack_raw_response(
self, raw_response: Response | SysResponse) -> Response | None:
self, bound_obj: Any,
raw_response: Response | SysResponse) -> Response | None:
"""Given a raw Response, unpacks to special values or Exceptions.
The result of this call is what should be passed to users.
@ -259,7 +273,9 @@ class MessageSender:
raise CleanError(raw_response.error_message)
# Everything else gets lumped in as a remote error.
raise RemoteError(raw_response.error_message)
raise RemoteError(raw_response.error_message,
peer_desc=('peer' if self._peer_desc_call is None
else self._peer_desc_call(bound_obj)))
assert isinstance(raw_response, Response)
return raw_response
@ -309,5 +325,6 @@ class BoundMessageSender:
self, message: Message,
raw_response: Response | SysResponse) -> Response | None:
"""Split send (part 2 of 2)."""
return self._sender.send_split_part_2(message=message,
return self._sender.send_split_part_2(bound_obj=self._obj,
message=message,
raw_response=raw_response)

View file

@ -61,6 +61,59 @@ class _PeerInfo:
OUR_PROTOCOL = 2
def ssl_stream_writer_underlying_transport_info(
writer: asyncio.StreamWriter) -> str:
"""For debugging SSL Stream connections; returns raw transport info."""
# Note: accessing internals here so just returning info and not
# actual objs to reduce potential for breakage.
transport = getattr(writer, '_transport', None)
if transport is not None:
sslproto = getattr(transport, '_ssl_protocol', None)
if sslproto is not None:
raw_transport = getattr(sslproto, '_transport', None)
if raw_transport is not None:
return str(raw_transport)
return '(not found)'
def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None:
"""Ensure a writer is closed; hacky workaround for odd hang."""
from efro.call import tpartial
from threading import Thread
# Hopefully can remove this in Python 3.11?...
# see issue with is_closing() below for more details.
transport = getattr(writer, '_transport', None)
if transport is not None:
sslproto = getattr(transport, '_ssl_protocol', None)
if sslproto is not None:
raw_transport = getattr(sslproto, '_transport', None)
if raw_transport is not None:
Thread(
target=tpartial(
_do_writer_force_close_check,
weakref.ref(raw_transport),
),
daemon=True,
).start()
def _do_writer_force_close_check(transport_weak: weakref.ref) -> None:
try:
# Attempt to bail as soon as the obj dies.
# If it hasn't done so by our timeout, force-kill it.
starttime = time.monotonic()
while time.monotonic() - starttime < 10.0:
time.sleep(0.1)
if transport_weak() is None:
return
transport = transport_weak()
if transport is not None:
logging.info('Forcing abort on stuck transport %s.', transport)
transport.abort()
except Exception:
logging.warning('Error in writer-force-close-check', exc_info=True)
class _InFlightMessage:
"""Represents a message that is out on the wire."""
@ -138,6 +191,9 @@ class RPCEndpoint:
self._peer_info: _PeerInfo | None = None
self._keepalive_interval = keepalive_interval
self._keepalive_timeout = keepalive_timeout
self._did_close_writer = False
self._did_wait_closed_writer = False
self._did_out_packets_buildup_warning = False
# Need to hold weak-refs to these otherwise it creates dep-loops
# which keeps us alive.
@ -156,11 +212,39 @@ class RPCEndpoint:
self._debug_print_call(
f'{self._label}: connected to {peername} at {self._tm()}.')
def __del__(self) -> None:
if self._run_called:
if not self._did_close_writer:
logging.warning(
'RPCEndpoint %d dying with run'
' called but writer not closed (transport=%s).', id(self),
ssl_stream_writer_underlying_transport_info(self._writer))
elif not self._did_wait_closed_writer:
logging.warning(
'RPCEndpoint %d dying with run called'
' but writer not wait-closed (transport=%s).', id(self),
ssl_stream_writer_underlying_transport_info(self._writer))
# Currently seeing rare issue where sockets don't go down;
# let's add a timer to force the issue until we can figure it out.
ssl_stream_writer_force_close_check(self._writer)
async def run(self) -> None:
"""Run the endpoint until the connection is lost or closed.
Handles closing the provided reader/writer on close.
"""
try:
await self._do_run()
except asyncio.CancelledError:
# We aren't really designed to be cancelled so let's warn
# if it happens.
logging.warning('RPCEndpoint.run got CancelledError;'
' want to try and avoid this.')
raise
async def _do_run(self) -> None:
self._check_env()
if self._run_called:
@ -186,9 +270,13 @@ class RPCEndpoint:
# We want to know if any errors happened aside from CancelledError
# (which are BaseExceptions, not Exception).
if isinstance(result, Exception):
if self._debug_print:
logging.error('Got unexpected error from %s core task: %s',
self._label, result)
logging.warning('Got unexpected error from %s core task: %s',
self._label, result)
if not all(task.done() for task in core_tasks):
logging.warning(
'RPCEndpoint %d: not all core tasks marked done after gather.',
id(self))
# Shut ourself down.
try:
@ -228,6 +316,9 @@ class RPCEndpoint:
message_id = self._next_message_id
self._next_message_id = (self._next_message_id + 1) % 65536
# FIXME - should handle backpressure (waiting here if there are
# enough packets already enqueued).
if len(message) > 65535:
# Payload consists of type (1b), message_id (2b),
# len (4b), and data.
@ -261,6 +352,9 @@ class RPCEndpoint:
try:
return await asyncio.wait_for(msgobj.wait_task, timeout=timeout)
except asyncio.CancelledError as exc:
# Question: we assume this means the above wait_for() was
# cancelled; what happens if a task running *us* is cancelled
# though?
if self._debug_print:
self._debug_print_call(
f'{self._label}: message {message_id} was cancelled.')
@ -297,9 +391,12 @@ class RPCEndpoint:
for task in self._get_live_tasks():
task.cancel()
# Close our writer.
assert not self._did_close_writer
if self._debug_print:
self._debug_print_call(f'{self._label}: closing writer...')
self._writer.close()
self._did_close_writer = True
# We don't need this anymore and it is likely to be creating a
# dependency loop.
@ -311,6 +408,7 @@ class RPCEndpoint:
async def wait_closed(self) -> None:
"""I said seagulls; mmmm; stop it now."""
# pylint: disable=too-many-branches
self._check_env()
# Make sure we only *enter* this call once.
@ -321,6 +419,10 @@ class RPCEndpoint:
if not self._closing:
raise RuntimeError('Must be called after close()')
if not self._did_close_writer:
logging.warning('RPCEndpoint wait_closed() called but never'
' explicitly closed writer.')
live_tasks = self._get_live_tasks()
if self._debug_print:
self._debug_print_call(
@ -333,10 +435,13 @@ class RPCEndpoint:
# We want to know if any errors happened aside from CancelledError
# (which are BaseExceptions, not Exception).
if isinstance(result, Exception):
if self._debug_print:
logging.error(
'Got unexpected error cleaning up %s task: %s',
self._label, result)
logging.warning('Got unexpected error cleaning up %s task: %s',
self._label, result)
if not all(task.done() for task in live_tasks):
logging.warning(
'RPCEndpoint %d: not all live tasks marked done after gather.',
id(self))
if self._debug_print:
self._debug_print_call(
@ -354,10 +459,12 @@ class RPCEndpoint:
# indefinitely. See https://github.com/python/cpython/issues/83939
# It sounds like this should be fixed in 3.11 but for now just
# forcing the issue with a timeout here.
await asyncio.wait_for(self._writer.wait_closed(), timeout=10.0)
await asyncio.wait_for(self._writer.wait_closed(), timeout=30.0)
except asyncio.TimeoutError:
logging.info('Timeout on _writer.wait_closed() for %s.',
self._label)
logging.info(
'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
self._label,
ssl_stream_writer_underlying_transport_info(self._writer))
if self._debug_print:
self._debug_print_call(
f'{self._label}: got timeout in _writer.wait_closed();'
@ -370,6 +477,12 @@ class RPCEndpoint:
self._debug_print_call(
f'{self._label}: silently ignoring error in'
f' _writer.wait_closed(): {exc}.')
except asyncio.CancelledError:
logging.warning('RPCEndpoint.wait_closed()'
' got asyncio.CancelledError; not expected.')
raise
assert not self._did_wait_closed_writer
self._did_wait_closed_writer = True
def _tm(self) -> str:
"""Simple readable time value for debugging."""
@ -494,7 +607,21 @@ class RPCEndpoint:
self._have_out_packets.clear()
self._writer.write(data)
# await self._writer.drain()
# This should keep our writer from buffering huge amounts
# of outgoing data. We must remember though that we also
# need to prevent _out_packets from growing too large and
# that part's on us.
await self._writer.drain()
# For now we're not applying backpressure, but let's make
# noise if this gets out of hand.
if len(self._out_packets) > 200:
if not self._did_out_packets_buildup_warning:
logging.warning(
'_out_packets building up too'
' much on RPCEndpoint %s.', id(self))
self._did_out_packets_buildup_warning = True
async def _run_keepalive_task(self) -> None:
"""Send periodic keepalive packets."""