sync changes with master

This commit is contained in:
Ayush Saini 2022-06-30 00:31:52 +05:30
parent 03034e4aa0
commit de0199ad50
178 changed files with 2191 additions and 1481 deletions

View file

@ -19,7 +19,7 @@ from efro.dataclassio import (dataclass_to_json, dataclass_from_json,
ioprepped, IOAttrs)
if TYPE_CHECKING:
from typing import Literal, Awaitable, Callable, Optional
from typing import Literal, Awaitable, Callable
# Terminology:
# Packet: A chunk of data consisting of a type and some type-dependent
@ -33,6 +33,8 @@ class _PacketType(Enum):
KEEPALIVE = 1
MESSAGE = 2
RESPONSE = 3
MESSAGE_BIG = 4
RESPONSE_BIG = 5
_BYTE_ORDER: Literal['big'] = 'big'
@ -49,14 +51,20 @@ class _PeerInfo:
keepalive_interval: Annotated[float, IOAttrs('k')]
OUR_PROTOCOL = 1
# Note: we are expected to be forward and backward compatible; we can
# increment protocol freely and expect everyone else to still talk to us.
# Likewise we should retain logic to communicate with older protocols.
# Protocol history:
# 1 - initial release
# 2 - gained big (32-bit len val) package/response packets
OUR_PROTOCOL = 2
class _InFlightMessage:
"""Represents a message that is out on the wire."""
def __init__(self) -> None:
self._response: Optional[bytes] = None
self._response: bytes | None = None
self._got_response = asyncio.Event()
self.wait_task = asyncio.create_task(self._wait())
@ -126,7 +134,7 @@ class RPCEndpoint:
self._out_packets: list[bytes] = []
self._have_out_packets = asyncio.Event()
self._run_called = False
self._peer_info: Optional[_PeerInfo] = None
self._peer_info: _PeerInfo | None = None
self._keepalive_interval = keepalive_interval
self._keepalive_timeout = keepalive_timeout
@ -135,7 +143,7 @@ class RPCEndpoint:
self._tasks: list[weakref.ref[asyncio.Task]] = []
# When we last got a keepalive or equivalent (time.monotonic value)
self._last_keepalive_receive_time: Optional[float] = None
self._last_keepalive_receive_time: float | None = None
# (Start near the end to make sure our looping logic is sound).
self._next_message_id = 65530
@ -193,7 +201,7 @@ class RPCEndpoint:
async def send_message(self,
message: bytes,
timeout: Optional[float] = None) -> bytes:
timeout: float | None = None) -> bytes:
"""Send a message to the peer and return a response.
If timeout is not provided, the default will be used.
@ -201,21 +209,38 @@ class RPCEndpoint:
for any reason.
"""
self._check_env()
if len(message) > 65535:
raise RuntimeError('Message cannot be larger than 65535 bytes')
if self._closing:
raise CommunicationError('Endpoint is closed')
# Go with 16 bit looping value for message_id.
# We need to know their protocol, so if we haven't gotten a handshake
# from them yet, just wait.
while self._peer_info is None:
await asyncio.sleep(0.01)
assert self._peer_info is not None
if self._peer_info.protocol == 1:
if len(message) > 65535:
raise RuntimeError('Message cannot be larger than 65535 bytes')
# message_id is a 16 bit looping value.
message_id = self._next_message_id
self._next_message_id = (self._next_message_id + 1) % 65536
# Payload consists of type (1b), message_id (2b), len (2b), and data.
self._enqueue_outgoing_packet(
_PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) +
message_id.to_bytes(2, _BYTE_ORDER) +
len(message).to_bytes(2, _BYTE_ORDER) + message)
if len(message) > 65535:
# Payload consists of type (1b), message_id (2b),
# len (4b), and data.
self._enqueue_outgoing_packet(
_PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER) +
message_id.to_bytes(2, _BYTE_ORDER) +
len(message).to_bytes(4, _BYTE_ORDER) + message)
else:
# Payload consists of type (1b), message_id (2b),
# len (2b), and data.
self._enqueue_outgoing_packet(
_PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) +
message_id.to_bytes(2, _BYTE_ORDER) +
len(message).to_bytes(2, _BYTE_ORDER) + message)
# Make an entry so we know this message is out there.
assert message_id not in self._in_flight_messages
@ -381,17 +406,27 @@ class RPCEndpoint:
self._last_keepalive_receive_time = time.monotonic()
elif mtype is _PacketType.MESSAGE:
await self._handle_message_packet()
await self._handle_message_packet(big=False)
elif mtype is _PacketType.MESSAGE_BIG:
await self._handle_message_packet(big=True)
elif mtype is _PacketType.RESPONSE:
await self._handle_response_packet()
await self._handle_response_packet(big=False)
elif mtype is _PacketType.RESPONSE_BIG:
await self._handle_response_packet(big=True)
else:
assert_never(mtype)
async def _handle_message_packet(self) -> None:
async def _handle_message_packet(self, big: bool) -> None:
assert self._peer_info is not None
msgid = await self._read_int_16()
msglen = await self._read_int_16()
if big:
msglen = await self._read_int_32()
else:
msglen = await self._read_int_16()
msg = await self._reader.readexactly(msglen)
if self._debug_print_io:
self._debug_print_call(f'{self._label}: received message {msgid}'
@ -408,9 +443,14 @@ class RPCEndpoint:
self._debug_print_call(
f'{self._label}: done handling message at {self._tm()}.')
async def _handle_response_packet(self) -> None:
async def _handle_response_packet(self, big: bool) -> None:
assert self._peer_info is not None
msgid = await self._read_int_16()
rsplen = await self._read_int_16()
# Protocol 2 gained 32 bit data lengths.
if big:
rsplen = await self._read_int_32()
else:
rsplen = await self._read_int_16()
if self._debug_print_io:
self._debug_print_call(f'{self._label}: received response {msgid}'
f' of size {rsplen} at {self._tm()}.')
@ -520,12 +560,25 @@ class RPCEndpoint:
logging.exception('Error handling message')
return
assert self._peer_info is not None
if self._peer_info.protocol == 1:
if len(response) > 65535:
raise RuntimeError(
'Response cannot be larger than 65535 bytes')
# Now send back our response.
# Payload consists of type (1b), msgid (2b), len (2b), and data.
self._enqueue_outgoing_packet(
_PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER) +
message_id.to_bytes(2, _BYTE_ORDER) +
len(response).to_bytes(2, _BYTE_ORDER) + response)
if len(response) > 65535:
self._enqueue_outgoing_packet(
_PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER) +
message_id.to_bytes(2, _BYTE_ORDER) +
len(response).to_bytes(4, _BYTE_ORDER) + response)
else:
self._enqueue_outgoing_packet(
_PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER) +
message_id.to_bytes(2, _BYTE_ORDER) +
len(response).to_bytes(2, _BYTE_ORDER) + response)
async def _read_int_8(self) -> int:
return int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER)