Add functions to encrypt/decrypt entire buffers

This commit is contained in:
Alex Barney 2019-11-16 14:32:01 -07:00
parent 8b47be19c2
commit df646fb503
3 changed files with 202 additions and 33 deletions

View file

@ -104,5 +104,69 @@ namespace LibHac.Crypto2
#endif #endif
return new AesXtsCipher(key1, key2, iv, false); return new AesXtsCipher(key1, key2, iv, false);
} }
public static void EncryptEcb128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
bool preferDotNetCrypto = false)
{
ICipher cipher = CreateEcbEncryptor(key, preferDotNetCrypto);
cipher.Transform(input, output);
}
public static void DecryptEcb128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
bool preferDotNetCrypto = false)
{
ICipher cipher = CreateEcbDecryptor(key, preferDotNetCrypto);
cipher.Transform(input, output);
}
public static void EncryptCbc128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false)
{
ICipher cipher = CreateCbcEncryptor(key, iv, preferDotNetCrypto);
cipher.Transform(input, output);
}
public static void DecryptCbc128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false)
{
ICipher cipher = CreateCbcDecryptor(key, iv, preferDotNetCrypto);
cipher.Transform(input, output);
}
public static void EncryptCtr128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false)
{
ICipher cipher = CreateCtrEncryptor(key, iv, preferDotNetCrypto);
cipher.Transform(input, output);
}
public static void DecryptCtr128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key,
ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false)
{
ICipher cipher = CreateCtrDecryptor(key, iv, preferDotNetCrypto);
cipher.Transform(input, output);
}
public static void EncryptXts128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key1,
ReadOnlySpan<byte> key2, ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false)
{
ICipher cipher = CreateXtsEncryptor(key1, key2, iv, preferDotNetCrypto);
cipher.Transform(input, output);
}
public static void DecryptXts128(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key1,
ReadOnlySpan<byte> key2, ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false)
{
ICipher cipher = CreateXtsDecryptor(key1, key2, iv, preferDotNetCrypto);
cipher.Transform(input, output);
}
} }
} }

View file

@ -104,18 +104,38 @@ namespace LibHac.Crypto2
[MethodImpl(MethodImplOptions.AggressiveOptimization)] [MethodImpl(MethodImplOptions.AggressiveOptimization)]
private static void KeyExpansion(ReadOnlySpan<byte> key, Span<Vector128<byte>> roundKeys, bool isDecrypting) private static void KeyExpansion(ReadOnlySpan<byte> key, Span<Vector128<byte>> roundKeys, bool isDecrypting)
{ {
roundKeys[0] = Unsafe.ReadUnaligned<Vector128<byte>>(ref MemoryMarshal.GetReference(key)); var curKey = Unsafe.ReadUnaligned<Vector128<byte>>(ref MemoryMarshal.GetReference(key));
roundKeys[0] = curKey;
MakeRoundKey(roundKeys, 1, 0x01); curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x01));
MakeRoundKey(roundKeys, 2, 0x02); roundKeys[1] = curKey;
MakeRoundKey(roundKeys, 3, 0x04);
MakeRoundKey(roundKeys, 4, 0x08); curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x02));
MakeRoundKey(roundKeys, 5, 0x10); roundKeys[2] = curKey;
MakeRoundKey(roundKeys, 6, 0x20);
MakeRoundKey(roundKeys, 7, 0x40); curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x04));
MakeRoundKey(roundKeys, 8, 0x80); roundKeys[3] = curKey;
MakeRoundKey(roundKeys, 9, 0x1b);
MakeRoundKey(roundKeys, 10, 0x36); curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x08));
roundKeys[4] = curKey;
curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x10));
roundKeys[5] = curKey;
curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x20));
roundKeys[6] = curKey;
curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x40));
roundKeys[7] = curKey;
curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x80));
roundKeys[8] = curKey;
curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x1b));
roundKeys[9] = curKey;
curKey = KeyExpansion(curKey, Aes.KeygenAssist(curKey, 0x36));
roundKeys[10] = curKey;
if (isDecrypting) if (isDecrypting)
{ {
@ -126,19 +146,15 @@ namespace LibHac.Crypto2
} }
} }
[MethodImpl(MethodImplOptions.AggressiveOptimization)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void MakeRoundKey(Span<Vector128<byte>> keys, int i, byte rcon) private static Vector128<byte> KeyExpansion(Vector128<byte> s, Vector128<byte> t)
{ {
Vector128<byte> s = keys[i - 1];
Vector128<byte> t = keys[i - 1];
t = Aes.KeygenAssist(t, rcon);
t = Sse2.Shuffle(t.AsUInt32(), 0xFF).AsByte(); t = Sse2.Shuffle(t.AsUInt32(), 0xFF).AsByte();
s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 4)); s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 4));
s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 8)); s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 8));
keys[i] = Sse2.Xor(s, t); return Sse2.Xor(s, t);
} }
} }
} }

View file

@ -12,6 +12,8 @@ namespace hactoolnet
{ {
private const int Size = 1024 * 1024 * 10; private const int Size = 1024 * 1024 * 10;
private const int Iterations = 100; private const int Iterations = 100;
private const int BlockSizeBlocked = 0x10;
private const int BlockSizeSeparate = 0x10;
private static void CopyBenchmark(IStorage src, IStorage dst, int iterations, string label, IProgressReport logger) private static void CopyBenchmark(IStorage src, IStorage dst, int iterations, string label, IProgressReport logger)
{ {
@ -82,7 +84,7 @@ namespace hactoolnet
logger.SetTotal(iterations); logger.SetTotal(iterations);
int blockCount = src.Length / 0x10; int blockCount = src.Length / BlockSizeBlocked;
for (int i = 0; i < iterations; i++) for (int i = 0; i < iterations; i++)
{ {
@ -92,7 +94,8 @@ namespace hactoolnet
for (int b = 0; b < blockCount; b++) for (int b = 0; b < blockCount; b++)
{ {
cipher.Transform(src.Slice(b * 0x10, 0x10), dst.Slice(b * 0x10, 0x10)); cipher.Transform(src.Slice(b * BlockSizeBlocked, BlockSizeBlocked),
dst.Slice(b * BlockSizeBlocked, BlockSizeBlocked));
} }
watch.Stop(); watch.Stop();
@ -116,8 +119,60 @@ namespace hactoolnet
logger.LogMessage($"{label}{averageRate}/s, fastest run: {fastestRate}/s, slowest run: {slowestRate}/s"); logger.LogMessage($"{label}{averageRate}/s, fastest run: {fastestRate}/s, slowest run: {slowestRate}/s");
} }
private static void RunCipherBenchmark(Func<ICipher> cipherNet, Func<ICipher> cipherLibHac, bool benchBlocked, private delegate void CipherTaskSeparate(ReadOnlySpan<byte> input, Span<byte> output, ReadOnlySpan<byte> key1,
string label, IProgressReport logger) ReadOnlySpan<byte> key2, ReadOnlySpan<byte> iv, bool preferDotNetCrypto = false);
// Benchmarks encrypting each block separately, initializing a new cipher object for each one
private static void CipherBenchmarkSeparate(ReadOnlySpan<byte> src, Span<byte> dst, CipherTaskSeparate function,
int iterations, string label, bool dotNetCrypto, IProgressReport logger)
{
Debug.Assert(src.Length == dst.Length);
var watch = new Stopwatch();
var runTimes = new double[iterations];
ReadOnlySpan<byte> key1 = stackalloc byte[0x10];
ReadOnlySpan<byte> key2 = stackalloc byte[0x10];
ReadOnlySpan<byte> iv = stackalloc byte[0x10];
logger.SetTotal(iterations);
const int blockSize = BlockSizeSeparate;
int blockCount = src.Length / blockSize;
for (int i = 0; i < iterations; i++)
{
watch.Restart();
for (int b = 0; b < blockCount; b++)
{
function(src.Slice(b * blockSize, blockSize), dst.Slice(b * blockSize, blockSize),
key1, key2, iv, dotNetCrypto);
}
watch.Stop();
logger.ReportAdd(1);
runTimes[i] = watch.Elapsed.TotalSeconds;
}
logger.SetTotal(0);
long srcSize = src.Length;
double fastestRun = runTimes.Min();
double averageRun = runTimes.Average();
double slowestRun = runTimes.Max();
string fastestRate = Util.GetBytesReadable((long)(srcSize / fastestRun));
string averageRate = Util.GetBytesReadable((long)(srcSize / averageRun));
string slowestRate = Util.GetBytesReadable((long)(srcSize / slowestRun));
logger.LogMessage($"{label}{averageRate}/s, fastest run: {fastestRate}/s, slowest run: {slowestRate}/s");
}
private static void RunCipherBenchmark(Func<ICipher> cipherNet, Func<ICipher> cipherLibHac,
CipherTaskSeparate function, bool benchBlocked, string label, IProgressReport logger)
{ {
var srcData = new byte[Size]; var srcData = new byte[Size];
@ -125,19 +180,34 @@ namespace hactoolnet
var dstDataNet = new byte[Size]; var dstDataNet = new byte[Size];
var dstDataBlockedLh = new byte[Size]; var dstDataBlockedLh = new byte[Size];
var dstDataBlockedNet = new byte[Size]; var dstDataBlockedNet = new byte[Size];
var dstDataSeparateLh = new byte[Size];
var dstDataSeparateNet = new byte[Size];
logger.LogMessage(string.Empty); logger.LogMessage(string.Empty);
logger.LogMessage(label); logger.LogMessage(label);
if (AesCrypto.IsAesNiSupported()) CipherBenchmark(srcData, dstDataLh, cipherLibHac, Iterations, "LibHac impl: ", logger); if (AesCrypto.IsAesNiSupported())
CipherBenchmark(srcData, dstDataLh, cipherLibHac, Iterations, "LibHac impl: ", logger);
CipherBenchmark(srcData, dstDataNet, cipherNet, Iterations, ".NET impl: ", logger); CipherBenchmark(srcData, dstDataNet, cipherNet, Iterations, ".NET impl: ", logger);
if (benchBlocked) if (benchBlocked)
{ {
if (AesCrypto.IsAesNiSupported()) if (AesCrypto.IsAesNiSupported())
CipherBenchmarkBlocked(srcData, dstDataBlockedLh, cipherLibHac, Iterations / 5, "LibHac impl (blocked): ", logger); CipherBenchmarkBlocked(srcData, dstDataBlockedLh, cipherLibHac, Iterations / 5,
"LibHac impl (blocked): ", logger);
CipherBenchmarkBlocked(srcData, dstDataBlockedNet, cipherNet, Iterations / 5, ".NET impl (blocked): ", logger); CipherBenchmarkBlocked(srcData, dstDataBlockedNet, cipherNet, Iterations / 5, ".NET impl (blocked): ",
logger);
}
if (function != null)
{
if (AesCrypto.IsAesNiSupported())
CipherBenchmarkSeparate(srcData, dstDataSeparateLh, function, Iterations / 5,
"LibHac impl (separate): ", false, logger);
CipherBenchmarkSeparate(srcData, dstDataSeparateNet, function, Iterations / 20,
".NET impl (separate): ", true, logger);
} }
if (AesCrypto.IsAesNiSupported()) if (AesCrypto.IsAesNiSupported())
@ -149,6 +219,12 @@ namespace hactoolnet
logger.LogMessage($"{dstDataLh.SequenceEqual(dstDataBlockedLh)}"); logger.LogMessage($"{dstDataLh.SequenceEqual(dstDataBlockedLh)}");
logger.LogMessage($"{dstDataLh.SequenceEqual(dstDataBlockedNet)}"); logger.LogMessage($"{dstDataLh.SequenceEqual(dstDataBlockedNet)}");
} }
if (function != null)
{
logger.LogMessage($"{dstDataLh.SequenceEqual(dstDataSeparateLh)}");
logger.LogMessage($"{dstDataLh.SequenceEqual(dstDataSeparateNet)}");
}
} }
} }
@ -158,7 +234,6 @@ namespace hactoolnet
{ {
case "aesctr": case "aesctr":
{ {
IStorage decStorage = new MemoryStorage(new byte[Size]); IStorage decStorage = new MemoryStorage(new byte[Size]);
IStorage encStorage = new Aes128CtrStorage(new MemoryStorage(new byte[Size]), new byte[0x10], new byte[0x10], true); IStorage encStorage = new Aes128CtrStorage(new MemoryStorage(new byte[Size]), new byte[0x10], new byte[0x10], true);
@ -206,13 +281,17 @@ namespace hactoolnet
{ {
Func<ICipher> encryptorNet = () => AesCrypto.CreateEcbEncryptor(new byte[0x10], true); Func<ICipher> encryptorNet = () => AesCrypto.CreateEcbEncryptor(new byte[0x10], true);
Func<ICipher> encryptorLh = () => AesCrypto.CreateEcbEncryptor(new byte[0x10]); Func<ICipher> encryptorLh = () => AesCrypto.CreateEcbEncryptor(new byte[0x10]);
CipherTaskSeparate encrypt = (input, output, key1, key2, iv, crypto) =>
AesCrypto.EncryptEcb128(input, output, key1, crypto);
RunCipherBenchmark(encryptorNet, encryptorLh, true, "AES-ECB encrypt", ctx.Logger); RunCipherBenchmark(encryptorNet, encryptorLh, encrypt, true, "AES-ECB encrypt", ctx.Logger);
Func<ICipher> decryptorNet = () => AesCrypto.CreateEcbDecryptor(new byte[0x10], true); Func<ICipher> decryptorNet = () => AesCrypto.CreateEcbDecryptor(new byte[0x10], true);
Func<ICipher> decryptorLh = () => AesCrypto.CreateEcbDecryptor(new byte[0x10]); Func<ICipher> decryptorLh = () => AesCrypto.CreateEcbDecryptor(new byte[0x10]);
CipherTaskSeparate decrypt = (input, output, key1, key2, iv, crypto) =>
AesCrypto.DecryptEcb128(input, output, key1, crypto);
RunCipherBenchmark(decryptorNet, decryptorLh, true, "AES-ECB decrypt", ctx.Logger); RunCipherBenchmark(decryptorNet, decryptorLh, decrypt, true, "AES-ECB decrypt", ctx.Logger);
break; break;
} }
@ -220,13 +299,17 @@ namespace hactoolnet
{ {
Func<ICipher> encryptorNet = () => AesCrypto.CreateCbcEncryptor(new byte[0x10], new byte[0x10], true); Func<ICipher> encryptorNet = () => AesCrypto.CreateCbcEncryptor(new byte[0x10], new byte[0x10], true);
Func<ICipher> encryptorLh = () => AesCrypto.CreateCbcEncryptor(new byte[0x10], new byte[0x10]); Func<ICipher> encryptorLh = () => AesCrypto.CreateCbcEncryptor(new byte[0x10], new byte[0x10]);
CipherTaskSeparate encrypt = (input, output, key1, key2, iv, crypto) =>
AesCrypto.EncryptCbc128(input, output, key1, iv, crypto);
RunCipherBenchmark(encryptorNet, encryptorLh, true, "AES-CBC encrypt", ctx.Logger); RunCipherBenchmark(encryptorNet, encryptorLh, encrypt, true, "AES-CBC encrypt", ctx.Logger);
Func<ICipher> decryptorNet = () => AesCrypto.CreateCbcDecryptor(new byte[0x10], new byte[0x10], true); Func<ICipher> decryptorNet = () => AesCrypto.CreateCbcDecryptor(new byte[0x10], new byte[0x10], true);
Func<ICipher> decryptorLh = () => AesCrypto.CreateCbcDecryptor(new byte[0x10], new byte[0x10]); Func<ICipher> decryptorLh = () => AesCrypto.CreateCbcDecryptor(new byte[0x10], new byte[0x10]);
CipherTaskSeparate decrypt = (input, output, key1, key2, iv, crypto) =>
AesCrypto.DecryptCbc128(input, output, key1, iv, crypto);
RunCipherBenchmark(decryptorNet, decryptorLh, true, "AES-CBC decrypt", ctx.Logger); RunCipherBenchmark(decryptorNet, decryptorLh, decrypt, true, "AES-CBC decrypt", ctx.Logger);
break; break;
} }
@ -235,8 +318,10 @@ namespace hactoolnet
{ {
Func<ICipher> encryptorNet = () => AesCrypto.CreateCtrEncryptor(new byte[0x10], new byte[0x10], true); Func<ICipher> encryptorNet = () => AesCrypto.CreateCtrEncryptor(new byte[0x10], new byte[0x10], true);
Func<ICipher> encryptorLh = () => AesCrypto.CreateCtrEncryptor(new byte[0x10], new byte[0x10]); Func<ICipher> encryptorLh = () => AesCrypto.CreateCtrEncryptor(new byte[0x10], new byte[0x10]);
CipherTaskSeparate encrypt = (input, output, key1, key2, iv, crypto) =>
AesCrypto.EncryptCtr128(input, output, key1, iv, crypto);
RunCipherBenchmark(encryptorNet, encryptorLh, true, "AES-CTR", ctx.Logger); RunCipherBenchmark(encryptorNet, encryptorLh, encrypt, true, "AES-CTR", ctx.Logger);
break; break;
} }
@ -244,13 +329,17 @@ namespace hactoolnet
{ {
Func<ICipher> encryptorNet = () => AesCrypto.CreateXtsEncryptor(new byte[0x10], new byte[0x10], new byte[0x10], true); Func<ICipher> encryptorNet = () => AesCrypto.CreateXtsEncryptor(new byte[0x10], new byte[0x10], new byte[0x10], true);
Func<ICipher> encryptorLh = () => AesCrypto.CreateXtsEncryptor(new byte[0x10], new byte[0x10], new byte[0x10]); Func<ICipher> encryptorLh = () => AesCrypto.CreateXtsEncryptor(new byte[0x10], new byte[0x10], new byte[0x10]);
CipherTaskSeparate encrypt = (input, output, key1, key2, iv, crypto) =>
AesCrypto.EncryptXts128(input, output, key1, key2, iv, crypto);
RunCipherBenchmark(encryptorNet, encryptorLh, false, "AES-XTS encrypt", ctx.Logger); RunCipherBenchmark(encryptorNet, encryptorLh, encrypt, false, "AES-XTS encrypt", ctx.Logger);
Func<ICipher> decryptorNet = () => AesCrypto.CreateXtsDecryptor(new byte[0x10], new byte[0x10], new byte[0x10], true); Func<ICipher> decryptorNet = () => AesCrypto.CreateXtsDecryptor(new byte[0x10], new byte[0x10], new byte[0x10], true);
Func<ICipher> decryptorLh = () => AesCrypto.CreateXtsDecryptor(new byte[0x10], new byte[0x10], new byte[0x10]); Func<ICipher> decryptorLh = () => AesCrypto.CreateXtsDecryptor(new byte[0x10], new byte[0x10], new byte[0x10]);
CipherTaskSeparate decrypt = (input, output, key1, key2, iv, crypto) =>
AesCrypto.DecryptXts128(input, output, key1, key2, iv, crypto);
RunCipherBenchmark(decryptorNet, decryptorLh, false, "AES-XTS decrypt", ctx.Logger); RunCipherBenchmark(decryptorNet, decryptorLh, decrypt, false, "AES-XTS decrypt", ctx.Logger);
break; break;
} }