Refactoring to asynchronous. (partially completed)

This commit is contained in:
BruceChen 2022-12-20 22:41:14 +08:00
parent 7ee08092d4
commit 096ea0c70c
72 changed files with 6033 additions and 5080 deletions

View file

@ -0,0 +1,203 @@
using System;
using System.IO;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using MinecraftClient.Crypto;
using MinecraftClient.Crypto.AesHandler;
using static ConsoleInteractive.ConsoleReader;
namespace MinecraftClient.Protocol.PacketPipeline
{
public class AesStream : Stream
{
public const int BlockSize = 16;
private const int BufferSize = 1024;
public Socket Client;
private bool inStreamEnded = false;
private readonly IAesHandler Aes;
private int InputBufPos = 0, OutputBufPos = 0;
private readonly Memory<byte> InputBuf, OutputBuf;
private readonly Memory<byte> AesBufRead, AesBufSend;
public override bool CanRead => true;
public override bool CanSeek => false;
public override bool CanWrite => false;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
public AesStream(Socket socket, byte[] key)
{
Client = socket;
InputBuf = new byte[BufferSize + BlockSize];
OutputBuf = new byte[BufferSize + BlockSize];
AesBufRead = new byte[BlockSize];
AesBufSend = new byte[BlockSize];
if (FasterAesX86.IsSupported())
Aes = new FasterAesX86(key);
else if (FasterAesArm.IsSupported())
Aes = new FasterAesArm(key);
else
Aes = new BasicAes(key);
key.CopyTo(InputBuf.Slice(0, BlockSize));
key.CopyTo(OutputBuf.Slice(0, BlockSize));
}
public override void Flush()
{
throw new NotSupportedException();
}
public override int Read(byte[] buffer, int offset, int count)
{
var task = ReadAsync(buffer.AsMemory(offset, count)).AsTask();
task.Wait();
return task.Result;
}
public override int ReadByte()
{
if (inStreamEnded)
return -1;
var task = Client.ReceiveAsync(InputBuf.Slice(InputBufPos + BlockSize, 1)).AsTask();
task.Wait();
if (task.Result == 0)
{
inStreamEnded = true;
return -1;
}
Aes.EncryptEcb(InputBuf.Slice(InputBufPos, BlockSize).Span, AesBufRead.Span);
byte result = (byte)(AesBufRead.Span[0] ^ InputBuf.Span[InputBufPos + BlockSize]);
InputBufPos++;
if (InputBufPos == BufferSize)
{
InputBuf.Slice(BufferSize, BlockSize).CopyTo(InputBuf[..BlockSize]);
InputBufPos = 0;
}
return result;
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
if (inStreamEnded)
return 0;
int readLimit = Math.Min(buffer.Length, BufferSize - InputBufPos);
int curRead = await Client.ReceiveAsync(InputBuf.Slice(InputBufPos + BlockSize, readLimit), cancellationToken);
if (curRead == 0 || cancellationToken.IsCancellationRequested)
{
if (curRead == 0)
inStreamEnded = true;
return curRead;
}
for (int idx = 0; idx < curRead; idx++)
{
Aes.EncryptEcb(InputBuf.Slice(InputBufPos + idx, BlockSize).Span, AesBufRead.Span);
buffer.Span[idx] = (byte)(AesBufRead.Span[0] ^ InputBuf.Span[InputBufPos + BlockSize + idx]);
}
InputBufPos += curRead;
if (InputBufPos == BufferSize)
{
InputBuf.Slice(BufferSize, BlockSize).CopyTo(InputBuf[..BlockSize]);
InputBufPos = 0;
}
return curRead;
}
public new async ValueTask ReadExactlyAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
if (inStreamEnded)
return;
for (int readed = 0, curRead; readed < buffer.Length; readed += curRead)
{
int readLimit = Math.Min(buffer.Length - readed, BufferSize - InputBufPos);
curRead = await Client.ReceiveAsync(InputBuf.Slice(InputBufPos + BlockSize, readLimit), cancellationToken);
if (curRead == 0 || cancellationToken.IsCancellationRequested)
{
if (curRead == 0)
inStreamEnded = true;
return;
}
for (int idx = 0; idx < curRead; idx++)
{
Aes.EncryptEcb(InputBuf.Slice(InputBufPos + idx, BlockSize).Span, AesBufRead.Span);
buffer.Span[readed + idx] = (byte)(AesBufRead.Span[0] ^ InputBuf.Span[InputBufPos + BlockSize + idx]);
}
InputBufPos += curRead;
if (InputBufPos == BufferSize)
{
InputBuf.Slice(BufferSize, BlockSize).CopyTo(InputBuf.Slice(0, BlockSize));
InputBufPos = 0;
}
}
}
public async ValueTask<int> ReadRawAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
return await Client.ReceiveAsync(buffer, cancellationToken);
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
public override void Write(byte[] buffer, int offset, int count)
{
WriteAsync(buffer.AsMemory(offset, count)).AsTask().Wait();
}
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
int outputStartPos = OutputBufPos;
for (int wirtten = 0; wirtten < buffer.Length; ++wirtten)
{
if (cancellationToken.IsCancellationRequested)
return;
Aes.EncryptEcb(OutputBuf.Slice(OutputBufPos, BlockSize).Span, AesBufSend.Span);
OutputBuf.Span[OutputBufPos + BlockSize] = (byte)(AesBufSend.Span[0] ^ buffer.Span[wirtten]);
if (++OutputBufPos == BufferSize)
{
await Client.SendAsync(OutputBuf.Slice(outputStartPos + BlockSize, BufferSize - outputStartPos), cancellationToken);
OutputBuf.Slice(BufferSize, BlockSize).CopyTo(OutputBuf.Slice(0, BlockSize));
OutputBufPos = outputStartPos = 0;
}
}
if (OutputBufPos > outputStartPos)
await Client.SendAsync(OutputBuf.Slice(outputStartPos + BlockSize, OutputBufPos - outputStartPos), cancellationToken);
return;
}
}
}

View file

@ -0,0 +1,189 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using static ConsoleInteractive.ConsoleReader;
namespace MinecraftClient.Protocol.PacketPipeline
{
internal class PacketStream : Stream
{
public override bool CanRead => true;
public override bool CanSeek => false;
public override bool CanWrite => false;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
readonly CancellationToken CancelToken;
private readonly Stream baseStream;
private readonly AesStream? aesStream;
private ZLibStream? zlibStream;
private int packetSize, packetReaded;
internal const int DropBufSize = 1024;
internal static readonly Memory<byte> DropBuf = new byte[DropBufSize];
private static readonly byte[] SingleByteBuf = new byte[1];
public PacketStream(ZLibStream zlibStream, int packetSize, CancellationToken cancellationToken = default)
{
CancelToken = cancellationToken;
this.aesStream = null;
this.zlibStream = zlibStream;
this.baseStream = zlibStream;
this.packetReaded = 0;
this.packetSize = packetSize;
}
public PacketStream(AesStream aesStream, int packetSize, CancellationToken cancellationToken = default)
{
CancelToken = cancellationToken;
this.aesStream = aesStream;
this.zlibStream = null;
this.baseStream = aesStream;
this.packetReaded = 0;
this.packetSize = packetSize;
}
public PacketStream(Stream baseStream, int packetSize, CancellationToken cancellationToken = default)
{
CancelToken = cancellationToken;
this.aesStream = null;
this.zlibStream = null;
this.baseStream = baseStream;
this.packetReaded = 0;
this.packetSize = packetSize;
}
public override void Flush()
{
throw new NotSupportedException();
}
public new byte ReadByte()
{
++packetReaded;
if (packetReaded > packetSize)
throw new OverflowException("Reach the end of the packet!");
baseStream.Read(SingleByteBuf, 0, 1);
return SingleByteBuf[0];
}
public async Task<byte> ReadByteAsync()
{
++packetReaded;
if (packetReaded > packetSize)
throw new OverflowException("Reach the end of the packet!");
await baseStream.ReadExactlyAsync(SingleByteBuf, CancelToken);
return SingleByteBuf[0];
}
public override int Read(byte[] buffer, int offset, int count)
{
if (packetReaded + buffer.Length > packetSize)
throw new OverflowException("Reach the end of the packet!");
int readed = baseStream.Read(buffer, offset, count);
packetReaded += readed;
return readed;
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
if (packetReaded + buffer.Length > packetSize)
throw new OverflowException("Reach the end of the packet!");
int readed = await baseStream.ReadAsync(buffer, CancelToken);
packetReaded += readed;
return readed;
}
public new async ValueTask ReadExactlyAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
if (packetReaded + buffer.Length > packetSize)
throw new OverflowException("Reach the end of the packet!");
await baseStream.ReadExactlyAsync(buffer, CancelToken);
packetReaded += buffer.Length;
}
public async Task<byte[]> ReadFullPacket()
{
byte[] buffer = new byte[packetSize - packetReaded];
await ReadExactlyAsync(buffer);
packetReaded = packetSize;
return buffer;
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotSupportedException();
}
public async Task Skip(int length)
{
if (zlibStream != null)
{
for (int readed = 0, curRead; readed < length; readed += curRead)
curRead = await zlibStream.ReadAsync(DropBuf[..Math.Min(DropBufSize, length - readed)]);
}
else if (aesStream != null)
{
int skipRaw = length - AesStream.BlockSize;
for (int readed = 0, curRead; readed < skipRaw; readed += curRead)
curRead = await aesStream.ReadRawAsync(DropBuf[..Math.Min(DropBufSize, skipRaw - readed)]);
await aesStream.ReadAsync(DropBuf[..Math.Min(length, AesStream.BlockSize)]);
}
else
{
for (int readed = 0, curRead; readed < length; readed += curRead)
curRead = await baseStream.ReadAsync(DropBuf[..Math.Min(DropBufSize, length - readed)]);
}
packetReaded += length;
}
public override async ValueTask DisposeAsync()
{
if (CancelToken.IsCancellationRequested)
return;
if (zlibStream != null)
{
await zlibStream.DisposeAsync();
zlibStream = null;
packetReaded = packetSize;
}
else
{
if (packetSize - packetReaded > 0)
{
// ConsoleIO.WriteLine("Plain readed " + packetReaded + ", last " + (packetSize - packetReaded));
await Skip(packetSize - packetReaded);
}
}
}
}
}

View file

@ -0,0 +1,157 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Compression;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using MinecraftClient.Crypto;
namespace MinecraftClient.Protocol.PacketPipeline
{
/// <summary>
/// Wrapper for handling unencrypted & encrypted socket
/// </summary>
class SocketWrapper
{
private TcpClient tcpClient;
private AesStream? AesStream;
private PacketStream? packetStream = null;
private Stream ReadStream, WriteStream;
private bool Encrypted = false;
public int CompressionThreshold { get; set; } = 0;
private SemaphoreSlim SendSemaphore = new SemaphoreSlim(1, 1);
private Task LastSendTask = Task.CompletedTask;
/// <summary>
/// Initialize a new SocketWrapper
/// </summary>
/// <param name="client">TcpClient connected to the server</param>
public SocketWrapper(TcpClient client)
{
tcpClient = client;
ReadStream = WriteStream = client.GetStream();
}
/// <summary>
/// Check if the socket is still connected
/// </summary>
/// <returns>TRUE if still connected</returns>
/// <remarks>Silently dropped connection can only be detected by attempting to read/write data</remarks>
public bool IsConnected()
{
return tcpClient.Client != null && tcpClient.Connected;
}
/// <summary>
/// Check if the socket has data available to read
/// </summary>
/// <returns>TRUE if data is available to read</returns>
public bool HasDataAvailable()
{
return tcpClient.Client.Available > 0;
}
/// <summary>
/// Switch network reading/writing to an encrypted stream
/// </summary>
/// <param name="secretKey">AES secret key</param>
public void SwitchToEncrypted(byte[] secretKey)
{
if (Encrypted)
throw new InvalidOperationException("Stream is already encrypted!?");
Encrypted = true;
ReadStream = WriteStream = AesStream = new AesStream(tcpClient.Client, secretKey);
}
/// <summary>
/// Send raw data to the server.
/// </summary>
/// <param name="buffer">data to send</param>
public async Task SendAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
await SendSemaphore.WaitAsync();
await LastSendTask;
LastSendTask = WriteStream.WriteAsync(buffer, cancellationToken).AsTask();
SendSemaphore.Release();
}
public async Task<Tuple<int, PacketStream>> GetNextPacket(bool handleCompress, CancellationToken cancellationToken = default)
{
// ConsoleIO.WriteLine("GetNextPacket");
if (packetStream != null)
{
await packetStream.DisposeAsync();
packetStream = null;
}
int readed = 0;
(int packetSize, _) = await ReceiveVarIntRaw(ReadStream, cancellationToken);
int packetID;
if (handleCompress && CompressionThreshold > 0)
{
(int sizeUncompressed, readed) = await ReceiveVarIntRaw(ReadStream, cancellationToken);
if (sizeUncompressed != 0)
{
ZlibBaseStream zlibBaseStream = new(AesStream ?? ReadStream, packetSize: packetSize - readed);
ZLibStream zlibStream = new(zlibBaseStream, CompressionMode.Decompress, leaveOpen: false);
zlibBaseStream.BufferSize = 16;
(packetID, readed) = await ReceiveVarIntRaw(zlibStream, cancellationToken);
zlibBaseStream.BufferSize = 512;
// ConsoleIO.WriteLine("packetID = " + packetID + ", readed = " + zlibBaseStream.packetReaded + ", size = " + packetSize + " -> " + sizeUncompressed);
packetStream = new(zlibStream, sizeUncompressed - readed, cancellationToken);
return new(packetID, packetStream);
}
}
(packetID, int readed2) = await ReceiveVarIntRaw(ReadStream, cancellationToken);
packetStream = new(AesStream ?? ReadStream, packetSize - readed - readed2, cancellationToken);
return new(packetID, packetStream);
}
private async Task<Tuple<int, int>> ReceiveVarIntRaw(Stream stream, CancellationToken cancellationToken = default)
{
int i = 0;
int j = 0;
byte[] b = new byte[1];
while (true)
{
await stream.ReadAsync(b);
i |= (b[0] & 0x7F) << j++ * 7;
if (j > 5) throw new OverflowException("VarInt too big");
if ((b[0] & 0x80) != 128) break;
}
return new(i, j);
}
/// <summary>
/// Disconnect from the server
/// </summary>
public void Disconnect()
{
try
{
tcpClient.Close();
}
catch (SocketException) { }
catch (IOException) { }
catch (NullReferenceException) { }
catch (ObjectDisposedException) { }
}
}
}

View file

@ -0,0 +1,119 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using MinecraftClient.Crypto;
using static ConsoleInteractive.ConsoleReader;
namespace MinecraftClient.Protocol.PacketPipeline
{
internal class ZlibBaseStream : Stream
{
public override bool CanRead => true;
public override bool CanSeek => false;
public override bool CanWrite => false;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
public int BufferSize { get; set; } = 16;
public int packetSize = 0, packetReaded = 0;
private Stream baseStream;
private AesStream? aesStream;
public ZlibBaseStream(Stream baseStream, int packetSize)
{
packetReaded = 0;
this.packetSize = packetSize;
this.baseStream = baseStream;
aesStream = null;
}
public ZlibBaseStream(AesStream aesStream, int packetSize)
{
packetReaded = 0;
this.packetSize = packetSize;
baseStream = this.aesStream = aesStream;
}
public override void Flush()
{
throw new NotSupportedException();
}
public override int Read(byte[] buffer, int offset, int count)
{
if (packetReaded == packetSize)
return 0;
int readed = baseStream.Read(buffer, offset, Math.Min(BufferSize, Math.Min(count, packetSize - packetReaded)));
packetReaded += readed;
return readed;
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
int readLen = Math.Min(BufferSize, Math.Min(buffer.Length, packetSize - packetReaded));
if (packetReaded + readLen > packetSize)
throw new OverflowException("Reach the end of the packet!");
await baseStream.ReadExactlyAsync(buffer[..readLen], cancellationToken);
packetReaded += readLen;
return readLen;
}
public new async ValueTask ReadExactlyAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
if (packetReaded + buffer.Length > packetSize)
throw new OverflowException("Reach the end of the packet!");
await baseStream.ReadExactlyAsync(buffer, cancellationToken);
packetReaded += buffer.Length;
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotSupportedException();
}
public async Task Skip(int length)
{
if (aesStream != null)
{
int skipRaw = length - AesStream.BlockSize;
for (int readed = 0, curRead; readed < skipRaw; readed += curRead)
curRead = await aesStream.ReadRawAsync(PacketStream.DropBuf[..Math.Min(PacketStream.DropBufSize, skipRaw - readed)]);
await aesStream.ReadAsync(PacketStream.DropBuf[..Math.Min(length, AesStream.BlockSize)]);
}
else
{
for (int readed = 0, curRead; readed < length; readed += curRead)
curRead = await baseStream.ReadAsync(PacketStream.DropBuf[..Math.Min(PacketStream.DropBufSize, length - readed)]);
}
packetReaded += length;
}
public override async ValueTask DisposeAsync()
{
if (packetSize - packetReaded > 0)
{
// ConsoleIO.WriteLine("Zlib readed " + packetReaded + ", last " + (packetSize - packetReaded));
await Skip(packetSize - packetReaded);
}
}
}
}