From df646fb50395528baf88b15816a4c84da384421d Mon Sep 17 00:00:00 2001 From: Alex Barney Date: Sat, 16 Nov 2019 14:32:01 -0700 Subject: [PATCH] Add functions to encrypt/decrypt entire buffers --- src/LibHac/Crypto2/Aes.cs | 64 +++++++++++++++++ src/LibHac/Crypto2/AesCoreNi.cs | 52 +++++++++----- src/hactoolnet/ProcessBench.cs | 119 ++++++++++++++++++++++++++++---- 3 files changed, 202 insertions(+), 33 deletions(-) diff --git a/src/LibHac/Crypto2/Aes.cs b/src/LibHac/Crypto2/Aes.cs index e433f901..47129638 100644 --- a/src/LibHac/Crypto2/Aes.cs +++ b/src/LibHac/Crypto2/Aes.cs @@ -104,5 +104,69 @@ namespace LibHac.Crypto2 #endif return new AesXtsCipher(key1, key2, iv, false); } + + public static void EncryptEcb128(ReadOnlySpan input, Span output, ReadOnlySpan key, + bool preferDotNetCrypto = false) + { + ICipher cipher = CreateEcbEncryptor(key, preferDotNetCrypto); + + cipher.Transform(input, output); + } + + public static void DecryptEcb128(ReadOnlySpan input, Span output, ReadOnlySpan key, + bool preferDotNetCrypto = false) + { + ICipher cipher = CreateEcbDecryptor(key, preferDotNetCrypto); + + cipher.Transform(input, output); + } + + public static void EncryptCbc128(ReadOnlySpan input, Span output, ReadOnlySpan key, + ReadOnlySpan iv, bool preferDotNetCrypto = false) + { + ICipher cipher = CreateCbcEncryptor(key, iv, preferDotNetCrypto); + + cipher.Transform(input, output); + } + + public static void DecryptCbc128(ReadOnlySpan input, Span output, ReadOnlySpan key, + ReadOnlySpan iv, bool preferDotNetCrypto = false) + { + ICipher cipher = CreateCbcDecryptor(key, iv, preferDotNetCrypto); + + cipher.Transform(input, output); + } + + public static void EncryptCtr128(ReadOnlySpan input, Span output, ReadOnlySpan key, + ReadOnlySpan iv, bool preferDotNetCrypto = false) + { + ICipher cipher = CreateCtrEncryptor(key, iv, preferDotNetCrypto); + + cipher.Transform(input, output); + } + + public static void DecryptCtr128(ReadOnlySpan input, Span output, ReadOnlySpan key, + ReadOnlySpan iv, bool preferDotNetCrypto = false) + { + ICipher cipher = CreateCtrDecryptor(key, iv, preferDotNetCrypto); + + cipher.Transform(input, output); + } + + public static void EncryptXts128(ReadOnlySpan input, Span output, ReadOnlySpan key1, + ReadOnlySpan key2, ReadOnlySpan iv, bool preferDotNetCrypto = false) + { + ICipher cipher = CreateXtsEncryptor(key1, key2, iv, preferDotNetCrypto); + + cipher.Transform(input, output); + } + + public static void DecryptXts128(ReadOnlySpan input, Span output, ReadOnlySpan key1, + ReadOnlySpan key2, ReadOnlySpan iv, bool preferDotNetCrypto = false) + { + ICipher cipher = CreateXtsDecryptor(key1, key2, iv, preferDotNetCrypto); + + cipher.Transform(input, output); + } } } diff --git a/src/LibHac/Crypto2/AesCoreNi.cs b/src/LibHac/Crypto2/AesCoreNi.cs index 57aabc7d..04b95f1b 100644 --- a/src/LibHac/Crypto2/AesCoreNi.cs +++ b/src/LibHac/Crypto2/AesCoreNi.cs @@ -104,18 +104,38 @@ namespace LibHac.Crypto2 [MethodImpl(MethodImplOptions.AggressiveOptimization)] private static void KeyExpansion(ReadOnlySpan key, Span> roundKeys, bool isDecrypting) { - roundKeys[0] = Unsafe.ReadUnaligned>(ref MemoryMarshal.GetReference(key)); + var curKey = Unsafe.ReadUnaligned>(ref MemoryMarshal.GetReference(key)); + roundKeys[0] = curKey; - MakeRoundKey(roundKeys, 1, 0x01); - MakeRoundKey(roundKeys, 2, 0x02); - MakeRoundKey(roundKeys, 3, 0x04); - MakeRoundKey(roundKeys, 4, 0x08); - MakeRoundKey(roundKeys, 5, 0x10); - MakeRoundKey(roundKeys, 6, 0x20); - MakeRoundKey(roundKeys, 7, 0x40); - MakeRoundKey(roundKeys, 8, 0x80); - MakeRoundKey(roundKeys, 9, 0x1b); - MakeRoundKey(roundKeys, 10, 0x36); + curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x01)); + roundKeys[1] = curKey; + + curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x02)); + roundKeys[2] = curKey; + + curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x04)); + roundKeys[3] = curKey; + + curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x08)); + roundKeys[4] = curKey; + + curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x10)); + roundKeys[5] = curKey; + + curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x20)); + roundKeys[6] = curKey; + + curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x40)); + roundKeys[7] = curKey; + + curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x80)); + roundKeys[8] = curKey; + + curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x1b)); + roundKeys[9] = curKey; + + curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x36)); + roundKeys[10] = curKey; if (isDecrypting) { @@ -126,19 +146,15 @@ namespace LibHac.Crypto2 } } - [MethodImpl(MethodImplOptions.AggressiveOptimization)] - private static void MakeRoundKey(Span> keys, int i, byte rcon) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 KeyExpansion(Vector128 s, Vector128 t) { - Vector128 s = keys[i - 1]; - Vector128 t = keys[i - 1]; - - t = Aes.KeygenAssist(t, rcon); t = Sse2.Shuffle(t.AsUInt32(), 0xFF).AsByte(); s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 4)); s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 8)); - keys[i] = Sse2.Xor(s, t); + return Sse2.Xor(s, t); } } } diff --git a/src/hactoolnet/ProcessBench.cs b/src/hactoolnet/ProcessBench.cs index 36fb78f8..db64776b 100644 --- a/src/hactoolnet/ProcessBench.cs +++ b/src/hactoolnet/ProcessBench.cs @@ -12,6 +12,8 @@ namespace hactoolnet { private const int Size = 1024 * 1024 * 10; private const int Iterations = 100; + private const int BlockSizeBlocked = 0x10; + private const int BlockSizeSeparate = 0x10; private static void CopyBenchmark(IStorage src, IStorage dst, int iterations, string label, IProgressReport logger) { @@ -82,7 +84,7 @@ namespace hactoolnet logger.SetTotal(iterations); - int blockCount = src.Length / 0x10; + int blockCount = src.Length / BlockSizeBlocked; for (int i = 0; i < iterations; i++) { @@ -92,7 +94,8 @@ namespace hactoolnet for (int b = 0; b < blockCount; b++) { - cipher.Transform(src.Slice(b * 0x10, 0x10), dst.Slice(b * 0x10, 0x10)); + cipher.Transform(src.Slice(b * BlockSizeBlocked, BlockSizeBlocked), + dst.Slice(b * BlockSizeBlocked, BlockSizeBlocked)); } watch.Stop(); @@ -116,8 +119,60 @@ namespace hactoolnet logger.LogMessage($"{label}{averageRate}/s, fastest run: {fastestRate}/s, slowest run: {slowestRate}/s"); } - private static void RunCipherBenchmark(Func cipherNet, Func cipherLibHac, bool benchBlocked, - string label, IProgressReport logger) + private delegate void CipherTaskSeparate(ReadOnlySpan input, Span output, ReadOnlySpan key1, + ReadOnlySpan key2, ReadOnlySpan iv, bool preferDotNetCrypto = false); + + // Benchmarks encrypting each block separately, initializing a new cipher object for each one + private static void CipherBenchmarkSeparate(ReadOnlySpan src, Span dst, CipherTaskSeparate function, + int iterations, string label, bool dotNetCrypto, IProgressReport logger) + { + Debug.Assert(src.Length == dst.Length); + + var watch = new Stopwatch(); + var runTimes = new double[iterations]; + + ReadOnlySpan key1 = stackalloc byte[0x10]; + ReadOnlySpan key2 = stackalloc byte[0x10]; + ReadOnlySpan iv = stackalloc byte[0x10]; + + logger.SetTotal(iterations); + + const int blockSize = BlockSizeSeparate; + int blockCount = src.Length / blockSize; + + for (int i = 0; i < iterations; i++) + { + watch.Restart(); + + for (int b = 0; b < blockCount; b++) + { + function(src.Slice(b * blockSize, blockSize), dst.Slice(b * blockSize, blockSize), + key1, key2, iv, dotNetCrypto); + } + + watch.Stop(); + + logger.ReportAdd(1); + runTimes[i] = watch.Elapsed.TotalSeconds; + } + + logger.SetTotal(0); + + long srcSize = src.Length; + + double fastestRun = runTimes.Min(); + double averageRun = runTimes.Average(); + double slowestRun = runTimes.Max(); + + string fastestRate = Util.GetBytesReadable((long)(srcSize / fastestRun)); + string averageRate = Util.GetBytesReadable((long)(srcSize / averageRun)); + string slowestRate = Util.GetBytesReadable((long)(srcSize / slowestRun)); + + logger.LogMessage($"{label}{averageRate}/s, fastest run: {fastestRate}/s, slowest run: {slowestRate}/s"); + } + + private static void RunCipherBenchmark(Func cipherNet, Func cipherLibHac, + CipherTaskSeparate function, bool benchBlocked, string label, IProgressReport logger) { var srcData = new byte[Size]; @@ -125,19 +180,34 @@ namespace hactoolnet var dstDataNet = new byte[Size]; var dstDataBlockedLh = new byte[Size]; var dstDataBlockedNet = new byte[Size]; + var dstDataSeparateLh = new byte[Size]; + var dstDataSeparateNet = new byte[Size]; logger.LogMessage(string.Empty); logger.LogMessage(label); - if (AesCrypto.IsAesNiSupported()) CipherBenchmark(srcData, dstDataLh, cipherLibHac, Iterations, "LibHac impl: ", logger); + if (AesCrypto.IsAesNiSupported()) + CipherBenchmark(srcData, dstDataLh, cipherLibHac, Iterations, "LibHac impl: ", logger); CipherBenchmark(srcData, dstDataNet, cipherNet, Iterations, ".NET impl: ", logger); if (benchBlocked) { if (AesCrypto.IsAesNiSupported()) - CipherBenchmarkBlocked(srcData, dstDataBlockedLh, cipherLibHac, Iterations / 5, "LibHac impl (blocked): ", logger); + CipherBenchmarkBlocked(srcData, dstDataBlockedLh, cipherLibHac, Iterations / 5, + "LibHac impl (blocked): ", logger); - CipherBenchmarkBlocked(srcData, dstDataBlockedNet, cipherNet, Iterations / 5, ".NET impl (blocked): ", logger); + CipherBenchmarkBlocked(srcData, dstDataBlockedNet, cipherNet, Iterations / 5, ".NET impl (blocked): ", + logger); + } + + if (function != null) + { + if (AesCrypto.IsAesNiSupported()) + CipherBenchmarkSeparate(srcData, dstDataSeparateLh, function, Iterations / 5, + "LibHac impl (separate): ", false, logger); + + CipherBenchmarkSeparate(srcData, dstDataSeparateNet, function, Iterations / 20, + ".NET impl (separate): ", true, logger); } if (AesCrypto.IsAesNiSupported()) @@ -149,6 +219,12 @@ namespace hactoolnet logger.LogMessage($"{dstDataLh.SequenceEqual(dstDataBlockedLh)}"); logger.LogMessage($"{dstDataLh.SequenceEqual(dstDataBlockedNet)}"); } + + if (function != null) + { + logger.LogMessage($"{dstDataLh.SequenceEqual(dstDataSeparateLh)}"); + logger.LogMessage($"{dstDataLh.SequenceEqual(dstDataSeparateNet)}"); + } } } @@ -158,7 +234,6 @@ namespace hactoolnet { case "aesctr": { - IStorage decStorage = new MemoryStorage(new byte[Size]); IStorage encStorage = new Aes128CtrStorage(new MemoryStorage(new byte[Size]), new byte[0x10], new byte[0x10], true); @@ -206,13 +281,17 @@ namespace hactoolnet { Func encryptorNet = () => AesCrypto.CreateEcbEncryptor(new byte[0x10], true); Func encryptorLh = () => AesCrypto.CreateEcbEncryptor(new byte[0x10]); + CipherTaskSeparate encrypt = (input, output, key1, key2, iv, crypto) => + AesCrypto.EncryptEcb128(input, output, key1, crypto); - RunCipherBenchmark(encryptorNet, encryptorLh, true, "AES-ECB encrypt", ctx.Logger); + RunCipherBenchmark(encryptorNet, encryptorLh, encrypt, true, "AES-ECB encrypt", ctx.Logger); Func decryptorNet = () => AesCrypto.CreateEcbDecryptor(new byte[0x10], true); Func decryptorLh = () => AesCrypto.CreateEcbDecryptor(new byte[0x10]); + CipherTaskSeparate decrypt = (input, output, key1, key2, iv, crypto) => + AesCrypto.DecryptEcb128(input, output, key1, crypto); - RunCipherBenchmark(decryptorNet, decryptorLh, true, "AES-ECB decrypt", ctx.Logger); + RunCipherBenchmark(decryptorNet, decryptorLh, decrypt, true, "AES-ECB decrypt", ctx.Logger); break; } @@ -220,13 +299,17 @@ namespace hactoolnet { Func encryptorNet = () => AesCrypto.CreateCbcEncryptor(new byte[0x10], new byte[0x10], true); Func encryptorLh = () => AesCrypto.CreateCbcEncryptor(new byte[0x10], new byte[0x10]); + CipherTaskSeparate encrypt = (input, output, key1, key2, iv, crypto) => + AesCrypto.EncryptCbc128(input, output, key1, iv, crypto); - RunCipherBenchmark(encryptorNet, encryptorLh, true, "AES-CBC encrypt", ctx.Logger); + RunCipherBenchmark(encryptorNet, encryptorLh, encrypt, true, "AES-CBC encrypt", ctx.Logger); Func decryptorNet = () => AesCrypto.CreateCbcDecryptor(new byte[0x10], new byte[0x10], true); Func decryptorLh = () => AesCrypto.CreateCbcDecryptor(new byte[0x10], new byte[0x10]); + CipherTaskSeparate decrypt = (input, output, key1, key2, iv, crypto) => + AesCrypto.DecryptCbc128(input, output, key1, iv, crypto); - RunCipherBenchmark(decryptorNet, decryptorLh, true, "AES-CBC decrypt", ctx.Logger); + RunCipherBenchmark(decryptorNet, decryptorLh, decrypt, true, "AES-CBC decrypt", ctx.Logger); break; } @@ -235,8 +318,10 @@ namespace hactoolnet { Func encryptorNet = () => AesCrypto.CreateCtrEncryptor(new byte[0x10], new byte[0x10], true); Func encryptorLh = () => AesCrypto.CreateCtrEncryptor(new byte[0x10], new byte[0x10]); + CipherTaskSeparate encrypt = (input, output, key1, key2, iv, crypto) => + AesCrypto.EncryptCtr128(input, output, key1, iv, crypto); - RunCipherBenchmark(encryptorNet, encryptorLh, true, "AES-CTR", ctx.Logger); + RunCipherBenchmark(encryptorNet, encryptorLh, encrypt, true, "AES-CTR", ctx.Logger); break; } @@ -244,13 +329,17 @@ namespace hactoolnet { Func encryptorNet = () => AesCrypto.CreateXtsEncryptor(new byte[0x10], new byte[0x10], new byte[0x10], true); Func encryptorLh = () => AesCrypto.CreateXtsEncryptor(new byte[0x10], new byte[0x10], new byte[0x10]); + CipherTaskSeparate encrypt = (input, output, key1, key2, iv, crypto) => + AesCrypto.EncryptXts128(input, output, key1, key2, iv, crypto); - RunCipherBenchmark(encryptorNet, encryptorLh, false, "AES-XTS encrypt", ctx.Logger); + RunCipherBenchmark(encryptorNet, encryptorLh, encrypt, false, "AES-XTS encrypt", ctx.Logger); Func decryptorNet = () => AesCrypto.CreateXtsDecryptor(new byte[0x10], new byte[0x10], new byte[0x10], true); Func decryptorLh = () => AesCrypto.CreateXtsDecryptor(new byte[0x10], new byte[0x10], new byte[0x10]); + CipherTaskSeparate decrypt = (input, output, key1, key2, iv, crypto) => + AesCrypto.DecryptXts128(input, output, key1, key2, iv, crypto); - RunCipherBenchmark(decryptorNet, decryptorLh, false, "AES-XTS decrypt", ctx.Logger); + RunCipherBenchmark(decryptorNet, decryptorLh, decrypt, false, "AES-XTS decrypt", ctx.Logger); break; }