Interleave AES instructions to improve performance

This commit is contained in:
Alex Barney 2019-11-20 21:33:31 -05:00
parent 191b3d41f6
commit db6269df5c
5 changed files with 541 additions and 17 deletions

View file

@ -49,14 +49,51 @@ namespace LibHac.Crypto.Detail
public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output) public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
{ {
int blockCount = Math.Min(input.Length, output.Length) >> 4; int remainingBlocks = Math.Min(input.Length, output.Length) >> 4;
ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input)); ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input));
ref Vector128<byte> outBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(output)); ref Vector128<byte> outBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(output));
Vector128<byte> iv = _iv; Vector128<byte> iv = _iv;
for (int i = 0; i < blockCount; i++) while (remainingBlocks > 7)
{
Vector128<byte> in0 = Unsafe.Add(ref inBlock, 0);
Vector128<byte> in1 = Unsafe.Add(ref inBlock, 1);
Vector128<byte> in2 = Unsafe.Add(ref inBlock, 2);
Vector128<byte> in3 = Unsafe.Add(ref inBlock, 3);
Vector128<byte> in4 = Unsafe.Add(ref inBlock, 4);
Vector128<byte> in5 = Unsafe.Add(ref inBlock, 5);
Vector128<byte> in6 = Unsafe.Add(ref inBlock, 6);
Vector128<byte> in7 = Unsafe.Add(ref inBlock, 7);
_aesCore.DecryptBlocks8(in0, in1, in2, in3, in4, in5, in6, in7,
out Vector128<byte> b0,
out Vector128<byte> b1,
out Vector128<byte> b2,
out Vector128<byte> b3,
out Vector128<byte> b4,
out Vector128<byte> b5,
out Vector128<byte> b6,
out Vector128<byte> b7);
Unsafe.Add(ref outBlock, 0) = Sse2.Xor(iv, b0);
Unsafe.Add(ref outBlock, 1) = Sse2.Xor(in0, b1);
Unsafe.Add(ref outBlock, 2) = Sse2.Xor(in1, b2);
Unsafe.Add(ref outBlock, 3) = Sse2.Xor(in2, b3);
Unsafe.Add(ref outBlock, 4) = Sse2.Xor(in3, b4);
Unsafe.Add(ref outBlock, 5) = Sse2.Xor(in4, b5);
Unsafe.Add(ref outBlock, 6) = Sse2.Xor(in5, b6);
Unsafe.Add(ref outBlock, 7) = Sse2.Xor(in6, b7);
iv = in7;
inBlock = ref Unsafe.Add(ref inBlock, 8);
outBlock = ref Unsafe.Add(ref outBlock, 8);
remainingBlocks -= 8;
}
while (remainingBlocks > 0)
{ {
Vector128<byte> currentBlock = inBlock; Vector128<byte> currentBlock = inBlock;
Vector128<byte> decBeforeIv = _aesCore.DecryptBlock(currentBlock); Vector128<byte> decBeforeIv = _aesCore.DecryptBlock(currentBlock);
@ -66,6 +103,7 @@ namespace LibHac.Crypto.Detail
inBlock = ref Unsafe.Add(ref inBlock, 1); inBlock = ref Unsafe.Add(ref inBlock, 1);
outBlock = ref Unsafe.Add(ref outBlock, 1); outBlock = ref Unsafe.Add(ref outBlock, 1);
remainingBlocks -= 1;
} }
_iv = iv; _iv = iv;

View file

@ -28,7 +28,6 @@ namespace LibHac.Crypto.Detail
public readonly ReadOnlySpan<Vector128<byte>> RoundKeys => public readonly ReadOnlySpan<Vector128<byte>> RoundKeys =>
MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef(in _roundKeys), RoundKeyCount); MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef(in _roundKeys), RoundKeyCount);
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public readonly void Encrypt(ReadOnlySpan<byte> input, Span<byte> output) public readonly void Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
{ {
int blockCount = Math.Min(input.Length, output.Length) >> 4; int blockCount = Math.Min(input.Length, output.Length) >> 4;
@ -45,7 +44,6 @@ namespace LibHac.Crypto.Detail
} }
} }
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public readonly void Decrypt(ReadOnlySpan<byte> input, Span<byte> output) public readonly void Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
{ {
int blockCount = Math.Min(input.Length, output.Length) >> 4; int blockCount = Math.Min(input.Length, output.Length) >> 4;
@ -98,7 +96,357 @@ namespace LibHac.Crypto.Detail
return AesNi.DecryptLast(b, keys[0]); return AesNi.DecryptLast(b, keys[0]);
} }
[MethodImpl(MethodImplOptions.AggressiveOptimization)] public readonly void EncryptInterleaved8(ReadOnlySpan<byte> input, Span<byte> output)
{
int remainingBlocks = Math.Min(input.Length, output.Length) >> 4;
ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input));
ref Vector128<byte> outBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(output));
while (remainingBlocks > 7)
{
EncryptBlocks8(
Unsafe.Add(ref inBlock, 0),
Unsafe.Add(ref inBlock, 1),
Unsafe.Add(ref inBlock, 2),
Unsafe.Add(ref inBlock, 3),
Unsafe.Add(ref inBlock, 4),
Unsafe.Add(ref inBlock, 5),
Unsafe.Add(ref inBlock, 6),
Unsafe.Add(ref inBlock, 7),
out Unsafe.Add(ref outBlock, 0),
out Unsafe.Add(ref outBlock, 1),
out Unsafe.Add(ref outBlock, 2),
out Unsafe.Add(ref outBlock, 3),
out Unsafe.Add(ref outBlock, 4),
out Unsafe.Add(ref outBlock, 5),
out Unsafe.Add(ref outBlock, 6),
out Unsafe.Add(ref outBlock, 7));
inBlock = ref Unsafe.Add(ref inBlock, 8);
outBlock = ref Unsafe.Add(ref outBlock, 8);
remainingBlocks -= 8;
}
while (remainingBlocks > 0)
{
outBlock = EncryptBlock(inBlock);
inBlock = ref Unsafe.Add(ref inBlock, 1);
outBlock = ref Unsafe.Add(ref outBlock, 1);
remainingBlocks -= 1;
}
}
public readonly void DecryptInterleaved8(ReadOnlySpan<byte> input, Span<byte> output)
{
int remainingBlocks = Math.Min(input.Length, output.Length) >> 4;
ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input));
ref Vector128<byte> outBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(output));
while (remainingBlocks > 7)
{
DecryptBlocks8(
Unsafe.Add(ref inBlock, 0),
Unsafe.Add(ref inBlock, 1),
Unsafe.Add(ref inBlock, 2),
Unsafe.Add(ref inBlock, 3),
Unsafe.Add(ref inBlock, 4),
Unsafe.Add(ref inBlock, 5),
Unsafe.Add(ref inBlock, 6),
Unsafe.Add(ref inBlock, 7),
out Unsafe.Add(ref outBlock, 0),
out Unsafe.Add(ref outBlock, 1),
out Unsafe.Add(ref outBlock, 2),
out Unsafe.Add(ref outBlock, 3),
out Unsafe.Add(ref outBlock, 4),
out Unsafe.Add(ref outBlock, 5),
out Unsafe.Add(ref outBlock, 6),
out Unsafe.Add(ref outBlock, 7));
inBlock = ref Unsafe.Add(ref inBlock, 8);
outBlock = ref Unsafe.Add(ref outBlock, 8);
remainingBlocks -= 8;
}
while (remainingBlocks > 0)
{
outBlock = DecryptBlock(inBlock);
inBlock = ref Unsafe.Add(ref inBlock, 1);
outBlock = ref Unsafe.Add(ref outBlock, 1);
remainingBlocks -= 1;
}
}
// When inlining this function, RyuJIT will almost make the
// generated code the same as if it were manually inlined
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public readonly void EncryptBlocks8(Vector128<byte> in0,
Vector128<byte> in1,
Vector128<byte> in2,
Vector128<byte> in3,
Vector128<byte> in4,
Vector128<byte> in5,
Vector128<byte> in6,
Vector128<byte> in7,
out Vector128<byte> out0,
out Vector128<byte> out1,
out Vector128<byte> out2,
out Vector128<byte> out3,
out Vector128<byte> out4,
out Vector128<byte> out5,
out Vector128<byte> out6,
out Vector128<byte> out7
)
{
ReadOnlySpan<Vector128<byte>> keys = RoundKeys;
Vector128<byte> key = keys[0];
Vector128<byte> b0 = Sse2.Xor(in0, key);
Vector128<byte> b1 = Sse2.Xor(in1, key);
Vector128<byte> b2 = Sse2.Xor(in2, key);
Vector128<byte> b3 = Sse2.Xor(in3, key);
Vector128<byte> b4 = Sse2.Xor(in4, key);
Vector128<byte> b5 = Sse2.Xor(in5, key);
Vector128<byte> b6 = Sse2.Xor(in6, key);
Vector128<byte> b7 = Sse2.Xor(in7, key);
key = keys[1];
b0 = AesNi.Encrypt(b0, key);
b1 = AesNi.Encrypt(b1, key);
b2 = AesNi.Encrypt(b2, key);
b3 = AesNi.Encrypt(b3, key);
b4 = AesNi.Encrypt(b4, key);
b5 = AesNi.Encrypt(b5, key);
b6 = AesNi.Encrypt(b6, key);
b7 = AesNi.Encrypt(b7, key);
key = keys[2];
b0 = AesNi.Encrypt(b0, key);
b1 = AesNi.Encrypt(b1, key);
b2 = AesNi.Encrypt(b2, key);
b3 = AesNi.Encrypt(b3, key);
b4 = AesNi.Encrypt(b4, key);
b5 = AesNi.Encrypt(b5, key);
b6 = AesNi.Encrypt(b6, key);
b7 = AesNi.Encrypt(b7, key);
key = keys[3];
b0 = AesNi.Encrypt(b0, key);
b1 = AesNi.Encrypt(b1, key);
b2 = AesNi.Encrypt(b2, key);
b3 = AesNi.Encrypt(b3, key);
b4 = AesNi.Encrypt(b4, key);
b5 = AesNi.Encrypt(b5, key);
b6 = AesNi.Encrypt(b6, key);
b7 = AesNi.Encrypt(b7, key);
key = keys[4];
b0 = AesNi.Encrypt(b0, key);
b1 = AesNi.Encrypt(b1, key);
b2 = AesNi.Encrypt(b2, key);
b3 = AesNi.Encrypt(b3, key);
b4 = AesNi.Encrypt(b4, key);
b5 = AesNi.Encrypt(b5, key);
b6 = AesNi.Encrypt(b6, key);
b7 = AesNi.Encrypt(b7, key);
key = keys[5];
b0 = AesNi.Encrypt(b0, key);
b1 = AesNi.Encrypt(b1, key);
b2 = AesNi.Encrypt(b2, key);
b3 = AesNi.Encrypt(b3, key);
b4 = AesNi.Encrypt(b4, key);
b5 = AesNi.Encrypt(b5, key);
b6 = AesNi.Encrypt(b6, key);
b7 = AesNi.Encrypt(b7, key);
key = keys[6];
b0 = AesNi.Encrypt(b0, key);
b1 = AesNi.Encrypt(b1, key);
b2 = AesNi.Encrypt(b2, key);
b3 = AesNi.Encrypt(b3, key);
b4 = AesNi.Encrypt(b4, key);
b5 = AesNi.Encrypt(b5, key);
b6 = AesNi.Encrypt(b6, key);
b7 = AesNi.Encrypt(b7, key);
key = keys[7];
b0 = AesNi.Encrypt(b0, key);
b1 = AesNi.Encrypt(b1, key);
b2 = AesNi.Encrypt(b2, key);
b3 = AesNi.Encrypt(b3, key);
b4 = AesNi.Encrypt(b4, key);
b5 = AesNi.Encrypt(b5, key);
b6 = AesNi.Encrypt(b6, key);
b7 = AesNi.Encrypt(b7, key);
key = keys[8];
b0 = AesNi.Encrypt(b0, key);
b1 = AesNi.Encrypt(b1, key);
b2 = AesNi.Encrypt(b2, key);
b3 = AesNi.Encrypt(b3, key);
b4 = AesNi.Encrypt(b4, key);
b5 = AesNi.Encrypt(b5, key);
b6 = AesNi.Encrypt(b6, key);
b7 = AesNi.Encrypt(b7, key);
key = keys[9];
b0 = AesNi.Encrypt(b0, key);
b1 = AesNi.Encrypt(b1, key);
b2 = AesNi.Encrypt(b2, key);
b3 = AesNi.Encrypt(b3, key);
b4 = AesNi.Encrypt(b4, key);
b5 = AesNi.Encrypt(b5, key);
b6 = AesNi.Encrypt(b6, key);
b7 = AesNi.Encrypt(b7, key);
key = keys[10];
out0 = AesNi.EncryptLast(b0, key);
out1 = AesNi.EncryptLast(b1, key);
out2 = AesNi.EncryptLast(b2, key);
out3 = AesNi.EncryptLast(b3, key);
out4 = AesNi.EncryptLast(b4, key);
out5 = AesNi.EncryptLast(b5, key);
out6 = AesNi.EncryptLast(b6, key);
out7 = AesNi.EncryptLast(b7, key);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public readonly void DecryptBlocks8(
Vector128<byte> in0,
Vector128<byte> in1,
Vector128<byte> in2,
Vector128<byte> in3,
Vector128<byte> in4,
Vector128<byte> in5,
Vector128<byte> in6,
Vector128<byte> in7,
out Vector128<byte> out0,
out Vector128<byte> out1,
out Vector128<byte> out2,
out Vector128<byte> out3,
out Vector128<byte> out4,
out Vector128<byte> out5,
out Vector128<byte> out6,
out Vector128<byte> out7
)
{
ReadOnlySpan<Vector128<byte>> keys = RoundKeys;
Vector128<byte> key = keys[10];
Vector128<byte> b0 = Sse2.Xor(in0, key);
Vector128<byte> b1 = Sse2.Xor(in1, key);
Vector128<byte> b2 = Sse2.Xor(in2, key);
Vector128<byte> b3 = Sse2.Xor(in3, key);
Vector128<byte> b4 = Sse2.Xor(in4, key);
Vector128<byte> b5 = Sse2.Xor(in5, key);
Vector128<byte> b6 = Sse2.Xor(in6, key);
Vector128<byte> b7 = Sse2.Xor(in7, key);
key = keys[9];
b0 = AesNi.Decrypt(b0, key);
b1 = AesNi.Decrypt(b1, key);
b2 = AesNi.Decrypt(b2, key);
b3 = AesNi.Decrypt(b3, key);
b4 = AesNi.Decrypt(b4, key);
b5 = AesNi.Decrypt(b5, key);
b6 = AesNi.Decrypt(b6, key);
b7 = AesNi.Decrypt(b7, key);
key = keys[8];
b0 = AesNi.Decrypt(b0, key);
b1 = AesNi.Decrypt(b1, key);
b2 = AesNi.Decrypt(b2, key);
b3 = AesNi.Decrypt(b3, key);
b4 = AesNi.Decrypt(b4, key);
b5 = AesNi.Decrypt(b5, key);
b6 = AesNi.Decrypt(b6, key);
b7 = AesNi.Decrypt(b7, key);
key = keys[7];
b0 = AesNi.Decrypt(b0, key);
b1 = AesNi.Decrypt(b1, key);
b2 = AesNi.Decrypt(b2, key);
b3 = AesNi.Decrypt(b3, key);
b4 = AesNi.Decrypt(b4, key);
b5 = AesNi.Decrypt(b5, key);
b6 = AesNi.Decrypt(b6, key);
b7 = AesNi.Decrypt(b7, key);
key = keys[6];
b0 = AesNi.Decrypt(b0, key);
b1 = AesNi.Decrypt(b1, key);
b2 = AesNi.Decrypt(b2, key);
b3 = AesNi.Decrypt(b3, key);
b4 = AesNi.Decrypt(b4, key);
b5 = AesNi.Decrypt(b5, key);
b6 = AesNi.Decrypt(b6, key);
b7 = AesNi.Decrypt(b7, key);
key = keys[5];
b0 = AesNi.Decrypt(b0, key);
b1 = AesNi.Decrypt(b1, key);
b2 = AesNi.Decrypt(b2, key);
b3 = AesNi.Decrypt(b3, key);
b4 = AesNi.Decrypt(b4, key);
b5 = AesNi.Decrypt(b5, key);
b6 = AesNi.Decrypt(b6, key);
b7 = AesNi.Decrypt(b7, key);
key = keys[4];
b0 = AesNi.Decrypt(b0, key);
b1 = AesNi.Decrypt(b1, key);
b2 = AesNi.Decrypt(b2, key);
b3 = AesNi.Decrypt(b3, key);
b4 = AesNi.Decrypt(b4, key);
b5 = AesNi.Decrypt(b5, key);
b6 = AesNi.Decrypt(b6, key);
b7 = AesNi.Decrypt(b7, key);
key = keys[3];
b0 = AesNi.Decrypt(b0, key);
b1 = AesNi.Decrypt(b1, key);
b2 = AesNi.Decrypt(b2, key);
b3 = AesNi.Decrypt(b3, key);
b4 = AesNi.Decrypt(b4, key);
b5 = AesNi.Decrypt(b5, key);
b6 = AesNi.Decrypt(b6, key);
b7 = AesNi.Decrypt(b7, key);
key = keys[2];
b0 = AesNi.Decrypt(b0, key);
b1 = AesNi.Decrypt(b1, key);
b2 = AesNi.Decrypt(b2, key);
b3 = AesNi.Decrypt(b3, key);
b4 = AesNi.Decrypt(b4, key);
b5 = AesNi.Decrypt(b5, key);
b6 = AesNi.Decrypt(b6, key);
b7 = AesNi.Decrypt(b7, key);
key = keys[1];
b0 = AesNi.Decrypt(b0, key);
b1 = AesNi.Decrypt(b1, key);
b2 = AesNi.Decrypt(b2, key);
b3 = AesNi.Decrypt(b3, key);
b4 = AesNi.Decrypt(b4, key);
b5 = AesNi.Decrypt(b5, key);
b6 = AesNi.Decrypt(b6, key);
b7 = AesNi.Decrypt(b7, key);
key = keys[0];
out0 = AesNi.DecryptLast(b0, key);
out1 = AesNi.DecryptLast(b1, key);
out2 = AesNi.DecryptLast(b2, key);
out3 = AesNi.DecryptLast(b3, key);
out4 = AesNi.DecryptLast(b4, key);
out5 = AesNi.DecryptLast(b5, key);
out6 = AesNi.DecryptLast(b6, key);
out7 = AesNi.DecryptLast(b7, key);
}
private static void KeyExpansion(ReadOnlySpan<byte> key, Span<Vector128<byte>> roundKeys, bool isDecrypting) private static void KeyExpansion(ReadOnlySpan<byte> key, Span<Vector128<byte>> roundKeys, bool isDecrypting)
{ {
var curKey = Unsafe.ReadUnaligned<Vector128<byte>>(ref MemoryMarshal.GetReference(key)); var curKey = Unsafe.ReadUnaligned<Vector128<byte>>(ref MemoryMarshal.GetReference(key));

View file

@ -27,7 +27,8 @@ namespace LibHac.Crypto.Detail
public void Transform(ReadOnlySpan<byte> input, Span<byte> output) public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
{ {
int blockCount = Math.Min(input.Length, output.Length) >> 4; int remaining = Math.Min(input.Length, output.Length);
int blockCount = remaining >> 4;
ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input)); ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input));
ref Vector128<byte> outBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(output)); ref Vector128<byte> outBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(output));
@ -38,7 +39,53 @@ namespace LibHac.Crypto.Detail
Vector128<byte> iv = _iv; Vector128<byte> iv = _iv;
Vector128<ulong> bSwappedIv = Ssse3.Shuffle(iv, byteSwapMask).AsUInt64(); Vector128<ulong> bSwappedIv = Ssse3.Shuffle(iv, byteSwapMask).AsUInt64();
for (int i = 0; i < blockCount; i++) while (remaining >= 8 * Aes.BlockSize)
{
Vector128<byte> b0 = iv;
bSwappedIv = Sse2.Add(bSwappedIv, inc);
Vector128<byte> b1 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask);
bSwappedIv = Sse2.Add(bSwappedIv, inc);
Vector128<byte> b2 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask);
bSwappedIv = Sse2.Add(bSwappedIv, inc);
Vector128<byte> b3 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask);
bSwappedIv = Sse2.Add(bSwappedIv, inc);
Vector128<byte> b4 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask);
bSwappedIv = Sse2.Add(bSwappedIv, inc);
Vector128<byte> b5 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask);
bSwappedIv = Sse2.Add(bSwappedIv, inc);
Vector128<byte> b6 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask);
bSwappedIv = Sse2.Add(bSwappedIv, inc);
Vector128<byte> b7 = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask);
_aesCore.EncryptBlocks8(b0, b1, b2, b3, b4, b5, b6, b7,
out b0, out b1, out b2, out b3, out b4, out b5, out b6, out b7);
Unsafe.Add(ref outBlock, 0) = Sse2.Xor(Unsafe.Add(ref inBlock, 0), b0);
Unsafe.Add(ref outBlock, 1) = Sse2.Xor(Unsafe.Add(ref inBlock, 1), b1);
Unsafe.Add(ref outBlock, 2) = Sse2.Xor(Unsafe.Add(ref inBlock, 2), b2);
Unsafe.Add(ref outBlock, 3) = Sse2.Xor(Unsafe.Add(ref inBlock, 3), b3);
Unsafe.Add(ref outBlock, 4) = Sse2.Xor(Unsafe.Add(ref inBlock, 4), b4);
Unsafe.Add(ref outBlock, 5) = Sse2.Xor(Unsafe.Add(ref inBlock, 5), b5);
Unsafe.Add(ref outBlock, 6) = Sse2.Xor(Unsafe.Add(ref inBlock, 6), b6);
Unsafe.Add(ref outBlock, 7) = Sse2.Xor(Unsafe.Add(ref inBlock, 7), b7);
// Increase the counter
bSwappedIv = Sse2.Add(bSwappedIv, inc);
iv = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask);
inBlock = ref Unsafe.Add(ref inBlock, 8);
outBlock = ref Unsafe.Add(ref outBlock, 8);
remaining -= 8 * Aes.BlockSize;
}
while (remaining >= Aes.BlockSize)
{ {
Vector128<byte> encIv = _aesCore.EncryptBlock(iv); Vector128<byte> encIv = _aesCore.EncryptBlock(iv);
outBlock = Sse2.Xor(inBlock, encIv); outBlock = Sse2.Xor(inBlock, encIv);
@ -49,11 +96,12 @@ namespace LibHac.Crypto.Detail
inBlock = ref Unsafe.Add(ref inBlock, 1); inBlock = ref Unsafe.Add(ref inBlock, 1);
outBlock = ref Unsafe.Add(ref outBlock, 1); outBlock = ref Unsafe.Add(ref outBlock, 1);
remaining -= Aes.BlockSize;
} }
_iv = iv; _iv = iv;
if ((input.Length & 0xF) != 0) if (remaining != 0)
{ {
EncryptCtrPartialBlock(input.Slice(blockCount * 0x10), output.Slice(blockCount * 0x10)); EncryptCtrPartialBlock(input.Slice(blockCount * 0x10), output.Slice(blockCount * 0x10));
} }

View file

@ -16,12 +16,12 @@ namespace LibHac.Crypto.Detail
public void Encrypt(ReadOnlySpan<byte> input, Span<byte> output) public void Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
{ {
_aesCore.Encrypt(input, output); _aesCore.EncryptInterleaved8(input, output);
} }
public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output) public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
{ {
_aesCore.Decrypt(input, output); _aesCore.DecryptInterleaved8(input, output);
} }
} }
} }

View file

@ -31,10 +31,10 @@ namespace LibHac.Crypto.Detail
public void Encrypt(ReadOnlySpan<byte> input, Span<byte> output) public void Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
{ {
int length = Math.Min(input.Length, output.Length); int length = Math.Min(input.Length, output.Length);
int blockCount = length >> 4; int remainingBlocks = length >> 4;
int leftover = length & 0xF; int leftover = length & 0xF;
Debug.Assert(blockCount > 0); Debug.Assert(remainingBlocks > 0);
ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input)); ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input));
ref Vector128<byte> outBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(output)); ref Vector128<byte> outBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(output));
@ -43,7 +43,51 @@ namespace LibHac.Crypto.Detail
Vector128<byte> tweak = _tweakAesCore.EncryptBlock(_iv); Vector128<byte> tweak = _tweakAesCore.EncryptBlock(_iv);
for (int i = 0; i < blockCount; i++) while (remainingBlocks > 7)
{
Vector128<byte> b0 = Sse2.Xor(tweak, Unsafe.Add(ref inBlock, 0));
Vector128<byte> tweak1 = Gf128Mul(tweak, mask);
Vector128<byte> b1 = Sse2.Xor(tweak1, Unsafe.Add(ref inBlock, 1));
Vector128<byte> tweak2 = Gf128Mul(tweak1, mask);
Vector128<byte> b2 = Sse2.Xor(tweak2, Unsafe.Add(ref inBlock, 2));
Vector128<byte> tweak3 = Gf128Mul(tweak2, mask);
Vector128<byte> b3 = Sse2.Xor(tweak3, Unsafe.Add(ref inBlock, 3));
Vector128<byte> tweak4 = Gf128Mul(tweak3, mask);
Vector128<byte> b4 = Sse2.Xor(tweak4, Unsafe.Add(ref inBlock, 4));
Vector128<byte> tweak5 = Gf128Mul(tweak4, mask);
Vector128<byte> b5 = Sse2.Xor(tweak5, Unsafe.Add(ref inBlock, 5));
Vector128<byte> tweak6 = Gf128Mul(tweak5, mask);
Vector128<byte> b6 = Sse2.Xor(tweak6, Unsafe.Add(ref inBlock, 6));
Vector128<byte> tweak7 = Gf128Mul(tweak6, mask);
Vector128<byte> b7 = Sse2.Xor(tweak7, Unsafe.Add(ref inBlock, 7));
_dataAesCore.EncryptBlocks8(b0, b1, b2, b3, b4, b5, b6, b7,
out b0, out b1, out b2, out b3, out b4, out b5, out b6, out b7);
Unsafe.Add(ref outBlock, 0) = Sse2.Xor(tweak, b0);
Unsafe.Add(ref outBlock, 1) = Sse2.Xor(tweak1, b1);
Unsafe.Add(ref outBlock, 2) = Sse2.Xor(tweak2, b2);
Unsafe.Add(ref outBlock, 3) = Sse2.Xor(tweak3, b3);
Unsafe.Add(ref outBlock, 4) = Sse2.Xor(tweak4, b4);
Unsafe.Add(ref outBlock, 5) = Sse2.Xor(tweak5, b5);
Unsafe.Add(ref outBlock, 6) = Sse2.Xor(tweak6, b6);
Unsafe.Add(ref outBlock, 7) = Sse2.Xor(tweak7, b7);
tweak = Gf128Mul(tweak7, mask);
inBlock = ref Unsafe.Add(ref inBlock, 8);
outBlock = ref Unsafe.Add(ref outBlock, 8);
remainingBlocks -= 8;
}
while (remainingBlocks > 0)
{ {
Vector128<byte> tmp = Sse2.Xor(inBlock, tweak); Vector128<byte> tmp = Sse2.Xor(inBlock, tweak);
tmp = _dataAesCore.EncryptBlock(tmp); tmp = _dataAesCore.EncryptBlock(tmp);
@ -53,6 +97,7 @@ namespace LibHac.Crypto.Detail
inBlock = ref Unsafe.Add(ref inBlock, 1); inBlock = ref Unsafe.Add(ref inBlock, 1);
outBlock = ref Unsafe.Add(ref outBlock, 1); outBlock = ref Unsafe.Add(ref outBlock, 1);
remainingBlocks--;
} }
if (leftover != 0) if (leftover != 0)
@ -64,12 +109,12 @@ namespace LibHac.Crypto.Detail
public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output) public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
{ {
int length = Math.Min(input.Length, output.Length); int length = Math.Min(input.Length, output.Length);
int blockCount = length >> 4; int remainingBlocks = length >> 4;
int leftover = length & 0xF; int leftover = length & 0xF;
Debug.Assert(blockCount > 0); Debug.Assert(remainingBlocks > 0);
if (leftover != 0) blockCount--; if (leftover != 0) remainingBlocks--;
ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input)); ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input));
ref Vector128<byte> outBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(output)); ref Vector128<byte> outBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(output));
@ -78,7 +123,51 @@ namespace LibHac.Crypto.Detail
Vector128<byte> tweak = _tweakAesCore.EncryptBlock(_iv); Vector128<byte> tweak = _tweakAesCore.EncryptBlock(_iv);
for (int i = 0; i < blockCount; i++) while (remainingBlocks > 7)
{
Vector128<byte> b0 = Sse2.Xor(tweak, Unsafe.Add(ref inBlock, 0));
Vector128<byte> tweak1 = Gf128Mul(tweak, mask);
Vector128<byte> b1 = Sse2.Xor(tweak1, Unsafe.Add(ref inBlock, 1));
Vector128<byte> tweak2 = Gf128Mul(tweak1, mask);
Vector128<byte> b2 = Sse2.Xor(tweak2, Unsafe.Add(ref inBlock, 2));
Vector128<byte> tweak3 = Gf128Mul(tweak2, mask);
Vector128<byte> b3 = Sse2.Xor(tweak3, Unsafe.Add(ref inBlock, 3));
Vector128<byte> tweak4 = Gf128Mul(tweak3, mask);
Vector128<byte> b4 = Sse2.Xor(tweak4, Unsafe.Add(ref inBlock, 4));
Vector128<byte> tweak5 = Gf128Mul(tweak4, mask);
Vector128<byte> b5 = Sse2.Xor(tweak5, Unsafe.Add(ref inBlock, 5));
Vector128<byte> tweak6 = Gf128Mul(tweak5, mask);
Vector128<byte> b6 = Sse2.Xor(tweak6, Unsafe.Add(ref inBlock, 6));
Vector128<byte> tweak7 = Gf128Mul(tweak6, mask);
Vector128<byte> b7 = Sse2.Xor(tweak7, Unsafe.Add(ref inBlock, 7));
_dataAesCore.DecryptBlocks8(b0, b1, b2, b3, b4, b5, b6, b7,
out b0, out b1, out b2, out b3, out b4, out b5, out b6, out b7);
Unsafe.Add(ref outBlock, 0) = Sse2.Xor(tweak, b0);
Unsafe.Add(ref outBlock, 1) = Sse2.Xor(tweak1, b1);
Unsafe.Add(ref outBlock, 2) = Sse2.Xor(tweak2, b2);
Unsafe.Add(ref outBlock, 3) = Sse2.Xor(tweak3, b3);
Unsafe.Add(ref outBlock, 4) = Sse2.Xor(tweak4, b4);
Unsafe.Add(ref outBlock, 5) = Sse2.Xor(tweak5, b5);
Unsafe.Add(ref outBlock, 6) = Sse2.Xor(tweak6, b6);
Unsafe.Add(ref outBlock, 7) = Sse2.Xor(tweak7, b7);
tweak = Gf128Mul(tweak7, mask);
inBlock = ref Unsafe.Add(ref inBlock, 8);
outBlock = ref Unsafe.Add(ref outBlock, 8);
remainingBlocks -= 8;
}
while (remainingBlocks > 0)
{ {
Vector128<byte> tmp = Sse2.Xor(inBlock, tweak); Vector128<byte> tmp = Sse2.Xor(inBlock, tweak);
tmp = _dataAesCore.DecryptBlock(tmp); tmp = _dataAesCore.DecryptBlock(tmp);
@ -88,6 +177,7 @@ namespace LibHac.Crypto.Detail
inBlock = ref Unsafe.Add(ref inBlock, 1); inBlock = ref Unsafe.Add(ref inBlock, 1);
outBlock = ref Unsafe.Add(ref outBlock, 1); outBlock = ref Unsafe.Add(ref outBlock, 1);
remainingBlocks--;
} }
if (leftover != 0) if (leftover != 0)