Bombsquad-Ballistica-Modded.../dist/ba_data/python/efro/rpc.py

768 lines
30 KiB
Python
Raw Normal View History

2022-06-09 01:26:46 +05:30
# Released under the MIT License. See LICENSE for details.
#
"""Remote procedure call related functionality."""
from __future__ import annotations
import time
import asyncio
import logging
import weakref
from enum import Enum
from dataclasses import dataclass
from threading import current_thread
from typing import TYPE_CHECKING, Annotated
from efro.util import assert_never
2022-07-16 17:59:14 +05:30
from efro.error import (CommunicationError,
is_asyncio_streams_communication_error)
2022-06-09 01:26:46 +05:30
from efro.dataclassio import (dataclass_to_json, dataclass_from_json,
ioprepped, IOAttrs)
if TYPE_CHECKING:
2022-06-30 00:31:52 +05:30
from typing import Literal, Awaitable, Callable
2022-06-09 01:26:46 +05:30
# Terminology:
# Packet: A chunk of data consisting of a type and some type-dependent
# payload. Even though we use streams we organize our transmission
# into 'packets'.
# Message: User data which we transmit using one or more packets.
class _PacketType(Enum):
HANDSHAKE = 0
KEEPALIVE = 1
MESSAGE = 2
RESPONSE = 3
2022-06-30 00:31:52 +05:30
MESSAGE_BIG = 4
RESPONSE_BIG = 5
2022-06-09 01:26:46 +05:30
_BYTE_ORDER: Literal['big'] = 'big'
@ioprepped
@dataclass
class _PeerInfo:
# So we can gracefully evolve how we communicate in the future.
protocol: Annotated[int, IOAttrs('p')]
# How often we'll be sending out keepalives (in seconds).
keepalive_interval: Annotated[float, IOAttrs('k')]
2022-06-30 00:31:52 +05:30
# Note: we are expected to be forward and backward compatible; we can
# increment protocol freely and expect everyone else to still talk to us.
# Likewise we should retain logic to communicate with older protocols.
# Protocol history:
# 1 - initial release
# 2 - gained big (32-bit len val) package/response packets
OUR_PROTOCOL = 2
2022-06-09 01:26:46 +05:30
2022-10-05 03:11:34 +05:30
def ssl_stream_writer_underlying_transport_info(
writer: asyncio.StreamWriter) -> str:
"""For debugging SSL Stream connections; returns raw transport info."""
# Note: accessing internals here so just returning info and not
# actual objs to reduce potential for breakage.
transport = getattr(writer, '_transport', None)
if transport is not None:
sslproto = getattr(transport, '_ssl_protocol', None)
if sslproto is not None:
raw_transport = getattr(sslproto, '_transport', None)
if raw_transport is not None:
return str(raw_transport)
return '(not found)'
def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None:
"""Ensure a writer is closed; hacky workaround for odd hang."""
from efro.call import tpartial
from threading import Thread
# Hopefully can remove this in Python 3.11?...
# see issue with is_closing() below for more details.
transport = getattr(writer, '_transport', None)
if transport is not None:
sslproto = getattr(transport, '_ssl_protocol', None)
if sslproto is not None:
raw_transport = getattr(sslproto, '_transport', None)
if raw_transport is not None:
Thread(
target=tpartial(
_do_writer_force_close_check,
weakref.ref(raw_transport),
),
daemon=True,
).start()
def _do_writer_force_close_check(transport_weak: weakref.ref) -> None:
try:
# Attempt to bail as soon as the obj dies.
# If it hasn't done so by our timeout, force-kill it.
starttime = time.monotonic()
while time.monotonic() - starttime < 10.0:
time.sleep(0.1)
if transport_weak() is None:
return
transport = transport_weak()
if transport is not None:
logging.info('Forcing abort on stuck transport %s.', transport)
transport.abort()
except Exception:
logging.warning('Error in writer-force-close-check', exc_info=True)
2022-06-09 01:26:46 +05:30
class _InFlightMessage:
"""Represents a message that is out on the wire."""
def __init__(self) -> None:
2022-06-30 00:31:52 +05:30
self._response: bytes | None = None
2022-06-09 01:26:46 +05:30
self._got_response = asyncio.Event()
self.wait_task = asyncio.create_task(self._wait())
async def _wait(self) -> bytes:
await self._got_response.wait()
assert self._response is not None
return self._response
def set_response(self, data: bytes) -> None:
"""Set response data."""
assert self._response is None
self._response = data
self._got_response.set()
class _KeepaliveTimeoutError(Exception):
"""Raised if we time out due to not receiving keepalives."""
class RPCEndpoint:
"""Facilitates asynchronous multiplexed remote procedure calls.
Be aware that, while multiple calls can be in flight in either direction
simultaneously, packets are still sent serially in a single
stream. So excessively long messages/responses will delay all other
communication. If/when this becomes an issue we can look into breaking up
long messages into multiple packets.
"""
# Set to True on an instance to test keepalive failures.
test_suppress_keepalives: bool = False
# How long we should wait before giving up on a message by default.
# Note this includes processing time on the other end.
DEFAULT_MESSAGE_TIMEOUT = 60.0
# How often we send out keepalive packets by default.
DEFAULT_KEEPALIVE_INTERVAL = 10.73 # (avoid too regular of values)
# How long we can go without receiving a keepalive packet before we
# disconnect.
DEFAULT_KEEPALIVE_TIMEOUT = 30.0
def __init__(self,
handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
label: str,
debug_print: bool = False,
debug_print_io: bool = False,
2022-07-02 01:43:52 +05:30
debug_print_call: Callable[[str], None] | None = None,
2022-06-09 01:26:46 +05:30
keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT) -> None:
self._handle_raw_message_call = handle_raw_message_call
self._reader = reader
self._writer = writer
self._debug_print = debug_print
self._debug_print_io = debug_print_io
if debug_print_call is None:
debug_print_call = print
self._debug_print_call: Callable[[str], None] = debug_print_call
self._label = label
self._thread = current_thread()
self._closing = False
self._did_wait_closed = False
self._event_loop = asyncio.get_running_loop()
self._out_packets: list[bytes] = []
self._have_out_packets = asyncio.Event()
self._run_called = False
2022-06-30 00:31:52 +05:30
self._peer_info: _PeerInfo | None = None
2022-06-09 01:26:46 +05:30
self._keepalive_interval = keepalive_interval
self._keepalive_timeout = keepalive_timeout
2022-10-05 03:11:34 +05:30
self._did_close_writer = False
self._did_wait_closed_writer = False
self._did_out_packets_buildup_warning = False
2022-06-09 01:26:46 +05:30
# Need to hold weak-refs to these otherwise it creates dep-loops
# which keeps us alive.
self._tasks: list[weakref.ref[asyncio.Task]] = []
# When we last got a keepalive or equivalent (time.monotonic value)
2022-06-30 00:31:52 +05:30
self._last_keepalive_receive_time: float | None = None
2022-06-09 01:26:46 +05:30
# (Start near the end to make sure our looping logic is sound).
self._next_message_id = 65530
self._in_flight_messages: dict[int, _InFlightMessage] = {}
if self._debug_print:
peername = self._writer.get_extra_info('peername')
self._debug_print_call(
f'{self._label}: connected to {peername} at {self._tm()}.')
2022-10-05 03:11:34 +05:30
def __del__(self) -> None:
if self._run_called:
if not self._did_close_writer:
logging.warning(
'RPCEndpoint %d dying with run'
' called but writer not closed (transport=%s).', id(self),
ssl_stream_writer_underlying_transport_info(self._writer))
elif not self._did_wait_closed_writer:
logging.warning(
'RPCEndpoint %d dying with run called'
' but writer not wait-closed (transport=%s).', id(self),
ssl_stream_writer_underlying_transport_info(self._writer))
# Currently seeing rare issue where sockets don't go down;
# let's add a timer to force the issue until we can figure it out.
ssl_stream_writer_force_close_check(self._writer)
2022-06-09 01:26:46 +05:30
async def run(self) -> None:
"""Run the endpoint until the connection is lost or closed.
Handles closing the provided reader/writer on close.
"""
2022-10-05 03:11:34 +05:30
try:
await self._do_run()
except asyncio.CancelledError:
# We aren't really designed to be cancelled so let's warn
# if it happens.
logging.warning('RPCEndpoint.run got CancelledError;'
' want to try and avoid this.')
raise
async def _do_run(self) -> None:
2022-06-09 01:26:46 +05:30
self._check_env()
if self._run_called:
raise RuntimeError('Run can be called only once per endpoint.')
self._run_called = True
core_tasks = [
asyncio.create_task(
self._run_core_task('keepalive', self._run_keepalive_task())),
asyncio.create_task(
self._run_core_task('read', self._run_read_task())),
asyncio.create_task(
self._run_core_task('write', self._run_write_task()))
]
self._tasks += [weakref.ref(t) for t in core_tasks]
# Run our core tasks until they all complete.
results = await asyncio.gather(*core_tasks, return_exceptions=True)
# Core tasks should handle their own errors; the only ones
# we expect to bubble up are CancelledError.
for result in results:
# We want to know if any errors happened aside from CancelledError
# (which are BaseExceptions, not Exception).
if isinstance(result, Exception):
2022-10-05 03:11:34 +05:30
logging.warning('Got unexpected error from %s core task: %s',
self._label, result)
if not all(task.done() for task in core_tasks):
logging.warning(
'RPCEndpoint %d: not all core tasks marked done after gather.',
id(self))
2022-06-09 01:26:46 +05:30
# Shut ourself down.
try:
self.close()
await self.wait_closed()
except Exception:
logging.exception('Error closing %s.', self._label)
if self._debug_print:
self._debug_print_call(f'{self._label}: finished.')
async def send_message(self,
message: bytes,
2022-06-30 00:31:52 +05:30
timeout: float | None = None) -> bytes:
2022-06-09 01:26:46 +05:30
"""Send a message to the peer and return a response.
If timeout is not provided, the default will be used.
Raises a CommunicationError if the round trip is not completed
for any reason.
"""
self._check_env()
if self._closing:
raise CommunicationError('Endpoint is closed')
2022-06-30 00:31:52 +05:30
# We need to know their protocol, so if we haven't gotten a handshake
# from them yet, just wait.
while self._peer_info is None:
await asyncio.sleep(0.01)
assert self._peer_info is not None
if self._peer_info.protocol == 1:
if len(message) > 65535:
raise RuntimeError('Message cannot be larger than 65535 bytes')
# message_id is a 16 bit looping value.
2022-06-09 01:26:46 +05:30
message_id = self._next_message_id
self._next_message_id = (self._next_message_id + 1) % 65536
2022-10-05 03:11:34 +05:30
# FIXME - should handle backpressure (waiting here if there are
# enough packets already enqueued).
2022-06-30 00:31:52 +05:30
if len(message) > 65535:
# Payload consists of type (1b), message_id (2b),
# len (4b), and data.
self._enqueue_outgoing_packet(
_PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER) +
message_id.to_bytes(2, _BYTE_ORDER) +
len(message).to_bytes(4, _BYTE_ORDER) + message)
else:
# Payload consists of type (1b), message_id (2b),
# len (2b), and data.
self._enqueue_outgoing_packet(
_PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) +
message_id.to_bytes(2, _BYTE_ORDER) +
len(message).to_bytes(2, _BYTE_ORDER) + message)
2022-06-09 01:26:46 +05:30
# Make an entry so we know this message is out there.
assert message_id not in self._in_flight_messages
msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
# Also add its task to our list so we properly cancel it if we die.
self._prune_tasks() # Keep our list from filling with dead tasks.
self._tasks.append(weakref.ref(msgobj.wait_task))
# Note: we always want to incorporate a timeout. Individual
# messages may hang or error on the other end and this ensures
# we won't build up lots of zombie tasks waiting around for
# responses that will never arrive.
if timeout is None:
timeout = self.DEFAULT_MESSAGE_TIMEOUT
assert timeout is not None
try:
return await asyncio.wait_for(msgobj.wait_task, timeout=timeout)
except asyncio.CancelledError as exc:
2022-10-05 03:11:34 +05:30
# Question: we assume this means the above wait_for() was
# cancelled; what happens if a task running *us* is cancelled
# though?
2022-06-09 01:26:46 +05:30
if self._debug_print:
self._debug_print_call(
f'{self._label}: message {message_id} was cancelled.')
raise CommunicationError() from exc
except asyncio.TimeoutError as exc:
if self._debug_print:
self._debug_print_call(
f'{self._label}: message {message_id} timed out.')
# Stop waiting on the response.
msgobj.wait_task.cancel()
# Remove the record of this message.
del self._in_flight_messages[message_id]
# Let the user know something went wrong.
raise CommunicationError() from exc
def close(self) -> None:
"""I said seagulls; mmmm; stop it now."""
self._check_env()
if self._closing:
return
if self._debug_print:
self._debug_print_call(f'{self._label}: closing...')
self._closing = True
# Kill all of our in-flight tasks.
if self._debug_print:
self._debug_print_call(f'{self._label}: cancelling tasks...')
for task in self._get_live_tasks():
task.cancel()
2022-10-05 03:11:34 +05:30
# Close our writer.
assert not self._did_close_writer
2022-06-09 01:26:46 +05:30
if self._debug_print:
self._debug_print_call(f'{self._label}: closing writer...')
self._writer.close()
2022-10-05 03:11:34 +05:30
self._did_close_writer = True
2022-06-09 01:26:46 +05:30
# We don't need this anymore and it is likely to be creating a
# dependency loop.
del self._handle_raw_message_call
def is_closing(self) -> bool:
"""Have we begun the process of closing?"""
return self._closing
async def wait_closed(self) -> None:
"""I said seagulls; mmmm; stop it now."""
2022-10-05 03:11:34 +05:30
# pylint: disable=too-many-branches
2022-06-09 01:26:46 +05:30
self._check_env()
# Make sure we only *enter* this call once.
if self._did_wait_closed:
return
self._did_wait_closed = True
if not self._closing:
raise RuntimeError('Must be called after close()')
2022-10-05 03:11:34 +05:30
if not self._did_close_writer:
logging.warning('RPCEndpoint wait_closed() called but never'
' explicitly closed writer.')
2022-06-09 01:26:46 +05:30
live_tasks = self._get_live_tasks()
if self._debug_print:
self._debug_print_call(
f'{self._label}: waiting for tasks to finish: '
f' ({live_tasks=})...')
# Wait for all of our in-flight tasks to wrap up.
results = await asyncio.gather(*live_tasks, return_exceptions=True)
for result in results:
# We want to know if any errors happened aside from CancelledError
# (which are BaseExceptions, not Exception).
if isinstance(result, Exception):
2022-10-05 03:11:34 +05:30
logging.warning('Got unexpected error cleaning up %s task: %s',
self._label, result)
if not all(task.done() for task in live_tasks):
logging.warning(
'RPCEndpoint %d: not all live tasks marked done after gather.',
id(self))
2022-06-09 01:26:46 +05:30
if self._debug_print:
self._debug_print_call(
f'{self._label}: tasks finished; waiting for writer close...')
# Now wait for our writer to finish going down.
# When we close our writer it generally triggers errors
# in our current blocked read/writes. However that same
# error is also sometimes returned from _writer.wait_closed().
# See connection_lost() in asyncio/streams.py to see why.
# So let's silently ignore it when that happens.
assert self._writer.is_closing()
try:
# It seems that as of Python 3.9.x it is possible for this to hang
# indefinitely. See https://github.com/python/cpython/issues/83939
# It sounds like this should be fixed in 3.11 but for now just
# forcing the issue with a timeout here.
2022-10-05 03:11:34 +05:30
await asyncio.wait_for(self._writer.wait_closed(), timeout=30.0)
2022-06-09 01:26:46 +05:30
except asyncio.TimeoutError:
2022-10-05 03:11:34 +05:30
logging.info(
'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
self._label,
ssl_stream_writer_underlying_transport_info(self._writer))
2022-06-09 01:26:46 +05:30
if self._debug_print:
self._debug_print_call(
f'{self._label}: got timeout in _writer.wait_closed();'
' This should be fixed in future Python versions.')
except Exception as exc:
if not self._is_expected_connection_error(exc):
logging.exception('Error closing _writer for %s.', self._label)
else:
if self._debug_print:
self._debug_print_call(
f'{self._label}: silently ignoring error in'
f' _writer.wait_closed(): {exc}.')
2022-10-05 03:11:34 +05:30
except asyncio.CancelledError:
logging.warning('RPCEndpoint.wait_closed()'
' got asyncio.CancelledError; not expected.')
raise
assert not self._did_wait_closed_writer
self._did_wait_closed_writer = True
2022-06-09 01:26:46 +05:30
def _tm(self) -> str:
"""Simple readable time value for debugging."""
tval = time.time() % 100.0
return f'{tval:.2f}'
async def _run_read_task(self) -> None:
"""Read from the peer."""
self._check_env()
assert self._peer_info is None
# The first thing they should send us is their handshake; then
# we'll know if/how we can talk to them.
mlen = await self._read_int_32()
message = (await self._reader.readexactly(mlen))
self._peer_info = dataclass_from_json(_PeerInfo, message.decode())
self._last_keepalive_receive_time = time.monotonic()
if self._debug_print:
self._debug_print_call(
f'{self._label}: received handshake at {self._tm()}.')
# Now just sit and handle stuff as it comes in.
while True:
assert not self._closing
# Read message type.
mtype = _PacketType(await self._read_int_8())
if mtype is _PacketType.HANDSHAKE:
raise RuntimeError('Got multiple handshakes')
if mtype is _PacketType.KEEPALIVE:
if self._debug_print_io:
self._debug_print_call(f'{self._label}: received keepalive'
f' at {self._tm()}.')
self._last_keepalive_receive_time = time.monotonic()
elif mtype is _PacketType.MESSAGE:
2022-06-30 00:31:52 +05:30
await self._handle_message_packet(big=False)
elif mtype is _PacketType.MESSAGE_BIG:
await self._handle_message_packet(big=True)
2022-06-09 01:26:46 +05:30
elif mtype is _PacketType.RESPONSE:
2022-06-30 00:31:52 +05:30
await self._handle_response_packet(big=False)
elif mtype is _PacketType.RESPONSE_BIG:
await self._handle_response_packet(big=True)
2022-06-09 01:26:46 +05:30
else:
assert_never(mtype)
2022-06-30 00:31:52 +05:30
async def _handle_message_packet(self, big: bool) -> None:
assert self._peer_info is not None
2022-06-09 01:26:46 +05:30
msgid = await self._read_int_16()
2022-06-30 00:31:52 +05:30
if big:
msglen = await self._read_int_32()
else:
msglen = await self._read_int_16()
2022-06-09 01:26:46 +05:30
msg = await self._reader.readexactly(msglen)
if self._debug_print_io:
self._debug_print_call(f'{self._label}: received message {msgid}'
f' of size {msglen} at {self._tm()}.')
# Create a message-task to handle this message and return
# a response (we don't want to block while that happens).
assert not self._closing
self._prune_tasks() # Keep from filling with dead tasks.
self._tasks.append(
weakref.ref(
asyncio.create_task(
self._handle_raw_message(message_id=msgid, message=msg))))
2022-10-01 14:51:35 +05:30
if self._debug_print:
self._debug_print_call(
f'{self._label}: done handling message at {self._tm()}.')
2022-06-09 01:26:46 +05:30
2022-06-30 00:31:52 +05:30
async def _handle_response_packet(self, big: bool) -> None:
assert self._peer_info is not None
2022-06-09 01:26:46 +05:30
msgid = await self._read_int_16()
2022-06-30 00:31:52 +05:30
# Protocol 2 gained 32 bit data lengths.
if big:
rsplen = await self._read_int_32()
else:
rsplen = await self._read_int_16()
2022-06-09 01:26:46 +05:30
if self._debug_print_io:
self._debug_print_call(f'{self._label}: received response {msgid}'
f' of size {rsplen} at {self._tm()}.')
rsp = await self._reader.readexactly(rsplen)
msgobj = self._in_flight_messages.get(msgid)
if msgobj is None:
# It's possible for us to get a response to a message
# that has timed out. In this case we will have no local
# record of it.
if self._debug_print:
self._debug_print_call(
f'{self._label}: got response for nonexistent'
f' message id {msgid}; perhaps it timed out?')
else:
msgobj.set_response(rsp)
async def _run_write_task(self) -> None:
"""Write to the peer."""
self._check_env()
# Introduce ourself so our peer knows how it can talk to us.
data = dataclass_to_json(
_PeerInfo(protocol=OUR_PROTOCOL,
keepalive_interval=self._keepalive_interval)).encode()
self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data)
# Now just write out-messages as they come in.
while True:
# Wait until some data comes in.
await self._have_out_packets.wait()
assert self._out_packets
data = self._out_packets.pop(0)
# Important: only clear this once all packets are sent.
if not self._out_packets:
self._have_out_packets.clear()
self._writer.write(data)
2022-10-05 03:11:34 +05:30
# This should keep our writer from buffering huge amounts
# of outgoing data. We must remember though that we also
# need to prevent _out_packets from growing too large and
# that part's on us.
await self._writer.drain()
# For now we're not applying backpressure, but let's make
# noise if this gets out of hand.
if len(self._out_packets) > 200:
if not self._did_out_packets_buildup_warning:
logging.warning(
'_out_packets building up too'
' much on RPCEndpoint %s.', id(self))
self._did_out_packets_buildup_warning = True
2022-06-09 01:26:46 +05:30
async def _run_keepalive_task(self) -> None:
"""Send periodic keepalive packets."""
self._check_env()
# We explicitly send our own keepalive packets so we can stay
# more on top of the connection state and possibly decide to
# kill it when contact is lost more quickly than the OS would
# do itself (or at least keep the user informed that the
# connection is lagging). It sounds like we could have the TCP
# layer do this sort of thing itself but that might be
# OS-specific so gonna go this way for now.
while True:
assert not self._closing
await asyncio.sleep(self._keepalive_interval)
if not self.test_suppress_keepalives:
self._enqueue_outgoing_packet(
_PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER))
# Also go ahead and handle dropping the connection if we
# haven't heard from the peer in a while.
# NOTE: perhaps we want to do something more exact than
# this which only checks once per keepalive-interval?..
now = time.monotonic()
if (self._last_keepalive_receive_time is not None
and now - self._last_keepalive_receive_time >
self._keepalive_timeout):
if self._debug_print:
since = now - self._last_keepalive_receive_time
self._debug_print_call(
f'{self._label}: reached keepalive time-out'
f' ({since:.1f}s).')
raise _KeepaliveTimeoutError()
async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None:
try:
await call
except Exception as exc:
# We expect connection errors to put us here, but make noise
# if something else does.
if not self._is_expected_connection_error(exc):
logging.exception('Unexpected error in rpc %s %s task.',
self._label, tasklabel)
else:
if self._debug_print:
self._debug_print_call(
f'{self._label}: {tasklabel} task will exit cleanly'
f' due to {exc!r}.')
finally:
# Any core task exiting triggers shutdown.
if self._debug_print:
self._debug_print_call(
f'{self._label}: {tasklabel} task exiting...')
self.close()
async def _handle_raw_message(self, message_id: int,
message: bytes) -> None:
try:
response = await self._handle_raw_message_call(message)
except Exception:
# We expect local message handler to always succeed.
# If that doesn't happen, make a fuss so we know to fix it.
# The other end will simply never get a response to this
# message.
2022-10-01 14:51:35 +05:30
logging.exception('Error handling raw rpc message')
2022-06-09 01:26:46 +05:30
return
2022-06-30 00:31:52 +05:30
assert self._peer_info is not None
if self._peer_info.protocol == 1:
if len(response) > 65535:
raise RuntimeError(
'Response cannot be larger than 65535 bytes')
2022-06-09 01:26:46 +05:30
# Now send back our response.
# Payload consists of type (1b), msgid (2b), len (2b), and data.
2022-06-30 00:31:52 +05:30
if len(response) > 65535:
self._enqueue_outgoing_packet(
_PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER) +
message_id.to_bytes(2, _BYTE_ORDER) +
len(response).to_bytes(4, _BYTE_ORDER) + response)
else:
self._enqueue_outgoing_packet(
_PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER) +
message_id.to_bytes(2, _BYTE_ORDER) +
len(response).to_bytes(2, _BYTE_ORDER) + response)
2022-06-09 01:26:46 +05:30
async def _read_int_8(self) -> int:
return int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER)
async def _read_int_16(self) -> int:
return int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER)
async def _read_int_32(self) -> int:
return int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER)
@classmethod
def _is_expected_connection_error(cls, exc: Exception) -> bool:
"""Stuff we expect to end our connection in normal circumstances."""
if isinstance(exc, _KeepaliveTimeoutError):
return True
2022-07-16 17:59:14 +05:30
return is_asyncio_streams_communication_error(exc)
2022-06-09 01:26:46 +05:30
def _check_env(self) -> None:
# I was seeing that asyncio stuff wasn't working as expected if
# created in one thread and used in another, so let's enforce
# a single thread for all use of an instance.
if current_thread() is not self._thread:
raise RuntimeError('This must be called from the same thread'
' that the endpoint was created in.')
# This should always be the case if thread is the same.
assert asyncio.get_running_loop() is self._event_loop
def _enqueue_outgoing_packet(self, data: bytes) -> None:
"""Enqueue a raw packet to be sent. Must be called from our loop."""
self._check_env()
if self._debug_print_io:
self._debug_print_call(f'{self._label}: enqueueing outgoing packet'
f' {data[:50]!r} at {self._tm()}.')
# Add the data and let our write task know about it.
self._out_packets.append(data)
self._have_out_packets.set()
def _prune_tasks(self) -> None:
out: list[weakref.ref[asyncio.Task]] = []
for task_weak_ref in self._tasks:
task = task_weak_ref()
if task is not None and not task.done():
out.append(task_weak_ref)
self._tasks = out
def _get_live_tasks(self) -> list[asyncio.Task]:
out: list[asyncio.Task] = []
for task_weak_ref in self._tasks:
task = task_weak_ref()
if task is not None and not task.done():
out.append(task)
return out