Add AesCtrStorage

This commit is contained in:
Alex Barney 2022-03-04 17:35:03 -07:00
parent c9352fcb5a
commit e8d54159e3
2 changed files with 260 additions and 1 deletions

View file

@ -0,0 +1,235 @@
using System;
using System.Buffers.Binary;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using LibHac.Common;
using LibHac.Common.FixedArrays;
using LibHac.Crypto;
using LibHac.Diag;
using LibHac.Fs;
using LibHac.Util;
namespace LibHac.FsSystem;
/// <summary>
/// Reads and writes to an <see cref="IStorage"/> that's encrypted with AES-CTR-128.
/// </summary>
/// <remarks>Based on FS 13.1.0 (nnSdk 13.4.0)</remarks>
public class AesCtrStorage : IStorage
{
public static readonly int BlockSize = Aes.BlockSize;
public static readonly int KeySize = Aes.KeySize128;
public static readonly int IvSize = Aes.KeySize128;
private IStorage _baseStorage;
private Array16<byte> _key;
private Array16<byte> _iv;
// LibHac addition: This field goes unused if initialized with a plain IStorage.
// The original class uses a template for both the shared and non-shared IStorage which avoids needing this field.
private SharedRef<IStorage> _baseStorageShared;
public static void MakeIv(Span<byte> outIv, ulong upperIv, long offset)
{
Assert.SdkRequiresEqual(outIv.Length, IvSize);
Assert.SdkRequiresGreaterEqual(offset, 0);
BinaryPrimitives.WriteUInt64BigEndian(outIv, upperIv);
BinaryPrimitives.WriteInt64BigEndian(outIv.Slice(sizeof(long)), offset / BlockSize);
}
public AesCtrStorage(IStorage baseStorage, ReadOnlySpan<byte> key, ReadOnlySpan<byte> iv)
{
Assert.SdkRequiresNotNull(baseStorage);
Assert.SdkRequiresEqual(key.Length, KeySize);
Assert.SdkRequiresEqual(iv.Length, IvSize);
_baseStorage = baseStorage;
key.CopyTo(_key.Items);
iv.CopyTo(_iv.Items);
}
public AesCtrStorage(in SharedRef<IStorage> baseStorage, ReadOnlySpan<byte> key, ReadOnlySpan<byte> iv)
{
Assert.SdkRequiresNotNull(in baseStorage);
Assert.SdkRequiresEqual(key.Length, KeySize);
Assert.SdkRequiresEqual(iv.Length, IvSize);
_baseStorage = baseStorage.Get;
_baseStorageShared = SharedRef<IStorage>.CreateCopy(in baseStorage);
key.CopyTo(_key.Items);
iv.CopyTo(_iv.Items);
}
public override void Dispose()
{
_baseStorageShared.Destroy();
base.Dispose();
}
public override Result Read(long offset, Span<byte> destination)
{
if (destination.Length == 0)
return Result.Success;
// Reads cannot contain any partial blocks.
if (!Alignment.IsAlignedPow2(offset, (uint)BlockSize))
return ResultFs.InvalidArgument.Log();
if (!Alignment.IsAlignedPow2(destination.Length, (uint)BlockSize))
return ResultFs.InvalidArgument.Log();
Result rc = _baseStorage.Read(offset, destination);
if (rc.IsFailure()) return rc.Miss();
using var changePriority = new ScopedThreadPriorityChanger(1, ScopedThreadPriorityChanger.Mode.Relative);
Array16<byte> counter = _iv;
Utility.AddCounter(counter.Items, (ulong)offset / (uint)BlockSize);
int decSize = Aes.DecryptCtr128(destination, destination, _key, counter);
if (decSize != destination.Length)
return ResultFs.UnexpectedInAesCtrStorageA.Log();
return Result.Success;
}
public override Result Write(long offset, ReadOnlySpan<byte> source)
{
if (source.Length == 0)
return Result.Success;
// We can only write full, aligned blocks.
if (!Alignment.IsAlignedPow2(offset, (uint)BlockSize))
return ResultFs.InvalidArgument.Log();
if (!Alignment.IsAlignedPow2(source.Length, (uint)BlockSize))
return ResultFs.InvalidArgument.Log();
// Get a pooled buffer.
// Note: The original code will const_cast the input buffer and encrypt the data in-place if the provided
// buffer is from the pooled buffer heap. This seems very error-prone since the data in buffers you pass
// as const might unexpectedly be modified. We make IsDeviceAddress() always return false
// so this won't happen, but the code that does the encryption in-place will be left in as a reference.
using var pooledBuffer = new PooledBuffer();
bool useWorkBuffer = PooledBufferGlobalMethods.IsDeviceAddress(source);
if (useWorkBuffer)
pooledBuffer.Allocate(source.Length, BlockSize);
// Setup the counter.
var counter = new Array16<byte>();
Utility.AddCounter(counter.Items, (ulong)offset / (uint)BlockSize);
// Loop until all data is written.
int remaining = source.Length;
int currentOffset = 0;
while (remaining > 0)
{
// Determine data we're writing and where.
int writeSize = useWorkBuffer ? Math.Max(pooledBuffer.GetSize(), remaining) : remaining;
Span<byte> writeBuffer = useWorkBuffer
? pooledBuffer.GetBuffer().Slice(0, writeSize)
: MemoryMarshal.CreateSpan(ref MemoryMarshal.GetReference(source), source.Length).Slice(0, writeSize);
// Encrypt the data, with temporarily increased priority.
using (new ScopedThreadPriorityChanger(1, ScopedThreadPriorityChanger.Mode.Relative))
{
int encSize = Aes.EncryptCtr128(source.Slice(currentOffset, writeSize), writeBuffer, _key, _iv);
if (encSize != writeSize)
return ResultFs.UnexpectedInAesCtrStorageA.Log();
}
// Write the encrypted data.
Result rc = _baseStorage.Write(offset + currentOffset, writeBuffer);
if (rc.IsFailure()) return rc.Miss();
// Advance.
currentOffset += writeSize;
remaining -= writeSize;
if (remaining > 0)
{
Utility.AddCounter(counter.Items, (uint)writeSize / (uint)BlockSize);
}
}
return Result.Success;
}
public override Result Flush()
{
return _baseStorage.Flush();
}
public override Result SetSize(long size)
{
return ResultFs.UnsupportedSetSizeForAesCtrStorage.Log();
}
public override Result GetSize(out long size)
{
return _baseStorage.GetSize(out size);
}
public override Result OperateRange(Span<byte> outBuffer, OperationId operationId, long offset, long size,
ReadOnlySpan<byte> inBuffer)
{
if (operationId != OperationId.InvalidateCache)
{
if (size == 0)
{
if (operationId == OperationId.QueryRange)
{
if (outBuffer.Length != Unsafe.SizeOf<QueryRangeInfo>())
return ResultFs.InvalidSize.Log();
Unsafe.As<byte, QueryRangeInfo>(ref MemoryMarshal.GetReference(outBuffer)).Clear();
}
return Result.Success;
}
if (!Alignment.IsAlignedPow2(offset, (uint)BlockSize))
return ResultFs.InvalidArgument.Log();
if (!Alignment.IsAlignedPow2(size, (uint)BlockSize))
return ResultFs.InvalidArgument.Log();
}
switch (operationId)
{
case OperationId.QueryRange:
{
if (outBuffer.Length != Unsafe.SizeOf<QueryRangeInfo>())
return ResultFs.InvalidSize.Log();
ref QueryRangeInfo outInfo =
ref Unsafe.As<byte, QueryRangeInfo>(ref MemoryMarshal.GetReference(outBuffer));
// Get the QueryRangeInfo of the underlying base storage.
Result rc = _baseStorage.OperateRange(outBuffer, operationId, offset, size, inBuffer);
if (rc.IsFailure()) return rc.Miss();
Unsafe.SkipInit(out QueryRangeInfo info);
info.Clear();
info.AesCtrKeyType = (int)QueryRangeInfo.AesCtrKeyTypeFlag.InternalKeyForSoftwareAes;
outInfo.Merge(in info);
break;
}
default:
{
Result rc = _baseStorage.OperateRange(outBuffer, operationId, offset, size, inBuffer);
if (rc.IsFailure()) return rc.Miss();
break;
}
}
return Result.Success;
}
}

View file

@ -8,6 +8,12 @@ namespace LibHac.FsSystem;
public static class PooledBufferGlobalMethods public static class PooledBufferGlobalMethods
{ {
// ReSharper disable once UnusedParameter.Global
public static bool IsPooledBuffer(ReadOnlySpan<byte> buffer)
{
return false;
}
public static long GetPooledBufferRetriedCount(this FileSystemServer fsSrv) public static long GetPooledBufferRetriedCount(this FileSystemServer fsSrv)
{ {
return fsSrv.Globals.PooledBuffer.CountRetried; return fsSrv.Globals.PooledBuffer.CountRetried;
@ -42,6 +48,24 @@ public static class PooledBufferGlobalMethods
g.CountReduceAllocation = 0; g.CountReduceAllocation = 0;
g.CountFailedIdealAllocationOnAsyncAccess = 0; g.CountFailedIdealAllocationOnAsyncAccess = 0;
} }
public static bool IsAdditionalDeviceAddress(ReadOnlySpan<byte> buffer)
{
return false;
}
// ReSharper disable once UnusedParameter.Global
/// <summary>
/// Checks if the provided buffer is located at a "device address".
/// </summary>
/// <param name="buffer">The buffer to check.</param>
/// <returns><see langword="true"/> if this is a device address; otherwise <see langword="false"/>.</returns>
/// <remarks>A device address is one that is either located in the pooled buffer heap
/// or in any of the registered additional device address ranges.</remarks>
public static bool IsDeviceAddress(ReadOnlySpan<byte> buffer)
{
return IsPooledBuffer(buffer) || IsAdditionalDeviceAddress(buffer);
}
} }
internal struct PooledBufferGlobals internal struct PooledBufferGlobals
@ -168,4 +192,4 @@ public struct PooledBuffer : IDisposable
{ {
Deallocate(); Deallocate();
} }
} }