diff --git a/src/LibHac/Common/RentedArray.cs b/src/LibHac/Common/RentedArray.cs new file mode 100644 index 00000000..dde343fb --- /dev/null +++ b/src/LibHac/Common/RentedArray.cs @@ -0,0 +1,41 @@ +using System; +using System.Buffers; +using System.Runtime.CompilerServices; + +namespace LibHac.Common +{ + public readonly ref struct RentedArray + { + // It's faster to create new smaller arrays than rent them + private const int RentThresholdBytes = 512; + private static int RentThresholdElements => RentThresholdBytes / Unsafe.SizeOf(); + + private readonly Span _span; + + public T[] Array { get; } + public Span Span => _span; + + public RentedArray(int minimumSize) + { + if (minimumSize >= RentThresholdElements) + { + Array = ArrayPool.Shared.Rent(minimumSize); + } + else + { + Array = new T[minimumSize]; + } + + _span = Array.AsSpan(0, minimumSize); + } + + public void Dispose() + { + // Only return if array was rented + if (_span.Length >= RentThresholdElements) + { + ArrayPool.Shared.Return(Array); + } + } + } +} diff --git a/src/LibHac/Crypto2/Aes.cs b/src/LibHac/Crypto2/Aes.cs index d985ec3c..5f2a5f44 100644 --- a/src/LibHac/Crypto2/Aes.cs +++ b/src/LibHac/Crypto2/Aes.cs @@ -86,7 +86,7 @@ namespace LibHac.Crypto2 return new AesCtrCipherNi(key, iv); } #endif - return new AesCtrEncryptor(key, iv); + return new AesCtrCipher(key, iv); } public static ICipher CreateXtsDecryptor(ReadOnlySpan key1, ReadOnlySpan key2, @@ -98,7 +98,7 @@ namespace LibHac.Crypto2 return new AesXtsDecryptorNi(key1, key2, iv); } #endif - return new AesXtsCipher(key1, key2, iv, true); + return new AesXtsDecryptor(key1, key2, iv); } public static ICipher CreateXtsEncryptor(ReadOnlySpan key1, ReadOnlySpan key2, @@ -110,7 +110,7 @@ namespace LibHac.Crypto2 return new AesXtsEncryptorNi(key1, key2, iv); } #endif - return new AesXtsCipher(key1, key2, iv, false); + return new AesXtsEncryptor(key1, key2, iv); } public static void EncryptEcb128(ReadOnlySpan input, Span output, ReadOnlySpan key, diff --git a/src/LibHac/Crypto2/AesCbcCipher.cs b/src/LibHac/Crypto2/AesCbcCipher.cs index 26e0ee33..03a1ef16 100644 --- a/src/LibHac/Crypto2/AesCbcCipher.cs +++ b/src/LibHac/Crypto2/AesCbcCipher.cs @@ -1,59 +1,37 @@ using System; -using System.Security.Cryptography; +using LibHac.Crypto2.Detail; namespace LibHac.Crypto2 { public class AesCbcEncryptor : ICipher { - private ICryptoTransform _encryptor; + private AesCbcMode _baseCipher; public AesCbcEncryptor(ReadOnlySpan key, ReadOnlySpan iv) { - Aes aes = Aes.Create(); - - if (aes == null) throw new CryptographicException("Unable to create AES object"); - aes.Key = key.ToArray(); - aes.IV = iv.ToArray(); - aes.Mode = CipherMode.CBC; - aes.Padding = PaddingMode.None; - - _encryptor = aes.CreateEncryptor(); + _baseCipher = new AesCbcMode(); + _baseCipher.Initialize(key, iv, false); } public void Transform(ReadOnlySpan input, Span output) { - var outputBuffer = new byte[input.Length]; - - _encryptor.TransformBlock(input.ToArray(), 0, input.Length, outputBuffer, 0); - - outputBuffer.CopyTo(output); + _baseCipher.Encrypt(input, output); } } public class AesCbcDecryptor : ICipher { - private ICryptoTransform _decryptor; + private AesCbcMode _baseCipher; public AesCbcDecryptor(ReadOnlySpan key, ReadOnlySpan iv) { - Aes aes = Aes.Create(); - - if (aes == null) throw new CryptographicException("Unable to create AES object"); - aes.Key = key.ToArray(); - aes.IV = iv.ToArray(); - aes.Mode = CipherMode.CBC; - aes.Padding = PaddingMode.None; - - _decryptor = aes.CreateDecryptor(); + _baseCipher = new AesCbcMode(); + _baseCipher.Initialize(key, iv, true); } public void Transform(ReadOnlySpan input, Span output) { - var outputBuffer = new byte[input.Length]; - - _decryptor.TransformBlock(input.ToArray(), 0, input.Length, outputBuffer, 0); - - outputBuffer.CopyTo(output); + _baseCipher.Decrypt(input, output); } } } diff --git a/src/LibHac/Crypto2/AesCtrCipher.cs b/src/LibHac/Crypto2/AesCtrCipher.cs index 0f1ed6e5..f0e792c9 100644 --- a/src/LibHac/Crypto2/AesCtrCipher.cs +++ b/src/LibHac/Crypto2/AesCtrCipher.cs @@ -1,69 +1,21 @@ using System; -using System.Buffers; -using System.Buffers.Binary; -using System.Runtime.InteropServices; -using System.Security.Cryptography; +using LibHac.Crypto2.Detail; namespace LibHac.Crypto2 { - public class AesCtrEncryptor : ICipher + public class AesCtrCipher : ICipher { - private const int BlockSize = 128; - private const int BlockSizeBytes = BlockSize / 8; + private AesCtrMode _baseCipher; - private readonly ICryptoTransform _encryptor; - private readonly byte[] _counter = new byte[0x10]; - - public AesCtrEncryptor(ReadOnlySpan key, ReadOnlySpan iv) + public AesCtrCipher(ReadOnlySpan key, ReadOnlySpan iv) { - Aes aes = Aes.Create(); - if (aes == null) throw new CryptographicException("Unable to create AES object"); - - aes.Mode = CipherMode.ECB; - aes.Padding = PaddingMode.None; - - _encryptor = aes.CreateEncryptor(key.ToArray(), new byte[0x10]); - - iv.CopyTo(_counter); + _baseCipher = new AesCtrMode(); + _baseCipher.Initialize(key, iv); } public void Transform(ReadOnlySpan input, Span output) { - int blockCount = Util.DivideByRoundUp(input.Length, BlockSizeBytes); - int length = blockCount * BlockSizeBytes; - - byte[] counterXor = ArrayPool.Shared.Rent(length); - try - { - FillDecryptedCounter(_counter, counterXor.AsSpan(0, length)); - - _encryptor.TransformBlock(counterXor, 0, length, counterXor, 0); - - input.CopyTo(output); - Util.XorArrays(output, counterXor); - } - finally - { - ArrayPool.Shared.Return(counterXor); - } - } - - private static void FillDecryptedCounter(Span counter, Span buffer) - { - Span bufL = MemoryMarshal.Cast(buffer); - Span counterL = MemoryMarshal.Cast(counter); - - ulong hi = counterL[0]; - ulong lo = BinaryPrimitives.ReverseEndianness(counterL[1]); - - for (int i = 0; i < bufL.Length; i += 2) - { - bufL[i] = hi; - bufL[i + 1] = BinaryPrimitives.ReverseEndianness(lo); - lo++; - } - - counterL[1] = BinaryPrimitives.ReverseEndianness(lo); + _baseCipher.Transform(input, output); } } } diff --git a/src/LibHac/Crypto2/AesEcbCipher.cs b/src/LibHac/Crypto2/AesEcbCipher.cs index 48dda455..179a0cd4 100644 --- a/src/LibHac/Crypto2/AesEcbCipher.cs +++ b/src/LibHac/Crypto2/AesEcbCipher.cs @@ -1,94 +1,37 @@ using System; -using System.Buffers; -using System.Security.Cryptography; +using LibHac.Crypto2.Detail; namespace LibHac.Crypto2 { public class AesEcbEncryptor : ICipher { - private const int BufferRentThreshold = 1024; - private ICryptoTransform _encryptor; + private AesEcbMode _baseCipher; public AesEcbEncryptor(ReadOnlySpan key) { - Aes aes = Aes.Create(); - - if (aes == null) throw new CryptographicException("Unable to create AES object"); - aes.Key = key.ToArray(); - aes.Mode = CipherMode.ECB; - aes.Padding = PaddingMode.None; - - _encryptor = aes.CreateEncryptor(); + _baseCipher = new AesEcbMode(); + _baseCipher.Initialize(key, false); } public void Transform(ReadOnlySpan input, Span output) { - if (input.Length < BufferRentThreshold) - { - var outputBuffer = new byte[input.Length]; - input.CopyTo(outputBuffer); - - _encryptor.TransformBlock(outputBuffer, 0, input.Length, outputBuffer, 0); - - outputBuffer.CopyTo(output); - } - else - { - byte[] outputBuffer = ArrayPool.Shared.Rent(input.Length); - try - { - input.CopyTo(outputBuffer); - - _encryptor.TransformBlock(outputBuffer, 0, input.Length, outputBuffer, 0); - - outputBuffer.CopyTo(output); - } - finally { ArrayPool.Shared.Return(outputBuffer); } - } + _baseCipher.Encrypt(input, output); } } public class AesEcbDecryptor : ICipher { - private const int BufferRentThreshold = 1024; - private ICryptoTransform _decryptor; + private AesEcbMode _baseCipher; public AesEcbDecryptor(ReadOnlySpan key) { - Aes aes = Aes.Create(); - - if (aes == null) throw new CryptographicException("Unable to create AES object"); - aes.Key = key.ToArray(); - aes.Mode = CipherMode.ECB; - aes.Padding = PaddingMode.None; - - _decryptor = aes.CreateDecryptor(); + _baseCipher = new AesEcbMode(); + _baseCipher.Initialize(key, true); } public void Transform(ReadOnlySpan input, Span output) { - if (input.Length < BufferRentThreshold) - { - var outputBuffer = new byte[input.Length]; - input.CopyTo(outputBuffer); - - _decryptor.TransformBlock(outputBuffer, 0, input.Length, outputBuffer, 0); - - outputBuffer.CopyTo(output); - } - else - { - byte[] outputBuffer = ArrayPool.Shared.Rent(input.Length); - try - { - input.CopyTo(outputBuffer); - - _decryptor.TransformBlock(outputBuffer, 0, input.Length, outputBuffer, 0); - - outputBuffer.CopyTo(output); - } - finally { ArrayPool.Shared.Return(outputBuffer); } - } + _baseCipher.Decrypt(input, output); } } } diff --git a/src/LibHac/Crypto2/AesXtsCipher.cs b/src/LibHac/Crypto2/AesXtsCipher.cs index d2e84f3f..f7cd4cd4 100644 --- a/src/LibHac/Crypto2/AesXtsCipher.cs +++ b/src/LibHac/Crypto2/AesXtsCipher.cs @@ -1,208 +1,37 @@ using System; -using System.Buffers; -using System.Diagnostics; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using LibHac.Common; +using LibHac.Crypto2.Detail; namespace LibHac.Crypto2 { - public class AesXtsCipher : ICipher + public class AesXtsEncryptor : ICipher { - private ICipher _dataCipher; - private ICipher _tweakCipher; - private Buffer16 _iv; - private bool _decrypting; + private AesXtsMode _baseCipher; - public AesXtsCipher(ReadOnlySpan key1, ReadOnlySpan key2, ReadOnlySpan iv, bool decrypting) + public AesXtsEncryptor(ReadOnlySpan key1, ReadOnlySpan key2, ReadOnlySpan iv) { - Debug.Assert(key1.Length == AesCrypto.KeySize128); - Debug.Assert(key2.Length == AesCrypto.KeySize128); - Debug.Assert(iv.Length == AesCrypto.KeySize128); - - if (decrypting) - { - _dataCipher = new AesEcbDecryptor(key1); - } - else - { - _dataCipher = new AesEcbEncryptor(key1); - } - - _tweakCipher = new AesEcbEncryptor(key2); - - _iv = new Buffer16(); - iv.CopyTo(_iv); - - _decrypting = decrypting; - } - - public void Encrypt(ReadOnlySpan input, Span output) - { - int length = Math.Min(input.Length, output.Length); - int blockCount = length >> 4; - int leftover = length & 0xF; - - // Data units must be at least 1 block long. - if (length < AesCrypto.BlockSize) - throw new ArgumentException(); - - var tweak = new Buffer16(); - - _tweakCipher.Transform(_iv, tweak); - - byte[] tweakBufferRented = ArrayPool.Shared.Rent(blockCount * AesCrypto.BlockSize); - try - { - Span tweakBuffer = tweakBufferRented.AsSpan(0, blockCount * AesCrypto.BlockSize); - tweak = FillTweakBuffer(tweak, MemoryMarshal.Cast(tweakBuffer)); - - Util.XorArrays(output, input, tweakBuffer); - _dataCipher.Transform(output.Slice(0, blockCount * AesCrypto.BlockSize), output); - Util.XorArrays(output, output, tweakBuffer); - } - finally { ArrayPool.Shared.Return(tweakBufferRented); } - - if (leftover != 0) - { - ref Buffer16 inBlock = - ref Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(input)), blockCount); - - ref Buffer16 outBlock = - ref Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(output)), blockCount); - - ref Buffer16 prevOutBlock = ref Unsafe.Subtract(ref outBlock, 1); - - var tmp = new Buffer16(); - - for (int i = 0; i < leftover; i++) - { - outBlock[i] = prevOutBlock[i]; - tmp[i] = inBlock[i]; - } - - for (int i = leftover; i < AesCrypto.BlockSize; i++) - { - tmp[i] = prevOutBlock[i]; - } - - XorBuffer(ref tmp, ref tmp, ref tweak); - _dataCipher.Transform(tmp, tmp); - XorBuffer(ref prevOutBlock, ref tmp, ref tweak); - } - } - - public void Decrypt(ReadOnlySpan input, Span output) - { - int length = Math.Min(input.Length, output.Length); - int blockCount = length >> 4; - int leftover = length & 0xF; - - // Data units must be at least 1 block long. - if (length < AesCrypto.BlockSize) - throw new ArgumentException(); - - if (leftover != 0) blockCount--; - - var tweak = new Buffer16(); - - _tweakCipher.Transform(_iv, tweak); - - if (blockCount > 0) - { - byte[] tweakBufferRented = ArrayPool.Shared.Rent(blockCount * AesCrypto.BlockSize); - try - { - Span tweakBuffer = tweakBufferRented.AsSpan(0, blockCount * AesCrypto.BlockSize); - tweak = FillTweakBuffer(tweak, MemoryMarshal.Cast(tweakBuffer)); - - Util.XorArrays(output, input, tweakBuffer); - _dataCipher.Transform(output.Slice(0, blockCount * AesCrypto.BlockSize), output); - Util.XorArrays(output, output, tweakBuffer); - } - finally { ArrayPool.Shared.Return(tweakBufferRented); } - } - - if (leftover != 0) - { - Buffer16 finalTweak = tweak; - Gf128Mul(ref finalTweak); - - ref Buffer16 inBlock = - ref Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(input)), blockCount); - - ref Buffer16 outBlock = - ref Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(output)), blockCount); - - var tmp = new Buffer16(); - - XorBuffer(ref tmp, ref inBlock, ref finalTweak); - _dataCipher.Transform(tmp, tmp); - XorBuffer(ref outBlock, ref tmp, ref finalTweak); - - ref Buffer16 finalOutBlock = ref Unsafe.Add(ref outBlock, 1); - ref Buffer16 finalInBlock = ref Unsafe.Add(ref inBlock, 1); - - for (int i = 0; i < leftover; i++) - { - finalOutBlock[i] = outBlock[i]; - tmp[i] = finalInBlock[i]; - } - - for (int i = leftover; i < AesCrypto.BlockSize; i++) - { - tmp[i] = outBlock[i]; - } - - XorBuffer(ref tmp, ref tmp, ref tweak); - _dataCipher.Transform(tmp, tmp); - XorBuffer(ref outBlock, ref tmp, ref tweak); - } + _baseCipher = new AesXtsMode(); + _baseCipher.Initialize(key1, key2, iv, false); } public void Transform(ReadOnlySpan input, Span output) { - if (_decrypting) - { - Decrypt(input, output); - } - else - { - Encrypt(input, output); - } + _baseCipher.Encrypt(input, output); + } + } + + public class AesXtsDecryptor : ICipher + { + private AesXtsMode _baseCipher; + + public AesXtsDecryptor(ReadOnlySpan key1, ReadOnlySpan key2, ReadOnlySpan iv) + { + _baseCipher = new AesXtsMode(); + _baseCipher.Initialize(key1, key2, iv, true); } - private static Buffer16 FillTweakBuffer(Buffer16 initialTweak, Span tweakBuffer) + public void Transform(ReadOnlySpan input, Span output) { - for (int i = 0; i < tweakBuffer.Length; i++) - { - tweakBuffer[i] = initialTweak; - Gf128Mul(ref initialTweak); - } - - return initialTweak; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Gf128Mul(ref Buffer16 buffer) - { - Span b = buffer.AsSpan(); - - ulong tt = (ulong)((long)b[1] >> 63) & 0x87; - - b[1] = (b[1] << 1) | (b[0] >> 63); - b[0] = (b[0] << 1) ^ tt; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void XorBuffer(ref Buffer16 output, ref Buffer16 input1, ref Buffer16 input2) - { - Span outputS = output.AsSpan(); - Span input1S = input1.AsSpan(); - Span input2S = input2.AsSpan(); - - outputS[0] = input1S[0] ^ input2S[0]; - outputS[1] = input1S[1] ^ input2S[1]; + _baseCipher.Decrypt(input, output); } } } diff --git a/src/LibHac/Crypto2/Detail/AesCbcMode.cs b/src/LibHac/Crypto2/Detail/AesCbcMode.cs new file mode 100644 index 00000000..302c4ff1 --- /dev/null +++ b/src/LibHac/Crypto2/Detail/AesCbcMode.cs @@ -0,0 +1,26 @@ +using System; +using System.Security.Cryptography; + +namespace LibHac.Crypto2.Detail +{ + public struct AesCbcMode + { + private AesCore _aesCore; + + public void Initialize(ReadOnlySpan key, ReadOnlySpan iv, bool isDecrypting) + { + _aesCore = new AesCore(); + _aesCore.Initialize(key, iv, CipherMode.CBC, isDecrypting); + } + + public void Encrypt(ReadOnlySpan input, Span output) + { + _aesCore.Encrypt(input, output); + } + + public void Decrypt(ReadOnlySpan input, Span output) + { + _aesCore.Decrypt(input, output); + } + } +} diff --git a/src/LibHac/Crypto2/Detail/AesCore.cs b/src/LibHac/Crypto2/Detail/AesCore.cs new file mode 100644 index 00000000..9f1bc2f0 --- /dev/null +++ b/src/LibHac/Crypto2/Detail/AesCore.cs @@ -0,0 +1,74 @@ +using System; +using System.Diagnostics; +using System.Security.Cryptography; +using LibHac.Common; + +namespace LibHac.Crypto2.Detail +{ + public struct AesCore + { + private ICryptoTransform _transform; + private bool _isDecrypting; + + public void Initialize(ReadOnlySpan key, ReadOnlySpan iv, CipherMode mode, bool isDecrypting) + { + Debug.Assert(key.Length == AesCrypto.KeySize128); + Debug.Assert(iv.IsEmpty || iv.Length == AesCrypto.BlockSize); + + Aes aes = Aes.Create(); + + if (aes == null) throw new CryptographicException("Unable to create AES object"); + aes.Key = key.ToArray(); + aes.Mode = mode; + aes.Padding = PaddingMode.None; + + if (!iv.IsEmpty) + { + aes.IV = iv.ToArray(); + } + + _transform = isDecrypting ? aes.CreateDecryptor() : aes.CreateEncryptor(); + _isDecrypting = isDecrypting; + } + + public void Encrypt(ReadOnlySpan input, Span output) + { + Debug.Assert(!_isDecrypting); + Transform(input, output); + } + + public void Decrypt(ReadOnlySpan input, Span output) + { + Debug.Assert(_isDecrypting); + Transform(input, output); + } + + public void Encrypt(byte[] input, byte[] output, int length) + { + Debug.Assert(!_isDecrypting); + Transform(input, output, length); + } + + public void Decrypt(byte[] input, byte[] output, int length) + { + Debug.Assert(_isDecrypting); + Transform(input, output, length); + } + + private void Transform(ReadOnlySpan input, Span output) + { + using var rented = new RentedArray(input.Length); + + input.CopyTo(rented.Array); + + Transform(rented.Array, rented.Array, input.Length); + + rented.Array.CopyTo(output); + } + + private void Transform(byte[] input, byte[] output, int length) + { + _transform.TransformBlock(input, 0, length, output, 0); + } + } +} diff --git a/src/LibHac/Crypto2/Detail/AesCtrMode.cs b/src/LibHac/Crypto2/Detail/AesCtrMode.cs new file mode 100644 index 00000000..fc4943ba --- /dev/null +++ b/src/LibHac/Crypto2/Detail/AesCtrMode.cs @@ -0,0 +1,55 @@ +using System; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Security.Cryptography; +using LibHac.Common; + +namespace LibHac.Crypto2.Detail +{ + public struct AesCtrMode + { + private AesCore _aesCore; + private byte[] _counter; + + public void Initialize(ReadOnlySpan key, ReadOnlySpan iv) + { + Debug.Assert(iv.Length == AesCrypto.BlockSize); + + _aesCore = new AesCore(); + _aesCore.Initialize(key, ReadOnlySpan.Empty, CipherMode.ECB, false); + + _counter = iv.ToArray(); + } + + public void Transform(ReadOnlySpan input, Span output) + { + int blockCount = Util.DivideByRoundUp(input.Length, AesCrypto.BlockSize); + int length = blockCount * AesCrypto.BlockSize; + + using var counterBuffer = new RentedArray(length); + FillDecryptedCounter(_counter, counterBuffer.Span); + + _aesCore.Encrypt(counterBuffer.Array, counterBuffer.Array, length); + Util.XorArrays(output, input, counterBuffer.Span); + } + + private static void FillDecryptedCounter(Span counter, Span buffer) + { + Span bufL = MemoryMarshal.Cast(buffer); + Span counterL = MemoryMarshal.Cast(counter); + + ulong hi = counterL[0]; + ulong lo = BinaryPrimitives.ReverseEndianness(counterL[1]); + + for (int i = 0; i < bufL.Length; i += 2) + { + bufL[i] = hi; + bufL[i + 1] = BinaryPrimitives.ReverseEndianness(lo); + lo++; + } + + counterL[1] = BinaryPrimitives.ReverseEndianness(lo); + } + } +} diff --git a/src/LibHac/Crypto2/Detail/AesEcbMode.cs b/src/LibHac/Crypto2/Detail/AesEcbMode.cs new file mode 100644 index 00000000..8365be0c --- /dev/null +++ b/src/LibHac/Crypto2/Detail/AesEcbMode.cs @@ -0,0 +1,26 @@ +using System; +using System.Security.Cryptography; + +namespace LibHac.Crypto2.Detail +{ + public struct AesEcbMode + { + private AesCore _aesCore; + + public void Initialize(ReadOnlySpan key, bool isDecrypting) + { + _aesCore = new AesCore(); + _aesCore.Initialize(key, ReadOnlySpan.Empty, CipherMode.ECB, isDecrypting); + } + + public void Encrypt(ReadOnlySpan input, Span output) + { + _aesCore.Encrypt(input, output); + } + + public void Decrypt(ReadOnlySpan input, Span output) + { + _aesCore.Decrypt(input, output); + } + } +} diff --git a/src/LibHac/Crypto2/Detail/AesXtsMode.cs b/src/LibHac/Crypto2/Detail/AesXtsMode.cs new file mode 100644 index 00000000..ca8805bb --- /dev/null +++ b/src/LibHac/Crypto2/Detail/AesXtsMode.cs @@ -0,0 +1,175 @@ +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Security.Cryptography; +using LibHac.Common; + +namespace LibHac.Crypto2.Detail +{ + public struct AesXtsMode + { + private AesCore _dataAesCore; + private AesCore _tweakAesCore; + private byte[] _iv; + + public void Initialize(ReadOnlySpan key1, ReadOnlySpan key2, ReadOnlySpan iv, bool isDecrypting) + { + Debug.Assert(iv.Length == AesCrypto.BlockSize); + + _dataAesCore = new AesCore(); + _tweakAesCore = new AesCore(); + + _dataAesCore.Initialize(key1, ReadOnlySpan.Empty, CipherMode.ECB, isDecrypting); + _tweakAesCore.Initialize(key2, ReadOnlySpan.Empty, CipherMode.ECB, false); + + _iv = iv.ToArray(); + } + + public void Encrypt(ReadOnlySpan input, Span output) + { + int length = Math.Min(input.Length, output.Length); + int blockCount = length >> 4; + int leftover = length & 0xF; + + // Data units must be at least 1 block long. + if (length < AesCrypto.BlockSize) + throw new ArgumentException(); + + var tweak = new Buffer16(); + + _tweakAesCore.Encrypt(_iv, tweak); + + using var tweakBuffer = new RentedArray(blockCount * AesCrypto.BlockSize); + tweak = FillTweakBuffer(tweak, MemoryMarshal.Cast(tweakBuffer.Span)); + + Util.XorArrays(output, input, tweakBuffer.Span); + _dataAesCore.Encrypt(output.Slice(0, blockCount * AesCrypto.BlockSize), output); + Util.XorArrays(output, output, tweakBuffer.Array); + + if (leftover != 0) + { + ref Buffer16 inBlock = + ref Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(input)), blockCount); + + ref Buffer16 outBlock = + ref Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(output)), blockCount); + + ref Buffer16 prevOutBlock = ref Unsafe.Subtract(ref outBlock, 1); + + var tmp = new Buffer16(); + + for (int i = 0; i < leftover; i++) + { + outBlock[i] = prevOutBlock[i]; + tmp[i] = inBlock[i]; + } + + for (int i = leftover; i < AesCrypto.BlockSize; i++) + { + tmp[i] = prevOutBlock[i]; + } + + XorBuffer(ref tmp, ref tmp, ref tweak); + _dataAesCore.Encrypt(tmp, tmp); + XorBuffer(ref prevOutBlock, ref tmp, ref tweak); + } + } + + public void Decrypt(ReadOnlySpan input, Span output) + { + int length = Math.Min(input.Length, output.Length); + int blockCount = length >> 4; + int leftover = length & 0xF; + + // Data units must be at least 1 block long. + if (length < AesCrypto.BlockSize) + throw new ArgumentException(); + + if (leftover != 0) blockCount--; + + var tweak = new Buffer16(); + + _tweakAesCore.Encrypt(_iv, tweak); + + if (blockCount > 0) + { + using var tweakBuffer = new RentedArray(blockCount * AesCrypto.BlockSize); + tweak = FillTweakBuffer(tweak, MemoryMarshal.Cast(tweakBuffer.Span)); + + Util.XorArrays(output, input, tweakBuffer.Span); + _dataAesCore.Decrypt(output.Slice(0, blockCount * AesCrypto.BlockSize), output); + Util.XorArrays(output, output, tweakBuffer.Span); + } + + if (leftover != 0) + { + Buffer16 finalTweak = tweak; + Gf128Mul(ref finalTweak); + + ref Buffer16 inBlock = + ref Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(input)), blockCount); + + ref Buffer16 outBlock = + ref Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(output)), blockCount); + + var tmp = new Buffer16(); + + XorBuffer(ref tmp, ref inBlock, ref finalTweak); + _dataAesCore.Decrypt(tmp, tmp); + XorBuffer(ref outBlock, ref tmp, ref finalTweak); + + ref Buffer16 finalOutBlock = ref Unsafe.Add(ref outBlock, 1); + ref Buffer16 finalInBlock = ref Unsafe.Add(ref inBlock, 1); + + for (int i = 0; i < leftover; i++) + { + finalOutBlock[i] = outBlock[i]; + tmp[i] = finalInBlock[i]; + } + + for (int i = leftover; i < AesCrypto.BlockSize; i++) + { + tmp[i] = outBlock[i]; + } + + XorBuffer(ref tmp, ref tmp, ref tweak); + _dataAesCore.Decrypt(tmp, tmp); + XorBuffer(ref outBlock, ref tmp, ref tweak); + } + } + + private static Buffer16 FillTweakBuffer(Buffer16 initialTweak, Span tweakBuffer) + { + for (int i = 0; i < tweakBuffer.Length; i++) + { + tweakBuffer[i] = initialTweak; + Gf128Mul(ref initialTweak); + } + + return initialTweak; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void Gf128Mul(ref Buffer16 buffer) + { + Span b = buffer.AsSpan(); + + ulong tt = (ulong)((long)b[1] >> 63) & 0x87; + + b[1] = (b[1] << 1) | (b[0] >> 63); + b[0] = (b[0] << 1) ^ tt; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void XorBuffer(ref Buffer16 output, ref Buffer16 input1, ref Buffer16 input2) + { + Span outputS = output.AsSpan(); + Span input1S = input1.AsSpan(); + Span input2S = input2.AsSpan(); + + outputS[0] = input1S[0] ^ input2S[0]; + outputS[1] = input1S[1] ^ input2S[1]; + } + } +}