Use AES-NI instruction set if possible

This commit is contained in:
BruceChen 2022-08-28 10:50:52 +08:00
parent 13d1a9856a
commit 842c968220
2 changed files with 115 additions and 5 deletions

View file

@ -0,0 +1,97 @@
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
namespace MinecraftClient.Crypto
{
// Using the AES-NI instruction set
// https://gist.github.com/Thealexbarney/9f75883786a9f3100408ff795fb95d85
public class AesContext
{
private Vector128<byte>[] RoundKeys { get; }
public byte[] Iv { get; } = new byte[0x10];
public AesContext(Span<byte> key)
{
RoundKeys = KeyExpansion(key);
}
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
public void EncryptEcb(ReadOnlySpan<byte> plaintext, Span<byte> destination)
{
Vector128<byte>[] keys = RoundKeys;
ReadOnlySpan<Vector128<byte>> blocks = MemoryMarshal.Cast<byte, Vector128<byte>>(plaintext);
Span<Vector128<byte>> dest = MemoryMarshal.Cast<byte, Vector128<byte>>(destination);
// Makes the JIT remove all the other range checks on keys
_ = keys[10];
for (int i = 0; i < blocks.Length; i++)
{
Vector128<byte> b = blocks[i];
b = Sse2.Xor(b, keys[0]);
b = Aes.Encrypt(b, keys[1]);
b = Aes.Encrypt(b, keys[2]);
b = Aes.Encrypt(b, keys[3]);
b = Aes.Encrypt(b, keys[4]);
b = Aes.Encrypt(b, keys[5]);
b = Aes.Encrypt(b, keys[6]);
b = Aes.Encrypt(b, keys[7]);
b = Aes.Encrypt(b, keys[8]);
b = Aes.Encrypt(b, keys[9]);
b = Aes.EncryptLast(b, keys[10]);
dest[i] = b;
}
}
private static Vector128<byte>[] KeyExpansion(Span<byte> key)
{
var keys = new Vector128<byte>[20];
keys[0] = Unsafe.ReadUnaligned<Vector128<byte>>(ref key[0]);
MakeRoundKey(keys, 1, 0x01);
MakeRoundKey(keys, 2, 0x02);
MakeRoundKey(keys, 3, 0x04);
MakeRoundKey(keys, 4, 0x08);
MakeRoundKey(keys, 5, 0x10);
MakeRoundKey(keys, 6, 0x20);
MakeRoundKey(keys, 7, 0x40);
MakeRoundKey(keys, 8, 0x80);
MakeRoundKey(keys, 9, 0x1b);
MakeRoundKey(keys, 10, 0x36);
for (int i = 1; i < 10; i++)
{
keys[10 + i] = Aes.InverseMixColumns(keys[i]);
}
return keys;
}
private static void MakeRoundKey(Vector128<byte>[] keys, int i, byte rcon)
{
Vector128<byte> s = keys[i - 1];
Vector128<byte> t = keys[i - 1];
t = Aes.KeygenAssist(t, rcon);
t = Sse2.Shuffle(t.AsUInt32(), 0xFF).AsByte();
s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 4));
s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 8));
keys[i] = Sse2.Xor(s, t);
}
public void SetIv(Span<byte> iv)
{
iv.Slice(0, 0x10).CopyTo(Iv);
}
}
}

View file

@ -6,6 +6,7 @@ using System.Threading.Tasks;
using System.Security.Cryptography;
using System.IO;
using System.Collections.Concurrent;
using System.Runtime.CompilerServices;
namespace MinecraftClient.Crypto.Streams
{
@ -13,7 +14,9 @@ namespace MinecraftClient.Crypto.Streams
{
public static readonly int blockSize = 16;
private Aes aes;
private readonly Aes? Aes = null;
private readonly AesContext? FastAes = null;
public System.IO.Stream BaseStream { get; set; }
@ -26,7 +29,10 @@ namespace MinecraftClient.Crypto.Streams
{
BaseStream = stream;
aes = GenerateAES(key);
if (System.Runtime.Intrinsics.X86.Sse2.IsSupported && System.Runtime.Intrinsics.X86.Aes.IsSupported)
FastAes = new AesContext(key);
else
Aes = GenerateAES(key);
Array.Copy(key, ReadStreamIV, 16);
Array.Copy(key, WriteStreamIV, 16);
@ -76,6 +82,7 @@ namespace MinecraftClient.Crypto.Streams
return temp[0];
}
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public override int Read(byte[] buffer, int outOffset, int required)
{
if (this.inStreamEnded)
@ -93,7 +100,7 @@ namespace MinecraftClient.Crypto.Streams
return readed;
}
OrderablePartitioner<Tuple<int, int>> rangePartitioner = (curRead <= 256) ?
OrderablePartitioner<Tuple<int, int>> rangePartitioner = (curRead <= 256) ?
Partitioner.Create(readed, readed + curRead, 32) : Partitioner.Create(readed, readed + curRead);
Parallel.ForEach(rangePartitioner, (range, loopState) =>
{
@ -101,7 +108,10 @@ namespace MinecraftClient.Crypto.Streams
for (int idx = range.Item1; idx < range.Item2; idx++)
{
ReadOnlySpan<byte> blockInput = new(inputBuf, idx, blockSize);
aes.EncryptEcb(blockInput, blockOutput, PaddingMode.None);
if (FastAes != null)
FastAes.EncryptEcb(blockInput, blockOutput);
else
Aes!.EncryptEcb(blockInput, blockOutput, PaddingMode.None);
buffer[outOffset + idx] = (byte)(blockOutput[0] ^ inputBuf[idx + blockSize]);
}
});
@ -136,7 +146,10 @@ namespace MinecraftClient.Crypto.Streams
for (int wirtten = 0; wirtten < required; ++wirtten)
{
ReadOnlySpan<byte> blockInput = new(outputBuf, wirtten, blockSize);
aes.EncryptEcb(blockInput, blockOutput, PaddingMode.None);
if (FastAes != null)
FastAes.EncryptEcb(blockInput, blockOutput);
else
Aes!.EncryptEcb(blockInput, blockOutput, PaddingMode.None);
outputBuf[blockSize + wirtten] = (byte)(blockOutput[0] ^ input[offset + wirtten]);
}
BaseStream.WriteAsync(outputBuf, blockSize, required);