Make AES crypto return the number of bytes written

This commit is contained in:
Alex Barney 2022-03-03 16:16:08 -07:00
parent b9e2e0863b
commit c9352fcb5a
21 changed files with 131 additions and 114 deletions

View file

@ -300,7 +300,7 @@ public class Package1
return Result.Success;
}
private delegate void Decryptor(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
private delegate int Decryptor(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false);
private bool TryFindEristaKeyRevision()

View file

@ -102,7 +102,7 @@ public static class Aes
return new AesXtsEncryptor(key1, key2, iv);
}
public static void EncryptEcb128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
public static int EncryptEcb128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
bool preferDotNetCrypto = false)
{
if (IsAesNiSupported() && !preferDotNetCrypto)
@ -110,16 +110,15 @@ public static class Aes
Unsafe.SkipInit(out AesEcbModeNi cipherNi);
cipherNi.Initialize(key, false);
cipherNi.Encrypt(input, output);
return;
return cipherNi.Encrypt(input, output);
}
ICipher cipher = CreateEcbEncryptor(key, preferDotNetCrypto);
cipher.Transform(input, output);
return cipher.Transform(input, output);
}
public static void DecryptEcb128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
public static int DecryptEcb128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
bool preferDotNetCrypto = false)
{
if (IsAesNiSupported() && !preferDotNetCrypto)
@ -127,16 +126,15 @@ public static class Aes
Unsafe.SkipInit(out AesEcbModeNi cipherNi);
cipherNi.Initialize(key, true);
cipherNi.Decrypt(input, output);
return;
return cipherNi.Decrypt(input, output);
}
ICipher cipher = CreateEcbDecryptor(key, preferDotNetCrypto);
cipher.Transform(input, output);
return cipher.Transform(input, output);
}
public static void EncryptCbc128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
public static int EncryptCbc128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false)
{
if (IsAesNiSupported() && !preferDotNetCrypto)
@ -144,16 +142,15 @@ public static class Aes
Unsafe.SkipInit(out AesCbcModeNi cipherNi);
cipherNi.Initialize(key, iv, false);
cipherNi.Encrypt(input, output);
return;
return cipherNi.Encrypt(input, output);
}
ICipher cipher = CreateCbcEncryptor(key, iv, preferDotNetCrypto);
cipher.Transform(input, output);
return cipher.Transform(input, output);
}
public static void DecryptCbc128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
public static int DecryptCbc128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false)
{
if (IsAesNiSupported() && !preferDotNetCrypto)
@ -161,16 +158,15 @@ public static class Aes
Unsafe.SkipInit(out AesCbcModeNi cipherNi);
cipherNi.Initialize(key, iv, true);
cipherNi.Decrypt(input, output);
return;
return cipherNi.Decrypt(input, output);
}
ICipher cipher = CreateCbcDecryptor(key, iv, preferDotNetCrypto);
cipher.Transform(input, output);
return cipher.Transform(input, output);
}
public static void EncryptCtr128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
public static int EncryptCtr128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false)
{
if (IsAesNiSupported() && !preferDotNetCrypto)
@ -178,16 +174,15 @@ public static class Aes
Unsafe.SkipInit(out AesCtrModeNi cipherNi);
cipherNi.Initialize(key, iv);
cipherNi.Transform(input, output);
return;
return cipherNi.Transform(input, output);
}
ICipher cipher = CreateCtrEncryptor(key, iv, preferDotNetCrypto);
cipher.Transform(input, output);
return cipher.Transform(input, output);
}
public static void DecryptCtr128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
public static int DecryptCtr128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false)
{
if (IsAesNiSupported() && !preferDotNetCrypto)
@ -195,16 +190,15 @@ public static class Aes
Unsafe.SkipInit(out AesCtrModeNi cipherNi);
cipherNi.Initialize(key, iv);
cipherNi.Transform(input, output);
return;
return cipherNi.Transform(input, output);
}
ICipher cipher = CreateCtrDecryptor(key, iv, preferDotNetCrypto);
cipher.Transform(input, output);
return cipher.Transform(input, output);
}
public static void EncryptXts128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key1,
public static int EncryptXts128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key1,
ReadOnlySpan<byte> key2, ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false)
{
if (IsAesNiSupported() && !preferDotNetCrypto)
@ -212,16 +206,15 @@ public static class Aes
Unsafe.SkipInit(out AesXtsModeNi cipherNi);
cipherNi.Initialize(key1, key2, iv, false);
cipherNi.Encrypt(input, output);
return;
return cipherNi.Encrypt(input, output);
}
ICipher cipher = CreateXtsEncryptor(key1, key2, iv, preferDotNetCrypto);
cipher.Transform(input, output);
return cipher.Transform(input, output);
}
public static void DecryptXts128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key1,
public static int DecryptXts128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key1,
ReadOnlySpan<byte> key2, ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false)
{
if (IsAesNiSupported() && !preferDotNetCrypto)
@ -229,13 +222,12 @@ public static class Aes
Unsafe.SkipInit(out AesXtsModeNi cipherNi);
cipherNi.Initialize(key1, key2, iv, true);
cipherNi.Decrypt(input, output);
return;
return cipherNi.Decrypt(input, output);
}
ICipher cipher = CreateXtsDecryptor(key1, key2, iv, preferDotNetCrypto);
cipher.Transform(input, output);
return cipher.Transform(input, output);
}
/// <summary>
@ -307,4 +299,4 @@ public static class Aes
carry = (byte)((b & 0xff00) >> 8);
}
}
}
}

View file

@ -13,9 +13,9 @@ public class AesCbcEncryptor : ICipher
_baseCipher.Initialize(key, iv, false);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Encrypt(input, output);
return _baseCipher.Encrypt(input, output);
}
}
@ -29,8 +29,8 @@ public class AesCbcDecryptor : ICipher
_baseCipher.Initialize(key, iv, true);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Decrypt(input, output);
return _baseCipher.Decrypt(input, output);
}
}
}

View file

@ -18,9 +18,9 @@ public class AesCbcEncryptorNi : ICipherWithIv
_baseCipher.Initialize(key, iv, false);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Encrypt(input, output);
return _baseCipher.Encrypt(input, output);
}
}
@ -36,8 +36,8 @@ public class AesCbcDecryptorNi : ICipherWithIv
_baseCipher.Initialize(key, iv, true);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Decrypt(input, output);
return _baseCipher.Decrypt(input, output);
}
}
}

View file

@ -16,8 +16,8 @@ public class AesCtrCipher : ICipherWithIv
_baseCipher.Initialize(key, iv);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Transform(input, output);
return _baseCipher.Transform(input, output);
}
}
}

View file

@ -18,8 +18,8 @@ public class AesCtrCipherNi : ICipherWithIv
_baseCipher.Initialize(key, iv);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Transform(input, output);
return _baseCipher.Transform(input, output);
}
}
}

View file

@ -13,9 +13,9 @@ public class AesEcbEncryptor : ICipher
_baseCipher.Initialize(key, false);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Encrypt(input, output);
return _baseCipher.Encrypt(input, output);
}
}
@ -29,8 +29,8 @@ public class AesEcbDecryptor : ICipher
_baseCipher.Initialize(key, true);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Decrypt(input, output);
return _baseCipher.Decrypt(input, output);
}
}
}

View file

@ -13,9 +13,9 @@ public class AesEcbEncryptorNi : ICipher
_baseCipher.Initialize(key, false);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Encrypt(input, output);
return _baseCipher.Encrypt(input, output);
}
}
@ -29,8 +29,8 @@ public class AesEcbDecryptorNi : ICipher
_baseCipher.Initialize(key, true);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Decrypt(input, output);
return _baseCipher.Decrypt(input, output);
}
}
}

View file

@ -16,9 +16,9 @@ public class AesXtsEncryptor : ICipherWithIv
_baseCipher.Initialize(key1, key2, iv, false);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Encrypt(input, output);
return _baseCipher.Encrypt(input, output);
}
}
@ -34,8 +34,8 @@ public class AesXtsDecryptor : ICipherWithIv
_baseCipher.Initialize(key1, key2, iv, true);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Decrypt(input, output);
return _baseCipher.Decrypt(input, output);
}
}
}

View file

@ -18,9 +18,9 @@ public class AesXtsEncryptorNi : ICipherWithIv
_baseCipher.Initialize(key1, key2, iv, false);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Encrypt(input, output);
return _baseCipher.Encrypt(input, output);
}
}
@ -36,8 +36,8 @@ public class AesXtsDecryptorNi : ICipherWithIv
_baseCipher.Initialize(key1, key2, iv, true);
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
_baseCipher.Decrypt(input, output);
return _baseCipher.Decrypt(input, output);
}
}
}

View file

@ -5,10 +5,10 @@ namespace LibHac.Crypto;
public interface ICipher
{
void Transform(ReadOnlySpan<byte> input, Span<byte> output);
int Transform(ReadOnlySpan<byte> input, Span<byte> output);
}
public interface ICipherWithIv : ICipher
{
ref Buffer16 Iv { get; }
}
}

View file

@ -13,13 +13,13 @@ public struct AesCbcMode
_aesCore.Initialize(key, iv, CipherMode.CBC, isDecrypting);
}
public void Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
_aesCore.Encrypt(input, output);
return _aesCore.Encrypt(input, output);
}
public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
_aesCore.Decrypt(input, output);
return _aesCore.Decrypt(input, output);
}
}
}

View file

@ -22,7 +22,7 @@ public struct AesCbcModeNi
Iv = Unsafe.ReadUnaligned<Vector128<byte>>(ref MemoryMarshal.GetReference(iv));
}
public void Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
int blockCount = Math.Min(input.Length, output.Length) >> 4;
@ -42,9 +42,11 @@ public struct AesCbcModeNi
}
Iv = iv;
return Math.Min(input.Length, output.Length) & ~0xF;
}
public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
int remainingBlocks = Math.Min(input.Length, output.Length) >> 4;
@ -104,5 +106,7 @@ public struct AesCbcModeNi
}
Iv = iv;
return Math.Min(input.Length, output.Length) & ~0xF;
}
}

View file

@ -31,43 +31,45 @@ public struct AesCore
_isDecrypting = isDecrypting;
}
public void Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
Debug.Assert(!_isDecrypting);
Transform(input, output);
return Transform(input, output);
}
public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
Debug.Assert(_isDecrypting);
Transform(input, output);
return Transform(input, output);
}
public void Encrypt(byte[] input, byte[] output, int length)
public int Encrypt(byte[] input, byte[] output, int length)
{
Debug.Assert(!_isDecrypting);
Transform(input, output, length);
return Transform(input, output, length);
}
public void Decrypt(byte[] input, byte[] output, int length)
public int Decrypt(byte[] input, byte[] output, int length)
{
Debug.Assert(_isDecrypting);
Transform(input, output, length);
return Transform(input, output, length);
}
private void Transform(ReadOnlySpan<byte> input, Span<byte> output)
private int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
using var rented = new RentedArray<byte>(input.Length);
input.CopyTo(rented.Array);
Transform(rented.Array, rented.Array, input.Length);
int bytesWritten = Transform(rented.Array, rented.Array, input.Length);
rented.Array.CopyTo(output);
return bytesWritten;
}
private void Transform(byte[] input, byte[] output, int length)
private int Transform(byte[] input, byte[] output, int length)
{
_transform.TransformBlock(input, 0, length, output, 0);
return _transform.TransformBlock(input, 0, length, output, 0);
}
}
}

View file

@ -98,9 +98,10 @@ public struct AesCoreNi
return AesNi.DecryptLast(b, keys[0]);
}
public readonly void EncryptInterleaved8(ReadOnlySpan<byte> input, Span<byte> output)
public readonly int EncryptInterleaved8(ReadOnlySpan<byte> input, Span<byte> output)
{
int remainingBlocks = Math.Min(input.Length, output.Length) >> 4;
int length = remainingBlocks << 4;
ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input));
ref Vector128<byte> outBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(output));
@ -138,11 +139,14 @@ public struct AesCoreNi
outBlock = ref Unsafe.Add(ref outBlock, 1);
remainingBlocks -= 1;
}
return length;
}
public readonly void DecryptInterleaved8(ReadOnlySpan<byte> input, Span<byte> output)
public readonly int DecryptInterleaved8(ReadOnlySpan<byte> input, Span<byte> output)
{
int remainingBlocks = Math.Min(input.Length, output.Length) >> 4;
int length = remainingBlocks << 4;
ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input));
ref Vector128<byte> outBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(output));
@ -180,6 +184,8 @@ public struct AesCoreNi
outBlock = ref Unsafe.Add(ref outBlock, 1);
remainingBlocks -= 1;
}
return length;
}
// When inlining this function, RyuJIT will almost make the
@ -574,4 +580,4 @@ public struct AesCoreNi
return Sse2.Xor(s, t);
}
}
}

View file

@ -24,7 +24,7 @@ public struct AesCtrMode
Iv = Unsafe.ReadUnaligned<Buffer16>(ref MemoryMarshal.GetReference(iv));
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
int blockCount = BitUtil.DivideUp(input.Length, Aes.BlockSize);
int length = blockCount * Aes.BlockSize;
@ -34,6 +34,8 @@ public struct AesCtrMode
_aesCore.Encrypt(counterBuffer.Array, counterBuffer.Array, length);
Utilities.XorArrays(output, input, counterBuffer.Span);
return Math.Min(input.Length, output.Length);
}
private static void FillDecryptedCounter(Span<byte> counter, Span<byte> buffer)
@ -53,4 +55,4 @@ public struct AesCtrMode
counterL[1] = BinaryPrimitives.ReverseEndianness(lo);
}
}
}

View file

@ -23,9 +23,10 @@ public struct AesCtrModeNi
Iv = Unsafe.ReadUnaligned<Vector128<byte>>(ref MemoryMarshal.GetReference(iv));
}
public void Transform(ReadOnlySpan<byte> input, Span<byte> output)
public int Transform(ReadOnlySpan<byte> input, Span<byte> output)
{
int remaining = Math.Min(input.Length, output.Length);
int length = Math.Min(input.Length, output.Length);
int remaining = length;
int blockCount = remaining >> 4;
ref Vector128<byte> inBlock = ref Unsafe.As<byte, Vector128<byte>>(ref MemoryMarshal.GetReference(input));
@ -103,6 +104,8 @@ public struct AesCtrModeNi
{
EncryptCtrPartialBlock(input.Slice(blockCount * 0x10), output.Slice(blockCount * 0x10));
}
return length;
}
private void EncryptCtrPartialBlock(ReadOnlySpan<byte> input, Span<byte> output)

View file

@ -13,13 +13,13 @@ public struct AesEcbMode
_aesCore.Initialize(key, ReadOnlySpan<byte>.Empty, CipherMode.ECB, isDecrypting);
}
public void Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
_aesCore.Encrypt(input, output);
return _aesCore.Encrypt(input, output);
}
public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
_aesCore.Decrypt(input, output);
return _aesCore.Decrypt(input, output);
}
}
}

View file

@ -11,13 +11,13 @@ public struct AesEcbModeNi
_aesCore.Initialize(key, isDecrypting);
}
public void Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
_aesCore.EncryptInterleaved8(input, output);
return _aesCore.EncryptInterleaved8(input, output);
}
public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
_aesCore.DecryptInterleaved8(input, output);
return _aesCore.DecryptInterleaved8(input, output);
}
}
}

View file

@ -26,7 +26,7 @@ public struct AesXtsMode
Iv = Unsafe.ReadUnaligned<Buffer16>(ref MemoryMarshal.GetReference(iv));
}
public void Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
int length = Math.Min(input.Length, output.Length);
int blockCount = length >> 4;
@ -74,9 +74,11 @@ public struct AesXtsMode
_dataAesCore.Encrypt(tmp, tmp);
XorBuffer(ref prevOutBlock, ref tmp, ref tweak);
}
return length;
}
public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
int length = Math.Min(input.Length, output.Length);
int blockCount = length >> 4;
@ -139,6 +141,8 @@ public struct AesXtsMode
_dataAesCore.Decrypt(tmp, tmp);
XorBuffer(ref outBlock, ref tmp, ref tweak);
}
return length;
}
private static Buffer16 FillTweakBuffer(Buffer16 initialTweak, Span<Buffer16> tweakBuffer)

View file

@ -25,7 +25,7 @@ public struct AesXtsModeNi
Iv = Unsafe.ReadUnaligned<Vector128<byte>>(ref MemoryMarshal.GetReference(iv));
}
public void Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Encrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
int length = Math.Min(input.Length, output.Length);
int remainingBlocks = length >> 4;
@ -101,9 +101,11 @@ public struct AesXtsModeNi
{
EncryptPartialFinalBlock(ref inBlock, ref outBlock, tweak, leftover);
}
return length;
}
public void Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
public int Decrypt(ReadOnlySpan<byte> input, Span<byte> output)
{
int length = Math.Min(input.Length, output.Length);
int remainingBlocks = length >> 4;
@ -181,6 +183,8 @@ public struct AesXtsModeNi
{
DecryptPartialFinalBlock(ref inBlock, ref outBlock, tweak, mask, leftover);
}
return length;
}
// ReSharper disable once RedundantAssignment
@ -251,4 +255,4 @@ public struct AesXtsModeNi
return Sse2.Xor(tmp1, tmp2);
}
}
}