mirror of
https://github.com/imayushsaini/Bombsquad-Ballistica-Modded-Server.git
synced 2025-10-20 00:00:39 +00:00
767 lines
30 KiB
Python
767 lines
30 KiB
Python
# Released under the MIT License. See LICENSE for details.
|
|
#
|
|
"""Remote procedure call related functionality."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
import asyncio
|
|
import logging
|
|
import weakref
|
|
from enum import Enum
|
|
from dataclasses import dataclass
|
|
from threading import current_thread
|
|
from typing import TYPE_CHECKING, Annotated
|
|
|
|
from efro.util import assert_never
|
|
from efro.error import (CommunicationError,
|
|
is_asyncio_streams_communication_error)
|
|
from efro.dataclassio import (dataclass_to_json, dataclass_from_json,
|
|
ioprepped, IOAttrs)
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Literal, Awaitable, Callable
|
|
|
|
# Terminology:
|
|
# Packet: A chunk of data consisting of a type and some type-dependent
|
|
# payload. Even though we use streams we organize our transmission
|
|
# into 'packets'.
|
|
# Message: User data which we transmit using one or more packets.
|
|
|
|
|
|
class _PacketType(Enum):
|
|
HANDSHAKE = 0
|
|
KEEPALIVE = 1
|
|
MESSAGE = 2
|
|
RESPONSE = 3
|
|
MESSAGE_BIG = 4
|
|
RESPONSE_BIG = 5
|
|
|
|
|
|
_BYTE_ORDER: Literal['big'] = 'big'
|
|
|
|
|
|
@ioprepped
|
|
@dataclass
|
|
class _PeerInfo:
|
|
|
|
# So we can gracefully evolve how we communicate in the future.
|
|
protocol: Annotated[int, IOAttrs('p')]
|
|
|
|
# How often we'll be sending out keepalives (in seconds).
|
|
keepalive_interval: Annotated[float, IOAttrs('k')]
|
|
|
|
|
|
# Note: we are expected to be forward and backward compatible; we can
|
|
# increment protocol freely and expect everyone else to still talk to us.
|
|
# Likewise we should retain logic to communicate with older protocols.
|
|
# Protocol history:
|
|
# 1 - initial release
|
|
# 2 - gained big (32-bit len val) package/response packets
|
|
OUR_PROTOCOL = 2
|
|
|
|
|
|
def ssl_stream_writer_underlying_transport_info(
|
|
writer: asyncio.StreamWriter) -> str:
|
|
"""For debugging SSL Stream connections; returns raw transport info."""
|
|
# Note: accessing internals here so just returning info and not
|
|
# actual objs to reduce potential for breakage.
|
|
transport = getattr(writer, '_transport', None)
|
|
if transport is not None:
|
|
sslproto = getattr(transport, '_ssl_protocol', None)
|
|
if sslproto is not None:
|
|
raw_transport = getattr(sslproto, '_transport', None)
|
|
if raw_transport is not None:
|
|
return str(raw_transport)
|
|
return '(not found)'
|
|
|
|
|
|
def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None:
|
|
"""Ensure a writer is closed; hacky workaround for odd hang."""
|
|
from efro.call import tpartial
|
|
from threading import Thread
|
|
# Hopefully can remove this in Python 3.11?...
|
|
# see issue with is_closing() below for more details.
|
|
transport = getattr(writer, '_transport', None)
|
|
if transport is not None:
|
|
sslproto = getattr(transport, '_ssl_protocol', None)
|
|
if sslproto is not None:
|
|
raw_transport = getattr(sslproto, '_transport', None)
|
|
if raw_transport is not None:
|
|
Thread(
|
|
target=tpartial(
|
|
_do_writer_force_close_check,
|
|
weakref.ref(raw_transport),
|
|
),
|
|
daemon=True,
|
|
).start()
|
|
|
|
|
|
def _do_writer_force_close_check(transport_weak: weakref.ref) -> None:
|
|
try:
|
|
# Attempt to bail as soon as the obj dies.
|
|
# If it hasn't done so by our timeout, force-kill it.
|
|
starttime = time.monotonic()
|
|
while time.monotonic() - starttime < 10.0:
|
|
time.sleep(0.1)
|
|
if transport_weak() is None:
|
|
return
|
|
transport = transport_weak()
|
|
if transport is not None:
|
|
logging.info('Forcing abort on stuck transport %s.', transport)
|
|
transport.abort()
|
|
except Exception:
|
|
logging.warning('Error in writer-force-close-check', exc_info=True)
|
|
|
|
|
|
class _InFlightMessage:
|
|
"""Represents a message that is out on the wire."""
|
|
|
|
def __init__(self) -> None:
|
|
self._response: bytes | None = None
|
|
self._got_response = asyncio.Event()
|
|
self.wait_task = asyncio.create_task(self._wait())
|
|
|
|
async def _wait(self) -> bytes:
|
|
await self._got_response.wait()
|
|
assert self._response is not None
|
|
return self._response
|
|
|
|
def set_response(self, data: bytes) -> None:
|
|
"""Set response data."""
|
|
assert self._response is None
|
|
self._response = data
|
|
self._got_response.set()
|
|
|
|
|
|
class _KeepaliveTimeoutError(Exception):
|
|
"""Raised if we time out due to not receiving keepalives."""
|
|
|
|
|
|
class RPCEndpoint:
|
|
"""Facilitates asynchronous multiplexed remote procedure calls.
|
|
|
|
Be aware that, while multiple calls can be in flight in either direction
|
|
simultaneously, packets are still sent serially in a single
|
|
stream. So excessively long messages/responses will delay all other
|
|
communication. If/when this becomes an issue we can look into breaking up
|
|
long messages into multiple packets.
|
|
"""
|
|
|
|
# Set to True on an instance to test keepalive failures.
|
|
test_suppress_keepalives: bool = False
|
|
|
|
# How long we should wait before giving up on a message by default.
|
|
# Note this includes processing time on the other end.
|
|
DEFAULT_MESSAGE_TIMEOUT = 60.0
|
|
|
|
# How often we send out keepalive packets by default.
|
|
DEFAULT_KEEPALIVE_INTERVAL = 10.73 # (avoid too regular of values)
|
|
|
|
# How long we can go without receiving a keepalive packet before we
|
|
# disconnect.
|
|
DEFAULT_KEEPALIVE_TIMEOUT = 30.0
|
|
|
|
def __init__(self,
|
|
handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
|
|
reader: asyncio.StreamReader,
|
|
writer: asyncio.StreamWriter,
|
|
label: str,
|
|
debug_print: bool = False,
|
|
debug_print_io: bool = False,
|
|
debug_print_call: Callable[[str], None] | None = None,
|
|
keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
|
|
keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT) -> None:
|
|
self._handle_raw_message_call = handle_raw_message_call
|
|
self._reader = reader
|
|
self._writer = writer
|
|
self._debug_print = debug_print
|
|
self._debug_print_io = debug_print_io
|
|
if debug_print_call is None:
|
|
debug_print_call = print
|
|
self._debug_print_call: Callable[[str], None] = debug_print_call
|
|
self._label = label
|
|
self._thread = current_thread()
|
|
self._closing = False
|
|
self._did_wait_closed = False
|
|
self._event_loop = asyncio.get_running_loop()
|
|
self._out_packets: list[bytes] = []
|
|
self._have_out_packets = asyncio.Event()
|
|
self._run_called = False
|
|
self._peer_info: _PeerInfo | None = None
|
|
self._keepalive_interval = keepalive_interval
|
|
self._keepalive_timeout = keepalive_timeout
|
|
self._did_close_writer = False
|
|
self._did_wait_closed_writer = False
|
|
self._did_out_packets_buildup_warning = False
|
|
|
|
# Need to hold weak-refs to these otherwise it creates dep-loops
|
|
# which keeps us alive.
|
|
self._tasks: list[weakref.ref[asyncio.Task]] = []
|
|
|
|
# When we last got a keepalive or equivalent (time.monotonic value)
|
|
self._last_keepalive_receive_time: float | None = None
|
|
|
|
# (Start near the end to make sure our looping logic is sound).
|
|
self._next_message_id = 65530
|
|
|
|
self._in_flight_messages: dict[int, _InFlightMessage] = {}
|
|
|
|
if self._debug_print:
|
|
peername = self._writer.get_extra_info('peername')
|
|
self._debug_print_call(
|
|
f'{self._label}: connected to {peername} at {self._tm()}.')
|
|
|
|
def __del__(self) -> None:
|
|
if self._run_called:
|
|
if not self._did_close_writer:
|
|
logging.warning(
|
|
'RPCEndpoint %d dying with run'
|
|
' called but writer not closed (transport=%s).', id(self),
|
|
ssl_stream_writer_underlying_transport_info(self._writer))
|
|
elif not self._did_wait_closed_writer:
|
|
logging.warning(
|
|
'RPCEndpoint %d dying with run called'
|
|
' but writer not wait-closed (transport=%s).', id(self),
|
|
ssl_stream_writer_underlying_transport_info(self._writer))
|
|
|
|
# Currently seeing rare issue where sockets don't go down;
|
|
# let's add a timer to force the issue until we can figure it out.
|
|
ssl_stream_writer_force_close_check(self._writer)
|
|
|
|
async def run(self) -> None:
|
|
"""Run the endpoint until the connection is lost or closed.
|
|
|
|
Handles closing the provided reader/writer on close.
|
|
"""
|
|
try:
|
|
await self._do_run()
|
|
except asyncio.CancelledError:
|
|
# We aren't really designed to be cancelled so let's warn
|
|
# if it happens.
|
|
logging.warning('RPCEndpoint.run got CancelledError;'
|
|
' want to try and avoid this.')
|
|
raise
|
|
|
|
async def _do_run(self) -> None:
|
|
|
|
self._check_env()
|
|
|
|
if self._run_called:
|
|
raise RuntimeError('Run can be called only once per endpoint.')
|
|
self._run_called = True
|
|
|
|
core_tasks = [
|
|
asyncio.create_task(
|
|
self._run_core_task('keepalive', self._run_keepalive_task())),
|
|
asyncio.create_task(
|
|
self._run_core_task('read', self._run_read_task())),
|
|
asyncio.create_task(
|
|
self._run_core_task('write', self._run_write_task()))
|
|
]
|
|
self._tasks += [weakref.ref(t) for t in core_tasks]
|
|
|
|
# Run our core tasks until they all complete.
|
|
results = await asyncio.gather(*core_tasks, return_exceptions=True)
|
|
|
|
# Core tasks should handle their own errors; the only ones
|
|
# we expect to bubble up are CancelledError.
|
|
for result in results:
|
|
# We want to know if any errors happened aside from CancelledError
|
|
# (which are BaseExceptions, not Exception).
|
|
if isinstance(result, Exception):
|
|
logging.warning('Got unexpected error from %s core task: %s',
|
|
self._label, result)
|
|
|
|
if not all(task.done() for task in core_tasks):
|
|
logging.warning(
|
|
'RPCEndpoint %d: not all core tasks marked done after gather.',
|
|
id(self))
|
|
|
|
# Shut ourself down.
|
|
try:
|
|
self.close()
|
|
await self.wait_closed()
|
|
except Exception:
|
|
logging.exception('Error closing %s.', self._label)
|
|
|
|
if self._debug_print:
|
|
self._debug_print_call(f'{self._label}: finished.')
|
|
|
|
async def send_message(self,
|
|
message: bytes,
|
|
timeout: float | None = None) -> bytes:
|
|
"""Send a message to the peer and return a response.
|
|
|
|
If timeout is not provided, the default will be used.
|
|
Raises a CommunicationError if the round trip is not completed
|
|
for any reason.
|
|
"""
|
|
self._check_env()
|
|
|
|
if self._closing:
|
|
raise CommunicationError('Endpoint is closed')
|
|
|
|
# We need to know their protocol, so if we haven't gotten a handshake
|
|
# from them yet, just wait.
|
|
while self._peer_info is None:
|
|
await asyncio.sleep(0.01)
|
|
assert self._peer_info is not None
|
|
|
|
if self._peer_info.protocol == 1:
|
|
if len(message) > 65535:
|
|
raise RuntimeError('Message cannot be larger than 65535 bytes')
|
|
|
|
# message_id is a 16 bit looping value.
|
|
message_id = self._next_message_id
|
|
self._next_message_id = (self._next_message_id + 1) % 65536
|
|
|
|
# FIXME - should handle backpressure (waiting here if there are
|
|
# enough packets already enqueued).
|
|
|
|
if len(message) > 65535:
|
|
# Payload consists of type (1b), message_id (2b),
|
|
# len (4b), and data.
|
|
self._enqueue_outgoing_packet(
|
|
_PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER) +
|
|
message_id.to_bytes(2, _BYTE_ORDER) +
|
|
len(message).to_bytes(4, _BYTE_ORDER) + message)
|
|
else:
|
|
# Payload consists of type (1b), message_id (2b),
|
|
# len (2b), and data.
|
|
self._enqueue_outgoing_packet(
|
|
_PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) +
|
|
message_id.to_bytes(2, _BYTE_ORDER) +
|
|
len(message).to_bytes(2, _BYTE_ORDER) + message)
|
|
|
|
# Make an entry so we know this message is out there.
|
|
assert message_id not in self._in_flight_messages
|
|
msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
|
|
|
|
# Also add its task to our list so we properly cancel it if we die.
|
|
self._prune_tasks() # Keep our list from filling with dead tasks.
|
|
self._tasks.append(weakref.ref(msgobj.wait_task))
|
|
|
|
# Note: we always want to incorporate a timeout. Individual
|
|
# messages may hang or error on the other end and this ensures
|
|
# we won't build up lots of zombie tasks waiting around for
|
|
# responses that will never arrive.
|
|
if timeout is None:
|
|
timeout = self.DEFAULT_MESSAGE_TIMEOUT
|
|
assert timeout is not None
|
|
try:
|
|
return await asyncio.wait_for(msgobj.wait_task, timeout=timeout)
|
|
except asyncio.CancelledError as exc:
|
|
# Question: we assume this means the above wait_for() was
|
|
# cancelled; what happens if a task running *us* is cancelled
|
|
# though?
|
|
if self._debug_print:
|
|
self._debug_print_call(
|
|
f'{self._label}: message {message_id} was cancelled.')
|
|
raise CommunicationError() from exc
|
|
except asyncio.TimeoutError as exc:
|
|
if self._debug_print:
|
|
self._debug_print_call(
|
|
f'{self._label}: message {message_id} timed out.')
|
|
|
|
# Stop waiting on the response.
|
|
msgobj.wait_task.cancel()
|
|
|
|
# Remove the record of this message.
|
|
del self._in_flight_messages[message_id]
|
|
|
|
# Let the user know something went wrong.
|
|
raise CommunicationError() from exc
|
|
|
|
def close(self) -> None:
|
|
"""I said seagulls; mmmm; stop it now."""
|
|
self._check_env()
|
|
|
|
if self._closing:
|
|
return
|
|
|
|
if self._debug_print:
|
|
self._debug_print_call(f'{self._label}: closing...')
|
|
|
|
self._closing = True
|
|
|
|
# Kill all of our in-flight tasks.
|
|
if self._debug_print:
|
|
self._debug_print_call(f'{self._label}: cancelling tasks...')
|
|
for task in self._get_live_tasks():
|
|
task.cancel()
|
|
|
|
# Close our writer.
|
|
assert not self._did_close_writer
|
|
if self._debug_print:
|
|
self._debug_print_call(f'{self._label}: closing writer...')
|
|
self._writer.close()
|
|
self._did_close_writer = True
|
|
|
|
# We don't need this anymore and it is likely to be creating a
|
|
# dependency loop.
|
|
del self._handle_raw_message_call
|
|
|
|
def is_closing(self) -> bool:
|
|
"""Have we begun the process of closing?"""
|
|
return self._closing
|
|
|
|
async def wait_closed(self) -> None:
|
|
"""I said seagulls; mmmm; stop it now."""
|
|
# pylint: disable=too-many-branches
|
|
self._check_env()
|
|
|
|
# Make sure we only *enter* this call once.
|
|
if self._did_wait_closed:
|
|
return
|
|
self._did_wait_closed = True
|
|
|
|
if not self._closing:
|
|
raise RuntimeError('Must be called after close()')
|
|
|
|
if not self._did_close_writer:
|
|
logging.warning('RPCEndpoint wait_closed() called but never'
|
|
' explicitly closed writer.')
|
|
|
|
live_tasks = self._get_live_tasks()
|
|
if self._debug_print:
|
|
self._debug_print_call(
|
|
f'{self._label}: waiting for tasks to finish: '
|
|
f' ({live_tasks=})...')
|
|
|
|
# Wait for all of our in-flight tasks to wrap up.
|
|
results = await asyncio.gather(*live_tasks, return_exceptions=True)
|
|
for result in results:
|
|
# We want to know if any errors happened aside from CancelledError
|
|
# (which are BaseExceptions, not Exception).
|
|
if isinstance(result, Exception):
|
|
logging.warning('Got unexpected error cleaning up %s task: %s',
|
|
self._label, result)
|
|
|
|
if not all(task.done() for task in live_tasks):
|
|
logging.warning(
|
|
'RPCEndpoint %d: not all live tasks marked done after gather.',
|
|
id(self))
|
|
|
|
if self._debug_print:
|
|
self._debug_print_call(
|
|
f'{self._label}: tasks finished; waiting for writer close...')
|
|
|
|
# Now wait for our writer to finish going down.
|
|
# When we close our writer it generally triggers errors
|
|
# in our current blocked read/writes. However that same
|
|
# error is also sometimes returned from _writer.wait_closed().
|
|
# See connection_lost() in asyncio/streams.py to see why.
|
|
# So let's silently ignore it when that happens.
|
|
assert self._writer.is_closing()
|
|
try:
|
|
# It seems that as of Python 3.9.x it is possible for this to hang
|
|
# indefinitely. See https://github.com/python/cpython/issues/83939
|
|
# It sounds like this should be fixed in 3.11 but for now just
|
|
# forcing the issue with a timeout here.
|
|
await asyncio.wait_for(self._writer.wait_closed(), timeout=30.0)
|
|
except asyncio.TimeoutError:
|
|
logging.info(
|
|
'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
|
|
self._label,
|
|
ssl_stream_writer_underlying_transport_info(self._writer))
|
|
if self._debug_print:
|
|
self._debug_print_call(
|
|
f'{self._label}: got timeout in _writer.wait_closed();'
|
|
' This should be fixed in future Python versions.')
|
|
except Exception as exc:
|
|
if not self._is_expected_connection_error(exc):
|
|
logging.exception('Error closing _writer for %s.', self._label)
|
|
else:
|
|
if self._debug_print:
|
|
self._debug_print_call(
|
|
f'{self._label}: silently ignoring error in'
|
|
f' _writer.wait_closed(): {exc}.')
|
|
except asyncio.CancelledError:
|
|
logging.warning('RPCEndpoint.wait_closed()'
|
|
' got asyncio.CancelledError; not expected.')
|
|
raise
|
|
assert not self._did_wait_closed_writer
|
|
self._did_wait_closed_writer = True
|
|
|
|
def _tm(self) -> str:
|
|
"""Simple readable time value for debugging."""
|
|
tval = time.time() % 100.0
|
|
return f'{tval:.2f}'
|
|
|
|
async def _run_read_task(self) -> None:
|
|
"""Read from the peer."""
|
|
self._check_env()
|
|
assert self._peer_info is None
|
|
|
|
# The first thing they should send us is their handshake; then
|
|
# we'll know if/how we can talk to them.
|
|
mlen = await self._read_int_32()
|
|
message = (await self._reader.readexactly(mlen))
|
|
self._peer_info = dataclass_from_json(_PeerInfo, message.decode())
|
|
self._last_keepalive_receive_time = time.monotonic()
|
|
if self._debug_print:
|
|
self._debug_print_call(
|
|
f'{self._label}: received handshake at {self._tm()}.')
|
|
|
|
# Now just sit and handle stuff as it comes in.
|
|
while True:
|
|
assert not self._closing
|
|
|
|
# Read message type.
|
|
mtype = _PacketType(await self._read_int_8())
|
|
if mtype is _PacketType.HANDSHAKE:
|
|
raise RuntimeError('Got multiple handshakes')
|
|
|
|
if mtype is _PacketType.KEEPALIVE:
|
|
if self._debug_print_io:
|
|
self._debug_print_call(f'{self._label}: received keepalive'
|
|
f' at {self._tm()}.')
|
|
self._last_keepalive_receive_time = time.monotonic()
|
|
|
|
elif mtype is _PacketType.MESSAGE:
|
|
await self._handle_message_packet(big=False)
|
|
|
|
elif mtype is _PacketType.MESSAGE_BIG:
|
|
await self._handle_message_packet(big=True)
|
|
|
|
elif mtype is _PacketType.RESPONSE:
|
|
await self._handle_response_packet(big=False)
|
|
|
|
elif mtype is _PacketType.RESPONSE_BIG:
|
|
await self._handle_response_packet(big=True)
|
|
|
|
else:
|
|
assert_never(mtype)
|
|
|
|
async def _handle_message_packet(self, big: bool) -> None:
|
|
assert self._peer_info is not None
|
|
msgid = await self._read_int_16()
|
|
if big:
|
|
msglen = await self._read_int_32()
|
|
else:
|
|
msglen = await self._read_int_16()
|
|
msg = await self._reader.readexactly(msglen)
|
|
if self._debug_print_io:
|
|
self._debug_print_call(f'{self._label}: received message {msgid}'
|
|
f' of size {msglen} at {self._tm()}.')
|
|
|
|
# Create a message-task to handle this message and return
|
|
# a response (we don't want to block while that happens).
|
|
assert not self._closing
|
|
self._prune_tasks() # Keep from filling with dead tasks.
|
|
self._tasks.append(
|
|
weakref.ref(
|
|
asyncio.create_task(
|
|
self._handle_raw_message(message_id=msgid, message=msg))))
|
|
if self._debug_print:
|
|
self._debug_print_call(
|
|
f'{self._label}: done handling message at {self._tm()}.')
|
|
|
|
async def _handle_response_packet(self, big: bool) -> None:
|
|
assert self._peer_info is not None
|
|
msgid = await self._read_int_16()
|
|
# Protocol 2 gained 32 bit data lengths.
|
|
if big:
|
|
rsplen = await self._read_int_32()
|
|
else:
|
|
rsplen = await self._read_int_16()
|
|
if self._debug_print_io:
|
|
self._debug_print_call(f'{self._label}: received response {msgid}'
|
|
f' of size {rsplen} at {self._tm()}.')
|
|
rsp = await self._reader.readexactly(rsplen)
|
|
msgobj = self._in_flight_messages.get(msgid)
|
|
if msgobj is None:
|
|
# It's possible for us to get a response to a message
|
|
# that has timed out. In this case we will have no local
|
|
# record of it.
|
|
if self._debug_print:
|
|
self._debug_print_call(
|
|
f'{self._label}: got response for nonexistent'
|
|
f' message id {msgid}; perhaps it timed out?')
|
|
else:
|
|
msgobj.set_response(rsp)
|
|
|
|
async def _run_write_task(self) -> None:
|
|
"""Write to the peer."""
|
|
|
|
self._check_env()
|
|
|
|
# Introduce ourself so our peer knows how it can talk to us.
|
|
data = dataclass_to_json(
|
|
_PeerInfo(protocol=OUR_PROTOCOL,
|
|
keepalive_interval=self._keepalive_interval)).encode()
|
|
self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data)
|
|
|
|
# Now just write out-messages as they come in.
|
|
while True:
|
|
|
|
# Wait until some data comes in.
|
|
await self._have_out_packets.wait()
|
|
|
|
assert self._out_packets
|
|
data = self._out_packets.pop(0)
|
|
|
|
# Important: only clear this once all packets are sent.
|
|
if not self._out_packets:
|
|
self._have_out_packets.clear()
|
|
|
|
self._writer.write(data)
|
|
|
|
# This should keep our writer from buffering huge amounts
|
|
# of outgoing data. We must remember though that we also
|
|
# need to prevent _out_packets from growing too large and
|
|
# that part's on us.
|
|
await self._writer.drain()
|
|
|
|
# For now we're not applying backpressure, but let's make
|
|
# noise if this gets out of hand.
|
|
if len(self._out_packets) > 200:
|
|
if not self._did_out_packets_buildup_warning:
|
|
logging.warning(
|
|
'_out_packets building up too'
|
|
' much on RPCEndpoint %s.', id(self))
|
|
self._did_out_packets_buildup_warning = True
|
|
|
|
async def _run_keepalive_task(self) -> None:
|
|
"""Send periodic keepalive packets."""
|
|
self._check_env()
|
|
|
|
# We explicitly send our own keepalive packets so we can stay
|
|
# more on top of the connection state and possibly decide to
|
|
# kill it when contact is lost more quickly than the OS would
|
|
# do itself (or at least keep the user informed that the
|
|
# connection is lagging). It sounds like we could have the TCP
|
|
# layer do this sort of thing itself but that might be
|
|
# OS-specific so gonna go this way for now.
|
|
while True:
|
|
assert not self._closing
|
|
await asyncio.sleep(self._keepalive_interval)
|
|
if not self.test_suppress_keepalives:
|
|
self._enqueue_outgoing_packet(
|
|
_PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER))
|
|
|
|
# Also go ahead and handle dropping the connection if we
|
|
# haven't heard from the peer in a while.
|
|
# NOTE: perhaps we want to do something more exact than
|
|
# this which only checks once per keepalive-interval?..
|
|
now = time.monotonic()
|
|
if (self._last_keepalive_receive_time is not None
|
|
and now - self._last_keepalive_receive_time >
|
|
self._keepalive_timeout):
|
|
if self._debug_print:
|
|
since = now - self._last_keepalive_receive_time
|
|
self._debug_print_call(
|
|
f'{self._label}: reached keepalive time-out'
|
|
f' ({since:.1f}s).')
|
|
raise _KeepaliveTimeoutError()
|
|
|
|
async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None:
|
|
try:
|
|
await call
|
|
except Exception as exc:
|
|
# We expect connection errors to put us here, but make noise
|
|
# if something else does.
|
|
if not self._is_expected_connection_error(exc):
|
|
logging.exception('Unexpected error in rpc %s %s task.',
|
|
self._label, tasklabel)
|
|
else:
|
|
if self._debug_print:
|
|
self._debug_print_call(
|
|
f'{self._label}: {tasklabel} task will exit cleanly'
|
|
f' due to {exc!r}.')
|
|
finally:
|
|
# Any core task exiting triggers shutdown.
|
|
if self._debug_print:
|
|
self._debug_print_call(
|
|
f'{self._label}: {tasklabel} task exiting...')
|
|
self.close()
|
|
|
|
async def _handle_raw_message(self, message_id: int,
|
|
message: bytes) -> None:
|
|
try:
|
|
response = await self._handle_raw_message_call(message)
|
|
except Exception:
|
|
# We expect local message handler to always succeed.
|
|
# If that doesn't happen, make a fuss so we know to fix it.
|
|
# The other end will simply never get a response to this
|
|
# message.
|
|
logging.exception('Error handling raw rpc message')
|
|
return
|
|
|
|
assert self._peer_info is not None
|
|
|
|
if self._peer_info.protocol == 1:
|
|
if len(response) > 65535:
|
|
raise RuntimeError(
|
|
'Response cannot be larger than 65535 bytes')
|
|
|
|
# Now send back our response.
|
|
# Payload consists of type (1b), msgid (2b), len (2b), and data.
|
|
if len(response) > 65535:
|
|
self._enqueue_outgoing_packet(
|
|
_PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER) +
|
|
message_id.to_bytes(2, _BYTE_ORDER) +
|
|
len(response).to_bytes(4, _BYTE_ORDER) + response)
|
|
else:
|
|
self._enqueue_outgoing_packet(
|
|
_PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER) +
|
|
message_id.to_bytes(2, _BYTE_ORDER) +
|
|
len(response).to_bytes(2, _BYTE_ORDER) + response)
|
|
|
|
async def _read_int_8(self) -> int:
|
|
return int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER)
|
|
|
|
async def _read_int_16(self) -> int:
|
|
return int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER)
|
|
|
|
async def _read_int_32(self) -> int:
|
|
return int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER)
|
|
|
|
@classmethod
|
|
def _is_expected_connection_error(cls, exc: Exception) -> bool:
|
|
"""Stuff we expect to end our connection in normal circumstances."""
|
|
|
|
if isinstance(exc, _KeepaliveTimeoutError):
|
|
return True
|
|
|
|
return is_asyncio_streams_communication_error(exc)
|
|
|
|
def _check_env(self) -> None:
|
|
# I was seeing that asyncio stuff wasn't working as expected if
|
|
# created in one thread and used in another, so let's enforce
|
|
# a single thread for all use of an instance.
|
|
if current_thread() is not self._thread:
|
|
raise RuntimeError('This must be called from the same thread'
|
|
' that the endpoint was created in.')
|
|
|
|
# This should always be the case if thread is the same.
|
|
assert asyncio.get_running_loop() is self._event_loop
|
|
|
|
def _enqueue_outgoing_packet(self, data: bytes) -> None:
|
|
"""Enqueue a raw packet to be sent. Must be called from our loop."""
|
|
self._check_env()
|
|
|
|
if self._debug_print_io:
|
|
self._debug_print_call(f'{self._label}: enqueueing outgoing packet'
|
|
f' {data[:50]!r} at {self._tm()}.')
|
|
|
|
# Add the data and let our write task know about it.
|
|
self._out_packets.append(data)
|
|
self._have_out_packets.set()
|
|
|
|
def _prune_tasks(self) -> None:
|
|
out: list[weakref.ref[asyncio.Task]] = []
|
|
for task_weak_ref in self._tasks:
|
|
task = task_weak_ref()
|
|
if task is not None and not task.done():
|
|
out.append(task_weak_ref)
|
|
self._tasks = out
|
|
|
|
def _get_live_tasks(self) -> list[asyncio.Task]:
|
|
out: list[asyncio.Task] = []
|
|
for task_weak_ref in self._tasks:
|
|
task = task_weak_ref()
|
|
if task is not None and not task.done():
|
|
out.append(task)
|
|
return out
|