From 99522b748ebf32d33863712781dcda6c9e3aa3f3 Mon Sep 17 00:00:00 2001 From: Alex Barney Date: Sun, 24 Nov 2019 19:54:29 -0600 Subject: [PATCH] Add optimized functions for decrypting a single AES block --- src/LibHac/Crypto/Detail/AesCoreNi.cs | 93 ++++++++++++++++++++++++--- src/hactoolnet/MultiBenchmark.cs | 10 +-- src/hactoolnet/ProcessBench.cs | 69 +++++++++++++++++++- 3 files changed, 156 insertions(+), 16 deletions(-) diff --git a/src/LibHac/Crypto/Detail/AesCoreNi.cs b/src/LibHac/Crypto/Detail/AesCoreNi.cs index f1b65481..77012b1e 100644 --- a/src/LibHac/Crypto/Detail/AesCoreNi.cs +++ b/src/LibHac/Crypto/Detail/AesCoreNi.cs @@ -1,4 +1,4 @@ -#if NETCOREAPP +#if HAS_INTRINSICS using System; using System.Diagnostics; using System.Runtime.CompilerServices; @@ -18,6 +18,10 @@ namespace LibHac.Crypto.Detail private Vector128 _roundKeys; + // An Initialize method is used instead of a constructor because it prevents the runtime + // from zeroing out the structure's memory when creating it. + // When processing a single block, doing this can increase performance by 20-40% + // depending on the context. public void Initialize(ReadOnlySpan key, bool isDecrypting) { Debug.Assert(key.Length == Aes.KeySize128); @@ -183,7 +187,8 @@ namespace LibHac.Crypto.Detail // When inlining this function, RyuJIT will almost make the // generated code the same as if it were manually inlined [MethodImpl(MethodImplOptions.AggressiveInlining)] - public readonly void EncryptBlocks8(Vector128 in0, + public readonly void EncryptBlocks8( + Vector128 in0, Vector128 in1, Vector128 in2, Vector128 in3, @@ -198,8 +203,7 @@ namespace LibHac.Crypto.Detail out Vector128 out4, out Vector128 out5, out Vector128 out6, - out Vector128 out7 - ) + out Vector128 out7) { ReadOnlySpan> keys = RoundKeys; @@ -331,8 +335,7 @@ namespace LibHac.Crypto.Detail out Vector128 out4, out Vector128 out5, out Vector128 out6, - out Vector128 out7 - ) + out Vector128 out7) { ReadOnlySpan> keys = RoundKeys; @@ -447,6 +450,71 @@ namespace LibHac.Crypto.Detail out7 = AesNi.DecryptLast(b7, key); } + public static Vector128 EncryptBlock(Vector128 input, Vector128 key) + { + Vector128 curKey = key; + Vector128 b = Sse2.Xor(input, curKey); + + curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x01)); + b = AesNi.Encrypt(b, curKey); + + curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x02)); + b = AesNi.Encrypt(b, curKey); + + curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x04)); + b = AesNi.Encrypt(b, curKey); + + curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x08)); + b = AesNi.Encrypt(b, curKey); + + curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x10)); + b = AesNi.Encrypt(b, curKey); + + curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x20)); + b = AesNi.Encrypt(b, curKey); + + curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x40)); + b = AesNi.Encrypt(b, curKey); + + curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x80)); + b = AesNi.Encrypt(b, curKey); + + curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x1b)); + b = AesNi.Encrypt(b, curKey); + + curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x36)); + return AesNi.EncryptLast(b, curKey); + } + + public static Vector128 DecryptBlock(Vector128 input, Vector128 key) + { + Vector128 key0 = key; + Vector128 key1 = KeyExpansion(key0, AesNi.KeygenAssist(key0, 0x01)); + Vector128 key2 = KeyExpansion(key1, AesNi.KeygenAssist(key1, 0x02)); + Vector128 key3 = KeyExpansion(key2, AesNi.KeygenAssist(key2, 0x04)); + Vector128 key4 = KeyExpansion(key3, AesNi.KeygenAssist(key3, 0x08)); + Vector128 key5 = KeyExpansion(key4, AesNi.KeygenAssist(key4, 0x10)); + Vector128 key6 = KeyExpansion(key5, AesNi.KeygenAssist(key5, 0x20)); + Vector128 key7 = KeyExpansion(key6, AesNi.KeygenAssist(key6, 0x40)); + Vector128 key8 = KeyExpansion(key7, AesNi.KeygenAssist(key7, 0x80)); + Vector128 key9 = KeyExpansion(key8, AesNi.KeygenAssist(key8, 0x1b)); + Vector128 key10 = KeyExpansion(key9, AesNi.KeygenAssist(key9, 0x36)); + + Vector128 b = input; + + b = Sse2.Xor(b, key10); + b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key9)); + b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key8)); + b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key7)); + b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key6)); + b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key5)); + b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key4)); + b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key3)); + b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key2)); + b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key1)); + return AesNi.DecryptLast(b, key0); + } + private void KeyExpansion(ReadOnlySpan key, bool isDecrypting) { Span> roundKeys = MemoryMarshal.CreateSpan(ref _roundKeys, RoundKeyCount); @@ -486,10 +554,15 @@ namespace LibHac.Crypto.Detail if (isDecrypting) { - for (int i = 1; i < 10; i++) - { - roundKeys[i] = AesNi.InverseMixColumns(roundKeys[i]); - } + roundKeys[1] = AesNi.InverseMixColumns(roundKeys[1]); + roundKeys[2] = AesNi.InverseMixColumns(roundKeys[2]); + roundKeys[3] = AesNi.InverseMixColumns(roundKeys[3]); + roundKeys[4] = AesNi.InverseMixColumns(roundKeys[4]); + roundKeys[5] = AesNi.InverseMixColumns(roundKeys[5]); + roundKeys[6] = AesNi.InverseMixColumns(roundKeys[6]); + roundKeys[7] = AesNi.InverseMixColumns(roundKeys[7]); + roundKeys[8] = AesNi.InverseMixColumns(roundKeys[8]); + roundKeys[9] = AesNi.InverseMixColumns(roundKeys[9]); } } diff --git a/src/hactoolnet/MultiBenchmark.cs b/src/hactoolnet/MultiBenchmark.cs index afa8aa51..ede8bfe6 100644 --- a/src/hactoolnet/MultiBenchmark.cs +++ b/src/hactoolnet/MultiBenchmark.cs @@ -6,18 +6,19 @@ namespace hactoolnet { internal class MultiBenchmark { - public int RunsNeeded { get; set; } = 500; + public int DefaultRunsNeeded { get; set; } = 500; private List Benchmarks { get; } = new List(); - public void Register(string name, Action setupAction, Action runAction, Func resultPrinter) + public void Register(string name, Action setupAction, Action runAction, Func resultPrinter, int runsNeeded = -1) { var benchmark = new BenchmarkItem { Name = name, Setup = setupAction, Run = runAction, - PrintResult = resultPrinter + PrintResult = resultPrinter, + RunsNeeded = runsNeeded == -1 ? DefaultRunsNeeded : runsNeeded }; Benchmarks.Add(benchmark); @@ -40,7 +41,7 @@ namespace hactoolnet int runsSinceLastBest = 0; - while (runsSinceLastBest < RunsNeeded) + while (runsSinceLastBest < item.RunsNeeded) { runsSinceLastBest++; item.Setup(); @@ -64,6 +65,7 @@ namespace hactoolnet private class BenchmarkItem { public string Name { get; set; } + public int RunsNeeded { get; set; } public double Time { get; set; } public string Result { get; set; } diff --git a/src/hactoolnet/ProcessBench.cs b/src/hactoolnet/ProcessBench.cs index c5cab16d..b62f6dac 100644 --- a/src/hactoolnet/ProcessBench.cs +++ b/src/hactoolnet/ProcessBench.cs @@ -6,6 +6,12 @@ using LibHac.Crypto; using LibHac.Fs; using LibHac.FsSystem; +#if NETCOREAPP +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +#endif + namespace hactoolnet { internal static class ProcessBench @@ -16,6 +22,8 @@ namespace hactoolnet private const int BlockSizeSeparate = 0x10; private const int BatchCipherBenchSize = 1024 * 1024; + // ReSharper disable once UnusedMember.Local + private const int SingleBlockCipherBenchSize = 1024 * 128; private static void CopyBenchmark(IStorage src, IStorage dst, int iterations, string label, IProgressReport logger) { @@ -173,7 +181,7 @@ namespace hactoolnet logger.LogMessage($"{label}{averageRate}/s, fastest run: {fastestRate}/s, slowest run: {slowestRate}/s"); } - private static void RegisterAllCipherBenchmarks(MultiBenchmark bench) + private static void RegisterAesSequentialBenchmarks(MultiBenchmark bench) { var input = new byte[BatchCipherBenchSize]; var output = new byte[BatchCipherBenchSize]; @@ -220,6 +228,62 @@ namespace hactoolnet } } + // ReSharper disable once UnusedParameter.Local + private static void RegisterAesSingleBlockBenchmarks(MultiBenchmark bench) + { +#if NETCOREAPP + var input = new byte[SingleBlockCipherBenchSize]; + var output = new byte[SingleBlockCipherBenchSize]; + + Func resultPrinter = time => Util.GetBytesReadable((long)(SingleBlockCipherBenchSize / time)) + "/s"; + + bench.Register("AES single-block encrypt", () => { }, EncryptBlocks, resultPrinter); + bench.Register("AES single-block decrypt", () => { }, DecryptBlocks, resultPrinter); + + bench.DefaultRunsNeeded = 1000; + + void EncryptBlocks() + { + ref byte inBlock = ref MemoryMarshal.GetReference(input.AsSpan()); + ref byte outBlock = ref MemoryMarshal.GetReference(output.AsSpan()); + + Vector128 keyVec = Vector128.Zero; + + ref byte end = ref Unsafe.Add(ref inBlock, input.Length); + + while (Unsafe.IsAddressLessThan(ref inBlock, ref end)) + { + var inputVec = Unsafe.ReadUnaligned>(ref inBlock); + Vector128 outputVec = LibHac.Crypto.Detail.AesCoreNi.EncryptBlock(inputVec, keyVec); + Unsafe.WriteUnaligned(ref outBlock, outputVec); + + inBlock = ref Unsafe.Add(ref inBlock, Aes.BlockSize); + outBlock = ref Unsafe.Add(ref outBlock, Aes.BlockSize); + } + } + + void DecryptBlocks() + { + ref byte inBlock = ref MemoryMarshal.GetReference(input.AsSpan()); + ref byte outBlock = ref MemoryMarshal.GetReference(output.AsSpan()); + + Vector128 keyVec = Vector128.Zero; + + ref byte end = ref Unsafe.Add(ref inBlock, input.Length); + + while (Unsafe.IsAddressLessThan(ref inBlock, ref end)) + { + var inputVec = Unsafe.ReadUnaligned>(ref inBlock); + Vector128 outputVec = LibHac.Crypto.Detail.AesCoreNi.DecryptBlock(inputVec, keyVec); + Unsafe.WriteUnaligned(ref outBlock, outputVec); + + inBlock = ref Unsafe.Add(ref inBlock, Aes.BlockSize); + outBlock = ref Unsafe.Add(ref outBlock, Aes.BlockSize); + } + } +#endif + } + private static void RunCipherBenchmark(Func cipherNet, Func cipherLibHac, CipherTaskSeparate function, bool benchBlocked, string label, IProgressReport logger) { @@ -399,7 +463,8 @@ namespace hactoolnet { var bench = new MultiBenchmark(); - RegisterAllCipherBenchmarks(bench); + RegisterAesSequentialBenchmarks(bench); + RegisterAesSingleBlockBenchmarks(bench); bench.Run(); break;