From 842c968220757eadb686cf02d7e2c961dadf9132 Mon Sep 17 00:00:00 2001 From: BruceChen Date: Sun, 28 Aug 2022 10:50:52 +0800 Subject: [PATCH] Use AES-NI instruction set if possible --- MinecraftClient/Crypto/AesContext.cs | 97 +++++++++++++++++++ .../Crypto/Streams/AesCfb8Stream.cs | 23 ++++- 2 files changed, 115 insertions(+), 5 deletions(-) create mode 100644 MinecraftClient/Crypto/AesContext.cs diff --git a/MinecraftClient/Crypto/AesContext.cs b/MinecraftClient/Crypto/AesContext.cs new file mode 100644 index 00000000..bbaab330 --- /dev/null +++ b/MinecraftClient/Crypto/AesContext.cs @@ -0,0 +1,97 @@ +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; + +namespace MinecraftClient.Crypto +{ + // Using the AES-NI instruction set + // https://gist.github.com/Thealexbarney/9f75883786a9f3100408ff795fb95d85 + public class AesContext + { + private Vector128[] RoundKeys { get; } + + public byte[] Iv { get; } = new byte[0x10]; + + public AesContext(Span key) + { + RoundKeys = KeyExpansion(key); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public void EncryptEcb(ReadOnlySpan plaintext, Span destination) + { + Vector128[] keys = RoundKeys; + + ReadOnlySpan> blocks = MemoryMarshal.Cast>(plaintext); + Span> dest = MemoryMarshal.Cast>(destination); + + // Makes the JIT remove all the other range checks on keys + _ = keys[10]; + + for (int i = 0; i < blocks.Length; i++) + { + Vector128 b = blocks[i]; + + b = Sse2.Xor(b, keys[0]); + b = Aes.Encrypt(b, keys[1]); + b = Aes.Encrypt(b, keys[2]); + b = Aes.Encrypt(b, keys[3]); + b = Aes.Encrypt(b, keys[4]); + b = Aes.Encrypt(b, keys[5]); + b = Aes.Encrypt(b, keys[6]); + b = Aes.Encrypt(b, keys[7]); + b = Aes.Encrypt(b, keys[8]); + b = Aes.Encrypt(b, keys[9]); + b = Aes.EncryptLast(b, keys[10]); + + dest[i] = b; + } + } + + private static Vector128[] KeyExpansion(Span key) + { + var keys = new Vector128[20]; + + keys[0] = Unsafe.ReadUnaligned>(ref key[0]); + + MakeRoundKey(keys, 1, 0x01); + MakeRoundKey(keys, 2, 0x02); + MakeRoundKey(keys, 3, 0x04); + MakeRoundKey(keys, 4, 0x08); + MakeRoundKey(keys, 5, 0x10); + MakeRoundKey(keys, 6, 0x20); + MakeRoundKey(keys, 7, 0x40); + MakeRoundKey(keys, 8, 0x80); + MakeRoundKey(keys, 9, 0x1b); + MakeRoundKey(keys, 10, 0x36); + + for (int i = 1; i < 10; i++) + { + keys[10 + i] = Aes.InverseMixColumns(keys[i]); + } + + return keys; + } + + private static void MakeRoundKey(Vector128[] keys, int i, byte rcon) + { + Vector128 s = keys[i - 1]; + Vector128 t = keys[i - 1]; + + t = Aes.KeygenAssist(t, rcon); + t = Sse2.Shuffle(t.AsUInt32(), 0xFF).AsByte(); + + s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 4)); + s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 8)); + + keys[i] = Sse2.Xor(s, t); + } + + public void SetIv(Span iv) + { + iv.Slice(0, 0x10).CopyTo(Iv); + } + } +} \ No newline at end of file diff --git a/MinecraftClient/Crypto/Streams/AesCfb8Stream.cs b/MinecraftClient/Crypto/Streams/AesCfb8Stream.cs index 038a1c12..1b1a80a7 100644 --- a/MinecraftClient/Crypto/Streams/AesCfb8Stream.cs +++ b/MinecraftClient/Crypto/Streams/AesCfb8Stream.cs @@ -6,6 +6,7 @@ using System.Threading.Tasks; using System.Security.Cryptography; using System.IO; using System.Collections.Concurrent; +using System.Runtime.CompilerServices; namespace MinecraftClient.Crypto.Streams { @@ -13,7 +14,9 @@ namespace MinecraftClient.Crypto.Streams { public static readonly int blockSize = 16; - private Aes aes; + private readonly Aes? Aes = null; + + private readonly AesContext? FastAes = null; public System.IO.Stream BaseStream { get; set; } @@ -26,7 +29,10 @@ namespace MinecraftClient.Crypto.Streams { BaseStream = stream; - aes = GenerateAES(key); + if (System.Runtime.Intrinsics.X86.Sse2.IsSupported && System.Runtime.Intrinsics.X86.Aes.IsSupported) + FastAes = new AesContext(key); + else + Aes = GenerateAES(key); Array.Copy(key, ReadStreamIV, 16); Array.Copy(key, WriteStreamIV, 16); @@ -76,6 +82,7 @@ namespace MinecraftClient.Crypto.Streams return temp[0]; } + [MethodImpl(MethodImplOptions.AggressiveOptimization)] public override int Read(byte[] buffer, int outOffset, int required) { if (this.inStreamEnded) @@ -93,7 +100,7 @@ namespace MinecraftClient.Crypto.Streams return readed; } - OrderablePartitioner> rangePartitioner = (curRead <= 256) ? + OrderablePartitioner> rangePartitioner = (curRead <= 256) ? Partitioner.Create(readed, readed + curRead, 32) : Partitioner.Create(readed, readed + curRead); Parallel.ForEach(rangePartitioner, (range, loopState) => { @@ -101,7 +108,10 @@ namespace MinecraftClient.Crypto.Streams for (int idx = range.Item1; idx < range.Item2; idx++) { ReadOnlySpan blockInput = new(inputBuf, idx, blockSize); - aes.EncryptEcb(blockInput, blockOutput, PaddingMode.None); + if (FastAes != null) + FastAes.EncryptEcb(blockInput, blockOutput); + else + Aes!.EncryptEcb(blockInput, blockOutput, PaddingMode.None); buffer[outOffset + idx] = (byte)(blockOutput[0] ^ inputBuf[idx + blockSize]); } }); @@ -136,7 +146,10 @@ namespace MinecraftClient.Crypto.Streams for (int wirtten = 0; wirtten < required; ++wirtten) { ReadOnlySpan blockInput = new(outputBuf, wirtten, blockSize); - aes.EncryptEcb(blockInput, blockOutput, PaddingMode.None); + if (FastAes != null) + FastAes.EncryptEcb(blockInput, blockOutput); + else + Aes!.EncryptEcb(blockInput, blockOutput, PaddingMode.None); outputBuf[blockSize + wirtten] = (byte)(blockOutput[0] ^ input[offset + wirtten]); } BaseStream.WriteAsync(outputBuf, blockSize, required);