From db6269df5c56e8cd5b14aeccf64c86ac1c50c766 Mon Sep 17 00:00:00 2001 From: Alex Barney Date: Wed, 20 Nov 2019 21:33:31 -0500 Subject: [PATCH] Interleave AES instructions to improve performance --- src/LibHac/Crypto/Detail/AesCbcModeNi.cs | 42 ++- src/LibHac/Crypto/Detail/AesCoreNi.cs | 354 ++++++++++++++++++++++- src/LibHac/Crypto/Detail/AesCtrModeNi.cs | 54 +++- src/LibHac/Crypto/Detail/AesEcbModeNi.cs | 4 +- src/LibHac/Crypto/Detail/AesXtsModeNi.cs | 104 ++++++- 5 files changed, 541 insertions(+), 17 deletions(-) diff --git a/src/LibHac/Crypto/Detail/AesCbcModeNi.cs b/src/LibHac/Crypto/Detail/AesCbcModeNi.cs index 8160991d..e57499b2 100644 --- a/src/LibHac/Crypto/Detail/AesCbcModeNi.cs +++ b/src/LibHac/Crypto/Detail/AesCbcModeNi.cs @@ -49,14 +49,51 @@ namespace LibHac.Crypto.Detail public void Decrypt(ReadOnlySpan input, Span output) { - int blockCount = Math.Min(input.Length, output.Length) >> 4; + int remainingBlocks = Math.Min(input.Length, output.Length) >> 4; ref Vector128 inBlock = ref Unsafe.As>(ref MemoryMarshal.GetReference(input)); ref Vector128 outBlock = ref Unsafe.As>(ref MemoryMarshal.GetReference(output)); Vector128 iv = _iv; - for (int i = 0; i < blockCount; i++) + while (remainingBlocks > 7) + { + Vector128 in0 = Unsafe.Add(ref inBlock, 0); + Vector128 in1 = Unsafe.Add(ref inBlock, 1); + Vector128 in2 = Unsafe.Add(ref inBlock, 2); + Vector128 in3 = Unsafe.Add(ref inBlock, 3); + Vector128 in4 = Unsafe.Add(ref inBlock, 4); + Vector128 in5 = Unsafe.Add(ref inBlock, 5); + Vector128 in6 = Unsafe.Add(ref inBlock, 6); + Vector128 in7 = Unsafe.Add(ref inBlock, 7); + + _aesCore.DecryptBlocks8(in0, in1, in2, in3, in4, in5, in6, in7, + out Vector128 b0, + out Vector128 b1, + out Vector128 b2, + out Vector128 b3, + out Vector128 b4, + out Vector128 b5, + out Vector128 b6, + out Vector128 b7); + + Unsafe.Add(ref outBlock, 0) = Sse2.Xor(iv, b0); + Unsafe.Add(ref outBlock, 1) = Sse2.Xor(in0, b1); + Unsafe.Add(ref outBlock, 2) = Sse2.Xor(in1, b2); + Unsafe.Add(ref outBlock, 3) = Sse2.Xor(in2, b3); + Unsafe.Add(ref outBlock, 4) = Sse2.Xor(in3, b4); + Unsafe.Add(ref outBlock, 5) = Sse2.Xor(in4, b5); + Unsafe.Add(ref outBlock, 6) = Sse2.Xor(in5, b6); + Unsafe.Add(ref outBlock, 7) = Sse2.Xor(in6, b7); + + iv = in7; + + inBlock = ref Unsafe.Add(ref inBlock, 8); + outBlock = ref Unsafe.Add(ref outBlock, 8); + remainingBlocks -= 8; + } + + while (remainingBlocks > 0) { Vector128 currentBlock = inBlock; Vector128 decBeforeIv = _aesCore.DecryptBlock(currentBlock); @@ -66,6 +103,7 @@ namespace LibHac.Crypto.Detail inBlock = ref Unsafe.Add(ref inBlock, 1); outBlock = ref Unsafe.Add(ref outBlock, 1); + remainingBlocks -= 1; } _iv = iv; diff --git a/src/LibHac/Crypto/Detail/AesCoreNi.cs b/src/LibHac/Crypto/Detail/AesCoreNi.cs index 3cdb4fef..554d2681 100644 --- a/src/LibHac/Crypto/Detail/AesCoreNi.cs +++ b/src/LibHac/Crypto/Detail/AesCoreNi.cs @@ -28,7 +28,6 @@ namespace LibHac.Crypto.Detail public readonly ReadOnlySpan> RoundKeys => MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef(in _roundKeys), RoundKeyCount); - [MethodImpl(MethodImplOptions.AggressiveOptimization)] public readonly void Encrypt(ReadOnlySpan input, Span output) { int blockCount = Math.Min(input.Length, output.Length) >> 4; @@ -45,7 +44,6 @@ namespace LibHac.Crypto.Detail } } - [MethodImpl(MethodImplOptions.AggressiveOptimization)] public readonly void Decrypt(ReadOnlySpan input, Span output) { int blockCount = Math.Min(input.Length, output.Length) >> 4; @@ -98,7 +96,357 @@ namespace LibHac.Crypto.Detail return AesNi.DecryptLast(b, keys[0]); } - [MethodImpl(MethodImplOptions.AggressiveOptimization)] + public readonly void EncryptInterleaved8(ReadOnlySpan input, Span output) + { + int remainingBlocks = Math.Min(input.Length, output.Length) >> 4; + + ref Vector128 inBlock = ref Unsafe.As>(ref MemoryMarshal.GetReference(input)); + ref Vector128 outBlock = ref Unsafe.As>(ref MemoryMarshal.GetReference(output)); + + while (remainingBlocks > 7) + { + EncryptBlocks8( + Unsafe.Add(ref inBlock, 0), + Unsafe.Add(ref inBlock, 1), + Unsafe.Add(ref inBlock, 2), + Unsafe.Add(ref inBlock, 3), + Unsafe.Add(ref inBlock, 4), + Unsafe.Add(ref inBlock, 5), + Unsafe.Add(ref inBlock, 6), + Unsafe.Add(ref inBlock, 7), + out Unsafe.Add(ref outBlock, 0), + out Unsafe.Add(ref outBlock, 1), + out Unsafe.Add(ref outBlock, 2), + out Unsafe.Add(ref outBlock, 3), + out Unsafe.Add(ref outBlock, 4), + out Unsafe.Add(ref outBlock, 5), + out Unsafe.Add(ref outBlock, 6), + out Unsafe.Add(ref outBlock, 7)); + + inBlock = ref Unsafe.Add(ref inBlock, 8); + outBlock = ref Unsafe.Add(ref outBlock, 8); + remainingBlocks -= 8; + } + + while (remainingBlocks > 0) + { + outBlock = EncryptBlock(inBlock); + + inBlock = ref Unsafe.Add(ref inBlock, 1); + outBlock = ref Unsafe.Add(ref outBlock, 1); + remainingBlocks -= 1; + } + } + + public readonly void DecryptInterleaved8(ReadOnlySpan input, Span output) + { + int remainingBlocks = Math.Min(input.Length, output.Length) >> 4; + + ref Vector128 inBlock = ref Unsafe.As>(ref MemoryMarshal.GetReference(input)); + ref Vector128 outBlock = ref Unsafe.As>(ref MemoryMarshal.GetReference(output)); + + while (remainingBlocks > 7) + { + DecryptBlocks8( + Unsafe.Add(ref inBlock, 0), + Unsafe.Add(ref inBlock, 1), + Unsafe.Add(ref inBlock, 2), + Unsafe.Add(ref inBlock, 3), + Unsafe.Add(ref inBlock, 4), + Unsafe.Add(ref inBlock, 5), + Unsafe.Add(ref inBlock, 6), + Unsafe.Add(ref inBlock, 7), + out Unsafe.Add(ref outBlock, 0), + out Unsafe.Add(ref outBlock, 1), + out Unsafe.Add(ref outBlock, 2), + out Unsafe.Add(ref outBlock, 3), + out Unsafe.Add(ref outBlock, 4), + out Unsafe.Add(ref outBlock, 5), + out Unsafe.Add(ref outBlock, 6), + out Unsafe.Add(ref outBlock, 7)); + + inBlock = ref Unsafe.Add(ref inBlock, 8); + outBlock = ref Unsafe.Add(ref outBlock, 8); + remainingBlocks -= 8; + } + + while (remainingBlocks > 0) + { + outBlock = DecryptBlock(inBlock); + + inBlock = ref Unsafe.Add(ref inBlock, 1); + outBlock = ref Unsafe.Add(ref outBlock, 1); + remainingBlocks -= 1; + } + } + + // When inlining this function, RyuJIT will almost make the + // generated code the same as if it were manually inlined + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public readonly void EncryptBlocks8(Vector128 in0, + Vector128 in1, + Vector128 in2, + Vector128 in3, + Vector128 in4, + Vector128 in5, + Vector128 in6, + Vector128 in7, + out Vector128 out0, + out Vector128 out1, + out Vector128 out2, + out Vector128 out3, + out Vector128 out4, + out Vector128 out5, + out Vector128 out6, + out Vector128 out7 + ) + { + ReadOnlySpan> keys = RoundKeys; + + Vector128 key = keys[0]; + Vector128 b0 = Sse2.Xor(in0, key); + Vector128 b1 = Sse2.Xor(in1, key); + Vector128 b2 = Sse2.Xor(in2, key); + Vector128 b3 = Sse2.Xor(in3, key); + Vector128 b4 = Sse2.Xor(in4, key); + Vector128 b5 = Sse2.Xor(in5, key); + Vector128 b6 = Sse2.Xor(in6, key); + Vector128 b7 = Sse2.Xor(in7, key); + + key = keys[1]; + b0 = AesNi.Encrypt(b0, key); + b1 = AesNi.Encrypt(b1, key); + b2 = AesNi.Encrypt(b2, key); + b3 = AesNi.Encrypt(b3, key); + b4 = AesNi.Encrypt(b4, key); + b5 = AesNi.Encrypt(b5, key); + b6 = AesNi.Encrypt(b6, key); + b7 = AesNi.Encrypt(b7, key); + + key = keys[2]; + b0 = AesNi.Encrypt(b0, key); + b1 = AesNi.Encrypt(b1, key); + b2 = AesNi.Encrypt(b2, key); + b3 = AesNi.Encrypt(b3, key); + b4 = AesNi.Encrypt(b4, key); + b5 = AesNi.Encrypt(b5, key); + b6 = AesNi.Encrypt(b6, key); + b7 = AesNi.Encrypt(b7, key); + + key = keys[3]; + b0 = AesNi.Encrypt(b0, key); + b1 = AesNi.Encrypt(b1, key); + b2 = AesNi.Encrypt(b2, key); + b3 = AesNi.Encrypt(b3, key); + b4 = AesNi.Encrypt(b4, key); + b5 = AesNi.Encrypt(b5, key); + b6 = AesNi.Encrypt(b6, key); + b7 = AesNi.Encrypt(b7, key); + + key = keys[4]; + b0 = AesNi.Encrypt(b0, key); + b1 = AesNi.Encrypt(b1, key); + b2 = AesNi.Encrypt(b2, key); + b3 = AesNi.Encrypt(b3, key); + b4 = AesNi.Encrypt(b4, key); + b5 = AesNi.Encrypt(b5, key); + b6 = AesNi.Encrypt(b6, key); + b7 = AesNi.Encrypt(b7, key); + + key = keys[5]; + b0 = AesNi.Encrypt(b0, key); + b1 = AesNi.Encrypt(b1, key); + b2 = AesNi.Encrypt(b2, key); + b3 = AesNi.Encrypt(b3, key); + b4 = AesNi.Encrypt(b4, key); + b5 = AesNi.Encrypt(b5, key); + b6 = AesNi.Encrypt(b6, key); + b7 = AesNi.Encrypt(b7, key); + + key = keys[6]; + b0 = AesNi.Encrypt(b0, key); + b1 = AesNi.Encrypt(b1, key); + b2 = AesNi.Encrypt(b2, key); + b3 = AesNi.Encrypt(b3, key); + b4 = AesNi.Encrypt(b4, key); + b5 = AesNi.Encrypt(b5, key); + b6 = AesNi.Encrypt(b6, key); + b7 = AesNi.Encrypt(b7, key); + + key = keys[7]; + b0 = AesNi.Encrypt(b0, key); + b1 = AesNi.Encrypt(b1, key); + b2 = AesNi.Encrypt(b2, key); + b3 = AesNi.Encrypt(b3, key); + b4 = AesNi.Encrypt(b4, key); + b5 = AesNi.Encrypt(b5, key); + b6 = AesNi.Encrypt(b6, key); + b7 = AesNi.Encrypt(b7, key); + + key = keys[8]; + b0 = AesNi.Encrypt(b0, key); + b1 = AesNi.Encrypt(b1, key); + b2 = AesNi.Encrypt(b2, key); + b3 = AesNi.Encrypt(b3, key); + b4 = AesNi.Encrypt(b4, key); + b5 = AesNi.Encrypt(b5, key); + b6 = AesNi.Encrypt(b6, key); + b7 = AesNi.Encrypt(b7, key); + + key = keys[9]; + b0 = AesNi.Encrypt(b0, key); + b1 = AesNi.Encrypt(b1, key); + b2 = AesNi.Encrypt(b2, key); + b3 = AesNi.Encrypt(b3, key); + b4 = AesNi.Encrypt(b4, key); + b5 = AesNi.Encrypt(b5, key); + b6 = AesNi.Encrypt(b6, key); + b7 = AesNi.Encrypt(b7, key); + + key = keys[10]; + out0 = AesNi.EncryptLast(b0, key); + out1 = AesNi.EncryptLast(b1, key); + out2 = AesNi.EncryptLast(b2, key); + out3 = AesNi.EncryptLast(b3, key); + out4 = AesNi.EncryptLast(b4, key); + out5 = AesNi.EncryptLast(b5, key); + out6 = AesNi.EncryptLast(b6, key); + out7 = AesNi.EncryptLast(b7, key); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public readonly void DecryptBlocks8( + Vector128 in0, + Vector128 in1, + Vector128 in2, + Vector128 in3, + Vector128 in4, + Vector128 in5, + Vector128 in6, + Vector128 in7, + out Vector128 out0, + out Vector128 out1, + out Vector128 out2, + out Vector128 out3, + out Vector128 out4, + out Vector128 out5, + out Vector128 out6, + out Vector128 out7 + ) + { + ReadOnlySpan> keys = RoundKeys; + + Vector128 key = keys[10]; + Vector128 b0 = Sse2.Xor(in0, key); + Vector128 b1 = Sse2.Xor(in1, key); + Vector128 b2 = Sse2.Xor(in2, key); + Vector128 b3 = Sse2.Xor(in3, key); + Vector128 b4 = Sse2.Xor(in4, key); + Vector128 b5 = Sse2.Xor(in5, key); + Vector128 b6 = Sse2.Xor(in6, key); + Vector128 b7 = Sse2.Xor(in7, key); + + key = keys[9]; + b0 = AesNi.Decrypt(b0, key); + b1 = AesNi.Decrypt(b1, key); + b2 = AesNi.Decrypt(b2, key); + b3 = AesNi.Decrypt(b3, key); + b4 = AesNi.Decrypt(b4, key); + b5 = AesNi.Decrypt(b5, key); + b6 = AesNi.Decrypt(b6, key); + b7 = AesNi.Decrypt(b7, key); + + key = keys[8]; + b0 = AesNi.Decrypt(b0, key); + b1 = AesNi.Decrypt(b1, key); + b2 = AesNi.Decrypt(b2, key); + b3 = AesNi.Decrypt(b3, key); + b4 = AesNi.Decrypt(b4, key); + b5 = AesNi.Decrypt(b5, key); + b6 = AesNi.Decrypt(b6, key); + b7 = AesNi.Decrypt(b7, key); + + key = keys[7]; + b0 = AesNi.Decrypt(b0, key); + b1 = AesNi.Decrypt(b1, key); + b2 = AesNi.Decrypt(b2, key); + b3 = AesNi.Decrypt(b3, key); + b4 = AesNi.Decrypt(b4, key); + b5 = AesNi.Decrypt(b5, key); + b6 = AesNi.Decrypt(b6, key); + b7 = AesNi.Decrypt(b7, key); + + key = keys[6]; + b0 = AesNi.Decrypt(b0, key); + b1 = AesNi.Decrypt(b1, key); + b2 = AesNi.Decrypt(b2, key); + b3 = AesNi.Decrypt(b3, key); + b4 = AesNi.Decrypt(b4, key); + b5 = AesNi.Decrypt(b5, key); + b6 = AesNi.Decrypt(b6, key); + b7 = AesNi.Decrypt(b7, key); + + key = keys[5]; + b0 = AesNi.Decrypt(b0, key); + b1 = AesNi.Decrypt(b1, key); + b2 = AesNi.Decrypt(b2, key); + b3 = AesNi.Decrypt(b3, key); + b4 = AesNi.Decrypt(b4, key); + b5 = AesNi.Decrypt(b5, key); + b6 = AesNi.Decrypt(b6, key); + b7 = AesNi.Decrypt(b7, key); + + key = keys[4]; + b0 = AesNi.Decrypt(b0, key); + b1 = AesNi.Decrypt(b1, key); + b2 = AesNi.Decrypt(b2, key); + b3 = AesNi.Decrypt(b3, key); + b4 = AesNi.Decrypt(b4, key); + b5 = AesNi.Decrypt(b5, key); + b6 = AesNi.Decrypt(b6, key); + b7 = AesNi.Decrypt(b7, key); + + key = keys[3]; + b0 = AesNi.Decrypt(b0, key); + b1 = AesNi.Decrypt(b1, key); + b2 = AesNi.Decrypt(b2, key); + b3 = AesNi.Decrypt(b3, key); + b4 = AesNi.Decrypt(b4, key); + b5 = AesNi.Decrypt(b5, key); + b6 = AesNi.Decrypt(b6, key); + b7 = AesNi.Decrypt(b7, key); + + key = keys[2]; + b0 = AesNi.Decrypt(b0, key); + b1 = AesNi.Decrypt(b1, key); + b2 = AesNi.Decrypt(b2, key); + b3 = AesNi.Decrypt(b3, key); + b4 = AesNi.Decrypt(b4, key); + b5 = AesNi.Decrypt(b5, key); + b6 = AesNi.Decrypt(b6, key); + b7 = AesNi.Decrypt(b7, key); + + key = keys[1]; + b0 = AesNi.Decrypt(b0, key); + b1 = AesNi.Decrypt(b1, key); + b2 = AesNi.Decrypt(b2, key); + b3 = AesNi.Decrypt(b3, key); + b4 = AesNi.Decrypt(b4, key); + b5 = AesNi.Decrypt(b5, key); + b6 = AesNi.Decrypt(b6, key); + b7 = AesNi.Decrypt(b7, key); + + key = keys[0]; + out0 = AesNi.DecryptLast(b0, key); + out1 = AesNi.DecryptLast(b1, key); + out2 = AesNi.DecryptLast(b2, key); + out3 = AesNi.DecryptLast(b3, key); + out4 = AesNi.DecryptLast(b4, key); + out5 = AesNi.DecryptLast(b5, key); + out6 = AesNi.DecryptLast(b6, key); + out7 = AesNi.DecryptLast(b7, key); + } + private static void KeyExpansion(ReadOnlySpan key, Span> roundKeys, bool isDecrypting) { var curKey = Unsafe.ReadUnaligned>(ref MemoryMarshal.GetReference(key)); diff --git a/src/LibHac/Crypto/Detail/AesCtrModeNi.cs b/src/LibHac/Crypto/Detail/AesCtrModeNi.cs index 244ca672..db5802bf 100644 --- a/src/LibHac/Crypto/Detail/AesCtrModeNi.cs +++ b/src/LibHac/Crypto/Detail/AesCtrModeNi.cs @@ -27,7 +27,8 @@ namespace LibHac.Crypto.Detail public void Transform(ReadOnlySpan input, Span output) { - int blockCount = Math.Min(input.Length, output.Length) >> 4; + int remaining = Math.Min(input.Length, output.Length); + int blockCount = remaining >> 4; ref Vector128 inBlock = ref Unsafe.As>(ref MemoryMarshal.GetReference(input)); ref Vector128 outBlock = ref Unsafe.As>(ref MemoryMarshal.GetReference(output)); @@ -38,7 +39,53 @@ namespace LibHac.Crypto.Detail Vector128 iv = _iv; Vector128 bSwappedIv = Ssse3.Shuffle(iv, byteSwapMask).AsUInt64(); - for (int i = 0; i < blockCount; i++) + while (remaining >= 8 * Aes.BlockSize) + { + Vector128 b0 = iv; + + bSwappedIv = Sse2.Add(bSwappedIv, inc); + Vector128 b1 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask); + + bSwappedIv = Sse2.Add(bSwappedIv, inc); + Vector128 b2 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask); + + bSwappedIv = Sse2.Add(bSwappedIv, inc); + Vector128 b3 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask); + + bSwappedIv = Sse2.Add(bSwappedIv, inc); + Vector128 b4 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask); + + bSwappedIv = Sse2.Add(bSwappedIv, inc); + Vector128 b5 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask); + + bSwappedIv = Sse2.Add(bSwappedIv, inc); + Vector128 b6 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask); + + bSwappedIv = Sse2.Add(bSwappedIv, inc); + Vector128 b7 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask); + + _aesCore.EncryptBlocks8(b0, b1, b2, b3, b4, b5, b6, b7, + out b0, out b1, out b2, out b3, out b4, out b5, out b6, out b7); + + Unsafe.Add(ref outBlock, 0) = Sse2.Xor(Unsafe.Add(ref inBlock, 0), b0); + Unsafe.Add(ref outBlock, 1) = Sse2.Xor(Unsafe.Add(ref inBlock, 1), b1); + Unsafe.Add(ref outBlock, 2) = Sse2.Xor(Unsafe.Add(ref inBlock, 2), b2); + Unsafe.Add(ref outBlock, 3) = Sse2.Xor(Unsafe.Add(ref inBlock, 3), b3); + Unsafe.Add(ref outBlock, 4) = Sse2.Xor(Unsafe.Add(ref inBlock, 4), b4); + Unsafe.Add(ref outBlock, 5) = Sse2.Xor(Unsafe.Add(ref inBlock, 5), b5); + Unsafe.Add(ref outBlock, 6) = Sse2.Xor(Unsafe.Add(ref inBlock, 6), b6); + Unsafe.Add(ref outBlock, 7) = Sse2.Xor(Unsafe.Add(ref inBlock, 7), b7); + + // Increase the counter + bSwappedIv = Sse2.Add(bSwappedIv, inc); + iv = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask); + + inBlock = ref Unsafe.Add(ref inBlock, 8); + outBlock = ref Unsafe.Add(ref outBlock, 8); + remaining -= 8 * Aes.BlockSize; + } + + while (remaining >= Aes.BlockSize) { Vector128 encIv = _aesCore.EncryptBlock(iv); outBlock = Sse2.Xor(inBlock, encIv); @@ -49,11 +96,12 @@ namespace LibHac.Crypto.Detail inBlock = ref Unsafe.Add(ref inBlock, 1); outBlock = ref Unsafe.Add(ref outBlock, 1); + remaining -= Aes.BlockSize; } _iv = iv; - if ((input.Length & 0xF) != 0) + if (remaining != 0) { EncryptCtrPartialBlock(input.Slice(blockCount * 0x10), output.Slice(blockCount * 0x10)); } diff --git a/src/LibHac/Crypto/Detail/AesEcbModeNi.cs b/src/LibHac/Crypto/Detail/AesEcbModeNi.cs index 17c24f78..c1073a09 100644 --- a/src/LibHac/Crypto/Detail/AesEcbModeNi.cs +++ b/src/LibHac/Crypto/Detail/AesEcbModeNi.cs @@ -16,12 +16,12 @@ namespace LibHac.Crypto.Detail public void Encrypt(ReadOnlySpan input, Span output) { - _aesCore.Encrypt(input, output); + _aesCore.EncryptInterleaved8(input, output); } public void Decrypt(ReadOnlySpan input, Span output) { - _aesCore.Decrypt(input, output); + _aesCore.DecryptInterleaved8(input, output); } } } diff --git a/src/LibHac/Crypto/Detail/AesXtsModeNi.cs b/src/LibHac/Crypto/Detail/AesXtsModeNi.cs index fcfca434..897ba171 100644 --- a/src/LibHac/Crypto/Detail/AesXtsModeNi.cs +++ b/src/LibHac/Crypto/Detail/AesXtsModeNi.cs @@ -31,10 +31,10 @@ namespace LibHac.Crypto.Detail public void Encrypt(ReadOnlySpan input, Span output) { int length = Math.Min(input.Length, output.Length); - int blockCount = length >> 4; + int remainingBlocks = length >> 4; int leftover = length & 0xF; - Debug.Assert(blockCount > 0); + Debug.Assert(remainingBlocks > 0); ref Vector128 inBlock = ref Unsafe.As>(ref MemoryMarshal.GetReference(input)); ref Vector128 outBlock = ref Unsafe.As>(ref MemoryMarshal.GetReference(output)); @@ -43,7 +43,51 @@ namespace LibHac.Crypto.Detail Vector128 tweak = _tweakAesCore.EncryptBlock(_iv); - for (int i = 0; i < blockCount; i++) + while (remainingBlocks > 7) + { + Vector128 b0 = Sse2.Xor(tweak, Unsafe.Add(ref inBlock, 0)); + + Vector128 tweak1 = Gf128Mul(tweak, mask); + Vector128 b1 = Sse2.Xor(tweak1, Unsafe.Add(ref inBlock, 1)); + + Vector128 tweak2 = Gf128Mul(tweak1, mask); + Vector128 b2 = Sse2.Xor(tweak2, Unsafe.Add(ref inBlock, 2)); + + Vector128 tweak3 = Gf128Mul(tweak2, mask); + Vector128 b3 = Sse2.Xor(tweak3, Unsafe.Add(ref inBlock, 3)); + + Vector128 tweak4 = Gf128Mul(tweak3, mask); + Vector128 b4 = Sse2.Xor(tweak4, Unsafe.Add(ref inBlock, 4)); + + Vector128 tweak5 = Gf128Mul(tweak4, mask); + Vector128 b5 = Sse2.Xor(tweak5, Unsafe.Add(ref inBlock, 5)); + + Vector128 tweak6 = Gf128Mul(tweak5, mask); + Vector128 b6 = Sse2.Xor(tweak6, Unsafe.Add(ref inBlock, 6)); + + Vector128 tweak7 = Gf128Mul(tweak6, mask); + Vector128 b7 = Sse2.Xor(tweak7, Unsafe.Add(ref inBlock, 7)); + + _dataAesCore.EncryptBlocks8(b0, b1, b2, b3, b4, b5, b6, b7, + out b0, out b1, out b2, out b3, out b4, out b5, out b6, out b7); + + Unsafe.Add(ref outBlock, 0) = Sse2.Xor(tweak, b0); + Unsafe.Add(ref outBlock, 1) = Sse2.Xor(tweak1, b1); + Unsafe.Add(ref outBlock, 2) = Sse2.Xor(tweak2, b2); + Unsafe.Add(ref outBlock, 3) = Sse2.Xor(tweak3, b3); + Unsafe.Add(ref outBlock, 4) = Sse2.Xor(tweak4, b4); + Unsafe.Add(ref outBlock, 5) = Sse2.Xor(tweak5, b5); + Unsafe.Add(ref outBlock, 6) = Sse2.Xor(tweak6, b6); + Unsafe.Add(ref outBlock, 7) = Sse2.Xor(tweak7, b7); + + tweak = Gf128Mul(tweak7, mask); + + inBlock = ref Unsafe.Add(ref inBlock, 8); + outBlock = ref Unsafe.Add(ref outBlock, 8); + remainingBlocks -= 8; + } + + while (remainingBlocks > 0) { Vector128 tmp = Sse2.Xor(inBlock, tweak); tmp = _dataAesCore.EncryptBlock(tmp); @@ -53,6 +97,7 @@ namespace LibHac.Crypto.Detail inBlock = ref Unsafe.Add(ref inBlock, 1); outBlock = ref Unsafe.Add(ref outBlock, 1); + remainingBlocks--; } if (leftover != 0) @@ -64,12 +109,12 @@ namespace LibHac.Crypto.Detail public void Decrypt(ReadOnlySpan input, Span output) { int length = Math.Min(input.Length, output.Length); - int blockCount = length >> 4; + int remainingBlocks = length >> 4; int leftover = length & 0xF; - Debug.Assert(blockCount > 0); + Debug.Assert(remainingBlocks > 0); - if (leftover != 0) blockCount--; + if (leftover != 0) remainingBlocks--; ref Vector128 inBlock = ref Unsafe.As>(ref MemoryMarshal.GetReference(input)); ref Vector128 outBlock = ref Unsafe.As>(ref MemoryMarshal.GetReference(output)); @@ -78,7 +123,51 @@ namespace LibHac.Crypto.Detail Vector128 tweak = _tweakAesCore.EncryptBlock(_iv); - for (int i = 0; i < blockCount; i++) + while (remainingBlocks > 7) + { + Vector128 b0 = Sse2.Xor(tweak, Unsafe.Add(ref inBlock, 0)); + + Vector128 tweak1 = Gf128Mul(tweak, mask); + Vector128 b1 = Sse2.Xor(tweak1, Unsafe.Add(ref inBlock, 1)); + + Vector128 tweak2 = Gf128Mul(tweak1, mask); + Vector128 b2 = Sse2.Xor(tweak2, Unsafe.Add(ref inBlock, 2)); + + Vector128 tweak3 = Gf128Mul(tweak2, mask); + Vector128 b3 = Sse2.Xor(tweak3, Unsafe.Add(ref inBlock, 3)); + + Vector128 tweak4 = Gf128Mul(tweak3, mask); + Vector128 b4 = Sse2.Xor(tweak4, Unsafe.Add(ref inBlock, 4)); + + Vector128 tweak5 = Gf128Mul(tweak4, mask); + Vector128 b5 = Sse2.Xor(tweak5, Unsafe.Add(ref inBlock, 5)); + + Vector128 tweak6 = Gf128Mul(tweak5, mask); + Vector128 b6 = Sse2.Xor(tweak6, Unsafe.Add(ref inBlock, 6)); + + Vector128 tweak7 = Gf128Mul(tweak6, mask); + Vector128 b7 = Sse2.Xor(tweak7, Unsafe.Add(ref inBlock, 7)); + + _dataAesCore.DecryptBlocks8(b0, b1, b2, b3, b4, b5, b6, b7, + out b0, out b1, out b2, out b3, out b4, out b5, out b6, out b7); + + Unsafe.Add(ref outBlock, 0) = Sse2.Xor(tweak, b0); + Unsafe.Add(ref outBlock, 1) = Sse2.Xor(tweak1, b1); + Unsafe.Add(ref outBlock, 2) = Sse2.Xor(tweak2, b2); + Unsafe.Add(ref outBlock, 3) = Sse2.Xor(tweak3, b3); + Unsafe.Add(ref outBlock, 4) = Sse2.Xor(tweak4, b4); + Unsafe.Add(ref outBlock, 5) = Sse2.Xor(tweak5, b5); + Unsafe.Add(ref outBlock, 6) = Sse2.Xor(tweak6, b6); + Unsafe.Add(ref outBlock, 7) = Sse2.Xor(tweak7, b7); + + tweak = Gf128Mul(tweak7, mask); + + inBlock = ref Unsafe.Add(ref inBlock, 8); + outBlock = ref Unsafe.Add(ref outBlock, 8); + remainingBlocks -= 8; + } + + while (remainingBlocks > 0) { Vector128 tmp = Sse2.Xor(inBlock, tweak); tmp = _dataAesCore.DecryptBlock(tmp); @@ -88,6 +177,7 @@ namespace LibHac.Crypto.Detail inBlock = ref Unsafe.Add(ref inBlock, 1); outBlock = ref Unsafe.Add(ref outBlock, 1); + remainingBlocks--; } if (leftover != 0)