From e8d54159e371b52bc469432920887b02534e4fcc Mon Sep 17 00:00:00 2001 From: Alex Barney Date: Fri, 4 Mar 2022 17:35:03 -0700 Subject: [PATCH] Add AesCtrStorage --- src/LibHac/FsSystem/AesCtrStorage.cs | 235 +++++++++++++++++++++++++++ src/LibHac/FsSystem/PooledBuffer.cs | 26 ++- 2 files changed, 260 insertions(+), 1 deletion(-) create mode 100644 src/LibHac/FsSystem/AesCtrStorage.cs diff --git a/src/LibHac/FsSystem/AesCtrStorage.cs b/src/LibHac/FsSystem/AesCtrStorage.cs new file mode 100644 index 00000000..750e4477 --- /dev/null +++ b/src/LibHac/FsSystem/AesCtrStorage.cs @@ -0,0 +1,235 @@ +using System; +using System.Buffers.Binary; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using LibHac.Common; +using LibHac.Common.FixedArrays; +using LibHac.Crypto; +using LibHac.Diag; +using LibHac.Fs; +using LibHac.Util; + +namespace LibHac.FsSystem; + +/// +/// Reads and writes to an that's encrypted with AES-CTR-128. +/// +/// Based on FS 13.1.0 (nnSdk 13.4.0) +public class AesCtrStorage : IStorage +{ + public static readonly int BlockSize = Aes.BlockSize; + public static readonly int KeySize = Aes.KeySize128; + public static readonly int IvSize = Aes.KeySize128; + + private IStorage _baseStorage; + private Array16 _key; + private Array16 _iv; + + // LibHac addition: This field goes unused if initialized with a plain IStorage. + // The original class uses a template for both the shared and non-shared IStorage which avoids needing this field. + private SharedRef _baseStorageShared; + + public static void MakeIv(Span outIv, ulong upperIv, long offset) + { + Assert.SdkRequiresEqual(outIv.Length, IvSize); + Assert.SdkRequiresGreaterEqual(offset, 0); + + BinaryPrimitives.WriteUInt64BigEndian(outIv, upperIv); + BinaryPrimitives.WriteInt64BigEndian(outIv.Slice(sizeof(long)), offset / BlockSize); + } + + public AesCtrStorage(IStorage baseStorage, ReadOnlySpan key, ReadOnlySpan iv) + { + Assert.SdkRequiresNotNull(baseStorage); + Assert.SdkRequiresEqual(key.Length, KeySize); + Assert.SdkRequiresEqual(iv.Length, IvSize); + + _baseStorage = baseStorage; + + key.CopyTo(_key.Items); + iv.CopyTo(_iv.Items); + } + + public AesCtrStorage(in SharedRef baseStorage, ReadOnlySpan key, ReadOnlySpan iv) + { + Assert.SdkRequiresNotNull(in baseStorage); + Assert.SdkRequiresEqual(key.Length, KeySize); + Assert.SdkRequiresEqual(iv.Length, IvSize); + + _baseStorage = baseStorage.Get; + _baseStorageShared = SharedRef.CreateCopy(in baseStorage); + + key.CopyTo(_key.Items); + iv.CopyTo(_iv.Items); + } + + public override void Dispose() + { + _baseStorageShared.Destroy(); + + base.Dispose(); + } + + public override Result Read(long offset, Span destination) + { + if (destination.Length == 0) + return Result.Success; + + // Reads cannot contain any partial blocks. + if (!Alignment.IsAlignedPow2(offset, (uint)BlockSize)) + return ResultFs.InvalidArgument.Log(); + + if (!Alignment.IsAlignedPow2(destination.Length, (uint)BlockSize)) + return ResultFs.InvalidArgument.Log(); + + Result rc = _baseStorage.Read(offset, destination); + if (rc.IsFailure()) return rc.Miss(); + + using var changePriority = new ScopedThreadPriorityChanger(1, ScopedThreadPriorityChanger.Mode.Relative); + + Array16 counter = _iv; + Utility.AddCounter(counter.Items, (ulong)offset / (uint)BlockSize); + + int decSize = Aes.DecryptCtr128(destination, destination, _key, counter); + if (decSize != destination.Length) + return ResultFs.UnexpectedInAesCtrStorageA.Log(); + + return Result.Success; + } + + public override Result Write(long offset, ReadOnlySpan source) + { + if (source.Length == 0) + return Result.Success; + + // We can only write full, aligned blocks. + if (!Alignment.IsAlignedPow2(offset, (uint)BlockSize)) + return ResultFs.InvalidArgument.Log(); + + if (!Alignment.IsAlignedPow2(source.Length, (uint)BlockSize)) + return ResultFs.InvalidArgument.Log(); + + // Get a pooled buffer. + // Note: The original code will const_cast the input buffer and encrypt the data in-place if the provided + // buffer is from the pooled buffer heap. This seems very error-prone since the data in buffers you pass + // as const might unexpectedly be modified. We make IsDeviceAddress() always return false + // so this won't happen, but the code that does the encryption in-place will be left in as a reference. + using var pooledBuffer = new PooledBuffer(); + bool useWorkBuffer = PooledBufferGlobalMethods.IsDeviceAddress(source); + if (useWorkBuffer) + pooledBuffer.Allocate(source.Length, BlockSize); + + // Setup the counter. + var counter = new Array16(); + Utility.AddCounter(counter.Items, (ulong)offset / (uint)BlockSize); + + // Loop until all data is written. + int remaining = source.Length; + int currentOffset = 0; + + while (remaining > 0) + { + // Determine data we're writing and where. + int writeSize = useWorkBuffer ? Math.Max(pooledBuffer.GetSize(), remaining) : remaining; + Span writeBuffer = useWorkBuffer + ? pooledBuffer.GetBuffer().Slice(0, writeSize) + : MemoryMarshal.CreateSpan(ref MemoryMarshal.GetReference(source), source.Length).Slice(0, writeSize); + + // Encrypt the data, with temporarily increased priority. + using (new ScopedThreadPriorityChanger(1, ScopedThreadPriorityChanger.Mode.Relative)) + { + int encSize = Aes.EncryptCtr128(source.Slice(currentOffset, writeSize), writeBuffer, _key, _iv); + if (encSize != writeSize) + return ResultFs.UnexpectedInAesCtrStorageA.Log(); + } + + // Write the encrypted data. + Result rc = _baseStorage.Write(offset + currentOffset, writeBuffer); + if (rc.IsFailure()) return rc.Miss(); + + // Advance. + currentOffset += writeSize; + remaining -= writeSize; + if (remaining > 0) + { + Utility.AddCounter(counter.Items, (uint)writeSize / (uint)BlockSize); + } + } + + return Result.Success; + } + + public override Result Flush() + { + return _baseStorage.Flush(); + } + + public override Result SetSize(long size) + { + return ResultFs.UnsupportedSetSizeForAesCtrStorage.Log(); + } + + public override Result GetSize(out long size) + { + return _baseStorage.GetSize(out size); + } + + public override Result OperateRange(Span outBuffer, OperationId operationId, long offset, long size, + ReadOnlySpan inBuffer) + { + if (operationId != OperationId.InvalidateCache) + { + if (size == 0) + { + if (operationId == OperationId.QueryRange) + { + if (outBuffer.Length != Unsafe.SizeOf()) + return ResultFs.InvalidSize.Log(); + + Unsafe.As(ref MemoryMarshal.GetReference(outBuffer)).Clear(); + } + + return Result.Success; + } + + if (!Alignment.IsAlignedPow2(offset, (uint)BlockSize)) + return ResultFs.InvalidArgument.Log(); + + if (!Alignment.IsAlignedPow2(size, (uint)BlockSize)) + return ResultFs.InvalidArgument.Log(); + } + + switch (operationId) + { + case OperationId.QueryRange: + { + if (outBuffer.Length != Unsafe.SizeOf()) + return ResultFs.InvalidSize.Log(); + + ref QueryRangeInfo outInfo = + ref Unsafe.As(ref MemoryMarshal.GetReference(outBuffer)); + + // Get the QueryRangeInfo of the underlying base storage. + Result rc = _baseStorage.OperateRange(outBuffer, operationId, offset, size, inBuffer); + if (rc.IsFailure()) return rc.Miss(); + + Unsafe.SkipInit(out QueryRangeInfo info); + info.Clear(); + info.AesCtrKeyType = (int)QueryRangeInfo.AesCtrKeyTypeFlag.InternalKeyForSoftwareAes; + + outInfo.Merge(in info); + + break; + } + default: + { + Result rc = _baseStorage.OperateRange(outBuffer, operationId, offset, size, inBuffer); + if (rc.IsFailure()) return rc.Miss(); + + break; + } + } + + return Result.Success; + } +} \ No newline at end of file diff --git a/src/LibHac/FsSystem/PooledBuffer.cs b/src/LibHac/FsSystem/PooledBuffer.cs index 1c0437f1..154db9c9 100644 --- a/src/LibHac/FsSystem/PooledBuffer.cs +++ b/src/LibHac/FsSystem/PooledBuffer.cs @@ -8,6 +8,12 @@ namespace LibHac.FsSystem; public static class PooledBufferGlobalMethods { + // ReSharper disable once UnusedParameter.Global + public static bool IsPooledBuffer(ReadOnlySpan buffer) + { + return false; + } + public static long GetPooledBufferRetriedCount(this FileSystemServer fsSrv) { return fsSrv.Globals.PooledBuffer.CountRetried; @@ -42,6 +48,24 @@ public static class PooledBufferGlobalMethods g.CountReduceAllocation = 0; g.CountFailedIdealAllocationOnAsyncAccess = 0; } + + public static bool IsAdditionalDeviceAddress(ReadOnlySpan buffer) + { + return false; + } + + // ReSharper disable once UnusedParameter.Global + /// + /// Checks if the provided buffer is located at a "device address". + /// + /// The buffer to check. + /// if this is a device address; otherwise . + /// A device address is one that is either located in the pooled buffer heap + /// or in any of the registered additional device address ranges. + public static bool IsDeviceAddress(ReadOnlySpan buffer) + { + return IsPooledBuffer(buffer) || IsAdditionalDeviceAddress(buffer); + } } internal struct PooledBufferGlobals @@ -168,4 +192,4 @@ public struct PooledBuffer : IDisposable { Deallocate(); } -} +} \ No newline at end of file