Add optimized functions for decrypting a single AES block

This commit is contained in:
Alex Barney 2019-11-24 19:54:29 -06:00
parent abce62dd4f
commit 99522b748e
3 changed files with 156 additions and 16 deletions

View file

@ -1,4 +1,4 @@
#if NETCOREAPP
#if HAS_INTRINSICS
using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
@ -18,6 +18,10 @@ namespace LibHac.Crypto.Detail
private Vector128<byte> _roundKeys;
// An Initialize method is used instead of a constructor because it prevents the runtime
// from zeroing out the structure's memory when creating it.
// When processing a single block, doing this can increase performance by 20-40%
// depending on the context.
public void Initialize(ReadOnlySpan<byte> key, bool isDecrypting)
{
Debug.Assert(key.Length == Aes.KeySize128);
@ -183,7 +187,8 @@ namespace LibHac.Crypto.Detail
// When inlining this function, RyuJIT will almost make the
// generated code the same as if it were manually inlined
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public readonly void EncryptBlocks8(Vector128<byte> in0,
public readonly void EncryptBlocks8(
Vector128<byte> in0,
Vector128<byte> in1,
Vector128<byte> in2,
Vector128<byte> in3,
@ -198,8 +203,7 @@ namespace LibHac.Crypto.Detail
out Vector128<byte> out4,
out Vector128<byte> out5,
out Vector128<byte> out6,
out Vector128<byte> out7
)
out Vector128<byte> out7)
{
ReadOnlySpan<Vector128<byte>> keys = RoundKeys;
@ -331,8 +335,7 @@ namespace LibHac.Crypto.Detail
out Vector128<byte> out4,
out Vector128<byte> out5,
out Vector128<byte> out6,
out Vector128<byte> out7
)
out Vector128<byte> out7)
{
ReadOnlySpan<Vector128<byte>> keys = RoundKeys;
@ -447,6 +450,71 @@ namespace LibHac.Crypto.Detail
out7 = AesNi.DecryptLast(b7, key);
}
public static Vector128<byte> EncryptBlock(Vector128<byte> input, Vector128<byte> key)
{
Vector128<byte> curKey = key;
Vector128<byte> b = Sse2.Xor(input, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x01));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x02));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x04));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x08));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x10));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x20));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x40));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x80));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x1b));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x36));
return AesNi.EncryptLast(b, curKey);
}
public static Vector128<byte> DecryptBlock(Vector128<byte> input, Vector128<byte> key)
{
Vector128<byte> key0 = key;
Vector128<byte> key1 = KeyExpansion(key0, AesNi.KeygenAssist(key0, 0x01));
Vector128<byte> key2 = KeyExpansion(key1, AesNi.KeygenAssist(key1, 0x02));
Vector128<byte> key3 = KeyExpansion(key2, AesNi.KeygenAssist(key2, 0x04));
Vector128<byte> key4 = KeyExpansion(key3, AesNi.KeygenAssist(key3, 0x08));
Vector128<byte> key5 = KeyExpansion(key4, AesNi.KeygenAssist(key4, 0x10));
Vector128<byte> key6 = KeyExpansion(key5, AesNi.KeygenAssist(key5, 0x20));
Vector128<byte> key7 = KeyExpansion(key6, AesNi.KeygenAssist(key6, 0x40));
Vector128<byte> key8 = KeyExpansion(key7, AesNi.KeygenAssist(key7, 0x80));
Vector128<byte> key9 = KeyExpansion(key8, AesNi.KeygenAssist(key8, 0x1b));
Vector128<byte> key10 = KeyExpansion(key9, AesNi.KeygenAssist(key9, 0x36));
Vector128<byte> b = input;
b = Sse2.Xor(b, key10);
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key9));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key8));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key7));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key6));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key5));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key4));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key3));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key2));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key1));
return AesNi.DecryptLast(b, key0);
}
private void KeyExpansion(ReadOnlySpan<byte> key, bool isDecrypting)
{
Span<Vector128<byte>> roundKeys = MemoryMarshal.CreateSpan(ref _roundKeys, RoundKeyCount);
@ -486,10 +554,15 @@ namespace LibHac.Crypto.Detail
if (isDecrypting)
{
for (int i = 1; i < 10; i++)
{
roundKeys[i] = AesNi.InverseMixColumns(roundKeys[i]);
}
roundKeys[1] = AesNi.InverseMixColumns(roundKeys[1]);
roundKeys[2] = AesNi.InverseMixColumns(roundKeys[2]);
roundKeys[3] = AesNi.InverseMixColumns(roundKeys[3]);
roundKeys[4] = AesNi.InverseMixColumns(roundKeys[4]);
roundKeys[5] = AesNi.InverseMixColumns(roundKeys[5]);
roundKeys[6] = AesNi.InverseMixColumns(roundKeys[6]);
roundKeys[7] = AesNi.InverseMixColumns(roundKeys[7]);
roundKeys[8] = AesNi.InverseMixColumns(roundKeys[8]);
roundKeys[9] = AesNi.InverseMixColumns(roundKeys[9]);
}
}

View file

@ -6,18 +6,19 @@ namespace hactoolnet
{
internal class MultiBenchmark
{
public int RunsNeeded { get; set; } = 500;
public int DefaultRunsNeeded { get; set; } = 500;
private List<BenchmarkItem> Benchmarks { get; } = new List<BenchmarkItem>();
public void Register(string name, Action setupAction, Action runAction, Func<double, string> resultPrinter)
public void Register(string name, Action setupAction, Action runAction, Func<double, string> resultPrinter, int runsNeeded = -1)
{
var benchmark = new BenchmarkItem
{
Name = name,
Setup = setupAction,
Run = runAction,
PrintResult = resultPrinter
PrintResult = resultPrinter,
RunsNeeded = runsNeeded == -1 ? DefaultRunsNeeded : runsNeeded
};
Benchmarks.Add(benchmark);
@ -40,7 +41,7 @@ namespace hactoolnet
int runsSinceLastBest = 0;
while (runsSinceLastBest < RunsNeeded)
while (runsSinceLastBest < item.RunsNeeded)
{
runsSinceLastBest++;
item.Setup();
@ -64,6 +65,7 @@ namespace hactoolnet
private class BenchmarkItem
{
public string Name { get; set; }
public int RunsNeeded { get; set; }
public double Time { get; set; }
public string Result { get; set; }

View file

@ -6,6 +6,12 @@ using LibHac.Crypto;
using LibHac.Fs;
using LibHac.FsSystem;
#if NETCOREAPP
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
#endif
namespace hactoolnet
{
internal static class ProcessBench
@ -16,6 +22,8 @@ namespace hactoolnet
private const int BlockSizeSeparate = 0x10;
private const int BatchCipherBenchSize = 1024 * 1024;
// ReSharper disable once UnusedMember.Local
private const int SingleBlockCipherBenchSize = 1024 * 128;
private static void CopyBenchmark(IStorage src, IStorage dst, int iterations, string label, IProgressReport logger)
{
@ -173,7 +181,7 @@ namespace hactoolnet
logger.LogMessage($"{label}{averageRate}/s, fastest run: {fastestRate}/s, slowest run: {slowestRate}/s");
}
private static void RegisterAllCipherBenchmarks(MultiBenchmark bench)
private static void RegisterAesSequentialBenchmarks(MultiBenchmark bench)
{
var input = new byte[BatchCipherBenchSize];
var output = new byte[BatchCipherBenchSize];
@ -220,6 +228,62 @@ namespace hactoolnet
}
}
// ReSharper disable once UnusedParameter.Local
private static void RegisterAesSingleBlockBenchmarks(MultiBenchmark bench)
{
#if NETCOREAPP
var input = new byte[SingleBlockCipherBenchSize];
var output = new byte[SingleBlockCipherBenchSize];
Func<double, string> resultPrinter = time => Util.GetBytesReadable((long)(SingleBlockCipherBenchSize / time)) + "/s";
bench.Register("AES single-block encrypt", () => { }, EncryptBlocks, resultPrinter);
bench.Register("AES single-block decrypt", () => { }, DecryptBlocks, resultPrinter);
bench.DefaultRunsNeeded = 1000;
void EncryptBlocks()
{
ref byte inBlock = ref MemoryMarshal.GetReference(input.AsSpan());
ref byte outBlock = ref MemoryMarshal.GetReference(output.AsSpan());
Vector128<byte> keyVec = Vector128<byte>.Zero;
ref byte end = ref Unsafe.Add(ref inBlock, input.Length);
while (Unsafe.IsAddressLessThan(ref inBlock, ref end))
{
var inputVec = Unsafe.ReadUnaligned<Vector128<byte>>(ref inBlock);
Vector128<byte> outputVec = LibHac.Crypto.Detail.AesCoreNi.EncryptBlock(inputVec, keyVec);
Unsafe.WriteUnaligned(ref outBlock, outputVec);
inBlock = ref Unsafe.Add(ref inBlock, Aes.BlockSize);
outBlock = ref Unsafe.Add(ref outBlock, Aes.BlockSize);
}
}
void DecryptBlocks()
{
ref byte inBlock = ref MemoryMarshal.GetReference(input.AsSpan());
ref byte outBlock = ref MemoryMarshal.GetReference(output.AsSpan());
Vector128<byte> keyVec = Vector128<byte>.Zero;
ref byte end = ref Unsafe.Add(ref inBlock, input.Length);
while (Unsafe.IsAddressLessThan(ref inBlock, ref end))
{
var inputVec = Unsafe.ReadUnaligned<Vector128<byte>>(ref inBlock);
Vector128<byte> outputVec = LibHac.Crypto.Detail.AesCoreNi.DecryptBlock(inputVec, keyVec);
Unsafe.WriteUnaligned(ref outBlock, outputVec);
inBlock = ref Unsafe.Add(ref inBlock, Aes.BlockSize);
outBlock = ref Unsafe.Add(ref outBlock, Aes.BlockSize);
}
}
#endif
}
private static void RunCipherBenchmark(Func<ICipher> cipherNet, Func<ICipher> cipherLibHac,
CipherTaskSeparate function, bool benchBlocked, string label, IProgressReport logger)
{
@ -399,7 +463,8 @@ namespace hactoolnet
{
var bench = new MultiBenchmark();
RegisterAllCipherBenchmarks(bench);
RegisterAesSequentialBenchmarks(bench);
RegisterAesSingleBlockBenchmarks(bench);
bench.Run();
break;