From 3a05e779f944975da49500f8504ee23b6365a17a Mon Sep 17 00:00:00 2001 From: Alex Barney Date: Mon, 14 Mar 2022 13:34:52 -0700 Subject: [PATCH] Add ReadOnlyBlockCacheStorage --- src/LibHac/Diag/Assert.cs | 52 ++++- src/LibHac/Diag/Impl/AssertImpl.cs | 7 +- src/LibHac/FsSystem/BitmapUtils.cs | 1 + src/LibHac/FsSystem/LruListCache.cs | 87 ++++++++ .../FsSystem/ReadOnlyBlockCacheStorage.cs | 141 ++++++++++++ .../ReadOnlyBlockCacheStorageTests.cs | 200 ++++++++++++++++++ 6 files changed, 486 insertions(+), 2 deletions(-) create mode 100644 src/LibHac/FsSystem/LruListCache.cs create mode 100644 src/LibHac/FsSystem/ReadOnlyBlockCacheStorage.cs create mode 100644 tests/LibHac.Tests/FsSystem/ReadOnlyBlockCacheStorageTests.cs diff --git a/src/LibHac/Diag/Assert.cs b/src/LibHac/Diag/Assert.cs index 5b8a9bf0..8f87b05b 100644 --- a/src/LibHac/Diag/Assert.cs +++ b/src/LibHac/Diag/Assert.cs @@ -1272,7 +1272,57 @@ public static class Assert } // --------------------------------------------------------------------- - // Aligned + // Aligned long + // --------------------------------------------------------------------- + + private static void AlignedImpl(AssertionType assertionType, long value, int alignment, string valueText, + string alignmentText, string functionName, string fileName, int lineNumber) + { + if (AssertImpl.IsAligned(value, alignment)) + return; + + AssertImpl.InvokeAssertionAligned(assertionType, value, alignment, valueText, alignmentText, functionName, fileName, + lineNumber); + } + + [Conditional(AssertCondition)] + public static void Aligned(long value, int alignment, + [CallerArgumentExpression("value")] string valueText = "", + [CallerArgumentExpression("alignment")] string alignmentText = "", + [CallerMemberName] string functionName = "", + [CallerFilePath] string fileName = "", + [CallerLineNumber] int lineNumber = 0) + { + AlignedImpl(AssertionType.UserAssert, value, alignment, valueText, alignmentText, functionName, fileName, + lineNumber); + } + + [Conditional(AssertCondition)] + internal static void SdkAligned(long value, int alignment, + [CallerArgumentExpression("value")] string valueText = "", + [CallerArgumentExpression("alignment")] string alignmentText = "", + [CallerMemberName] string functionName = "", + [CallerFilePath] string fileName = "", + [CallerLineNumber] int lineNumber = 0) + { + AlignedImpl(AssertionType.SdkAssert, value, alignment, valueText, alignmentText, functionName, fileName, + lineNumber); + } + + [Conditional(AssertCondition)] + internal static void SdkRequiresAligned(long value, int alignment, + [CallerArgumentExpression("value")] string valueText = "", + [CallerArgumentExpression("alignment")] string alignmentText = "", + [CallerMemberName] string functionName = "", + [CallerFilePath] string fileName = "", + [CallerLineNumber] int lineNumber = 0) + { + AlignedImpl(AssertionType.SdkRequires, value, alignment, valueText, alignmentText, functionName, fileName, + lineNumber); + } + + // --------------------------------------------------------------------- + // Aligned ulong // --------------------------------------------------------------------- private static void AlignedImpl(AssertionType assertionType, ulong value, int alignment, string valueText, diff --git a/src/LibHac/Diag/Impl/AssertImpl.cs b/src/LibHac/Diag/Impl/AssertImpl.cs index cc2e262e..e27cc9bd 100644 --- a/src/LibHac/Diag/Impl/AssertImpl.cs +++ b/src/LibHac/Diag/Impl/AssertImpl.cs @@ -106,7 +106,7 @@ internal static class AssertImpl Assert.OnAssertionFailure(assertionType, "GreaterEqual", functionName, fileName, lineNumber, message); } - internal static void InvokeAssertionAligned(AssertionType assertionType, ulong value, int alignment, + internal static void InvokeAssertionAligned(AssertionType assertionType, T value, int alignment, string valueText, string alignmentText, string functionName, string fileName, int lineNumber) { string message = @@ -246,6 +246,11 @@ internal static class AssertImpl return lhs.CompareTo(rhs) >= 0; } + public static bool IsAligned(long value, int alignment) + { + return Alignment.IsAlignedPow2(value, (uint)alignment); + } + public static bool IsAligned(ulong value, int alignment) { return Alignment.IsAlignedPow2(value, (uint)alignment); diff --git a/src/LibHac/FsSystem/BitmapUtils.cs b/src/LibHac/FsSystem/BitmapUtils.cs index c425ec86..0cdb4bfc 100644 --- a/src/LibHac/FsSystem/BitmapUtils.cs +++ b/src/LibHac/FsSystem/BitmapUtils.cs @@ -4,6 +4,7 @@ namespace LibHac.FsSystem; public static class BitmapUtils { + // ReSharper disable once InconsistentNaming public static uint ILog2(uint value) { Assert.SdkRequiresGreater(value, 0u); diff --git a/src/LibHac/FsSystem/LruListCache.cs b/src/LibHac/FsSystem/LruListCache.cs new file mode 100644 index 00000000..9ffb20c0 --- /dev/null +++ b/src/LibHac/FsSystem/LruListCache.cs @@ -0,0 +1,87 @@ +using System; +using System.Collections.Generic; +using LibHac.Diag; + +namespace LibHac.FsSystem; + +/// +/// Represents a list of key/value pairs that are ordered by when they were last accessed. +/// +/// The type of the keys in the list. +/// The type of the values in the list. +/// Based on FS 13.1.0 (nnSdk 13.4.0) +public class LruListCache where TKey : IEquatable +{ + public struct Node + { + public TKey Key; + public TValue Value; + + public Node(TValue value) + { + Key = default; + Value = value; + } + } + + private LinkedList _list; + + public LruListCache() + { + _list = new LinkedList(); + } + + public bool FindValueAndUpdateMru(out TValue value, TKey key) + { + LinkedListNode currentNode = _list.First; + + while (currentNode is not null) + { + if (currentNode.ValueRef.Key.Equals(key)) + { + value = currentNode.ValueRef.Value; + + _list.Remove(currentNode); + _list.AddFirst(currentNode); + + return true; + } + + currentNode = currentNode.Next; + } + + value = default; + return false; + } + + public LinkedListNode PopLruNode() + { + Abort.DoAbortUnless(_list.Count != 0); + + LinkedListNode lru = _list.Last; + _list.RemoveLast(); + + return lru; + } + + public void PushMruNode(LinkedListNode node, TKey key) + { + node.ValueRef.Key = key; + _list.AddFirst(node); + } + + public void DeleteAllNodes() + { + _list.Clear(); + } + + public int GetSize() + { + return _list.Count; + } + + public bool IsEmpty() + { + return _list.Count == 0; + } +} \ No newline at end of file diff --git a/src/LibHac/FsSystem/ReadOnlyBlockCacheStorage.cs b/src/LibHac/FsSystem/ReadOnlyBlockCacheStorage.cs new file mode 100644 index 00000000..64f10ecf --- /dev/null +++ b/src/LibHac/FsSystem/ReadOnlyBlockCacheStorage.cs @@ -0,0 +1,141 @@ +using System; +using System.Collections.Generic; +using LibHac.Common; +using LibHac.Diag; +using LibHac.Fs; +using LibHac.Os; +using LibHac.Util; + +using BlockCache = LibHac.FsSystem.LruListCache>; + +namespace LibHac.FsSystem; + +/// +/// Caches reads to a base using a least-recently-used cache of data blocks. +/// The offset and size read from the storage must be aligned to multiples of the block size. +/// Only reads that access a single block will use the cache. Reads that access multiple blocks will +/// be passed down to the base to be handled without caching. +/// +/// Based on FS 13.1.0 (nnSdk 13.4.0) +public class ReadOnlyBlockCacheStorage : IStorage +{ + private SdkMutexType _mutex; + private BlockCache _blockCache; + private SharedRef _baseStorage; + private int _blockSize; + + public ReadOnlyBlockCacheStorage(ref SharedRef baseStorage, int blockSize, Memory buffer, + int cacheBlockCount) + { + _baseStorage = SharedRef.CreateMove(ref baseStorage); + _blockSize = blockSize; + _blockCache = new BlockCache(); + _mutex = new SdkMutexType(); + + Assert.SdkRequiresGreaterEqual(buffer.Length, _blockSize); + Assert.SdkRequires(BitUtil.IsPowerOfTwo(blockSize), $"{nameof(blockSize)} must be power of 2."); + Assert.SdkRequiresGreater(cacheBlockCount, 0); + Assert.SdkRequiresGreaterEqual(buffer.Length, blockSize * cacheBlockCount); + + for (int i = 0; i < cacheBlockCount; i++) + { + Memory nodeBuffer = buffer.Slice(i * blockSize, blockSize); + var node = new LinkedListNode(new BlockCache.Node(nodeBuffer)); + Assert.SdkNotNull(node); + + _blockCache.PushMruNode(node, -1); + } + } + + public override void Dispose() + { + _blockCache.DeleteAllNodes(); + _baseStorage.Destroy(); + + base.Dispose(); + } + + public override Result Read(long offset, Span destination) + { + Assert.SdkRequiresAligned(offset, _blockSize); + Assert.SdkRequiresAligned(destination.Length, _blockSize); + + if (destination.Length == _blockSize) + { + // Search the cache for the requested block. + using (new ScopedLock(ref _mutex)) + { + bool found = _blockCache.FindValueAndUpdateMru(out Memory cachedBuffer, offset / _blockSize); + if (found) + { + cachedBuffer.Span.CopyTo(destination); + return Result.Success; + } + } + + // The block wasn't in the cache. Read from the base storage. + Result rc = _baseStorage.Get.Read(offset, destination); + if (rc.IsFailure()) return rc.Miss(); + + // Add the block to the cache. + using (new ScopedLock(ref _mutex)) + { + LinkedListNode lru = _blockCache.PopLruNode(); + destination.CopyTo(lru.ValueRef.Value.Span); + _blockCache.PushMruNode(lru, offset / _blockSize); + } + + return Result.Success; + } + else + { + return _baseStorage.Get.Read(offset, destination); + } + } + + public override Result Write(long offset, ReadOnlySpan source) + { + // Missing: Log output + return ResultFs.UnsupportedWriteForReadOnlyBlockCacheStorage.Log(); + } + + public override Result Flush() + { + return Result.Success; + } + + public override Result SetSize(long size) + { + return ResultFs.UnsupportedSetSizeForReadOnlyBlockCacheStorage.Log(); + } + + public override Result GetSize(out long size) + { + return _baseStorage.Get.GetSize(out size); + } + + public override Result OperateRange(Span outBuffer, OperationId operationId, long offset, long size, + ReadOnlySpan inBuffer) + { + if (operationId == OperationId.InvalidateCache) + { + // Invalidate all the blocks in our cache. + using var scopedLock = new ScopedLock(ref _mutex); + + int cacheBlockCount = _blockCache.GetSize(); + for (int i = 0; i < cacheBlockCount; i++) + { + LinkedListNode lru = _blockCache.PopLruNode(); + _blockCache.PushMruNode(lru, -1); + } + } + else + { + Assert.SdkRequiresAligned(offset, _blockSize); + Assert.SdkRequiresAligned(size, _blockSize); + } + + // Pass the request to the base storage. + return _baseStorage.Get.OperateRange(outBuffer, operationId, offset, size, inBuffer); + } +} \ No newline at end of file diff --git a/tests/LibHac.Tests/FsSystem/ReadOnlyBlockCacheStorageTests.cs b/tests/LibHac.Tests/FsSystem/ReadOnlyBlockCacheStorageTests.cs new file mode 100644 index 00000000..15ec81d3 --- /dev/null +++ b/tests/LibHac.Tests/FsSystem/ReadOnlyBlockCacheStorageTests.cs @@ -0,0 +1,200 @@ +using System; +using System.Buffers.Binary; +using LibHac.Common; +using LibHac.Fs; +using LibHac.FsSystem; +using Xunit; + +namespace LibHac.Tests.FsSystem; + +public class ReadOnlyBlockCacheStorageTests +{ + private class TestContext + { + private int _blockSize; + private int _cacheBlockCount; + + public byte[] BaseData; + public byte[] ModifiedBaseData; + public byte[] CacheBuffer; + + public ReadOnlyBlockCacheStorage CacheStorage; + + public TestContext(int blockSize, int cacheBlockCount, int storageBlockCount, ulong rngSeed) + { + _blockSize = blockSize; + _cacheBlockCount = cacheBlockCount; + + BaseData = new byte[_blockSize * storageBlockCount]; + ModifiedBaseData = new byte[_blockSize * storageBlockCount]; + CacheBuffer = new byte[_blockSize * _cacheBlockCount]; + + new Random(rngSeed).NextBytes(BaseData); + BaseData.AsSpan().CopyTo(ModifiedBaseData); + + for (int i = 0; i < storageBlockCount; i++) + { + ModifyBlock(GetModifiedBaseDataBlock(i)); + } + + using var baseStorage = new SharedRef(new MemoryStorage(BaseData)); + CacheStorage = new ReadOnlyBlockCacheStorage(ref baseStorage.Ref(), _blockSize, CacheBuffer, _cacheBlockCount); + } + + public Span GetBaseDataBlock(int index) => BaseData.AsSpan(_blockSize * index, _blockSize); + public Span GetModifiedBaseDataBlock(int index) => ModifiedBaseData.AsSpan(_blockSize * index, _blockSize); + public Span GetCacheDataBlock(int index) => CacheBuffer.AsSpan(_blockSize * index, _blockSize); + + private void ModifyBlock(Span block) + { + BinaryPrimitives.WriteUInt64LittleEndian(block, ulong.MaxValue); + } + + public void ModifyAllCacheBlocks() + { + for (int i = 0; i < _cacheBlockCount; i++) + { + ModifyBlock(GetCacheDataBlock(i)); + } + } + + public Span ReadCachedStorage(int blockIndex) + { + byte[] buffer = new byte[_blockSize]; + Assert.Success(CacheStorage.Read(_blockSize * blockIndex, buffer)); + return buffer; + } + + public Span ReadCachedStorage(long offset, int size) + { + byte[] buffer = new byte[size]; + Assert.Success(CacheStorage.Read(offset, buffer)); + return buffer; + } + + public void InvalidateCache() + { + Assert.Success(CacheStorage.OperateRange(OperationId.InvalidateCache, 0, long.MaxValue)); + } + } + + private const int BlockSize = 0x4000; + private const int CacheBlockCount = 4; + private const int StorageBlockCount = 16; + + [Fact] + public void Read_CompleteBlocks_ReadsCorrectData() + { + var context = new TestContext(BlockSize, CacheBlockCount, StorageBlockCount, 21341); + + for (int i = 0; i < StorageBlockCount; i++) + { + Assert.True(context.GetBaseDataBlock(i).SequenceEqual(context.ReadCachedStorage(i))); + Assert.True(context.GetBaseDataBlock(i).SequenceEqual(context.ReadCachedStorage(i))); + } + } + + [Fact] + public void Read_PreviouslyCachedBlock_ReturnsDataFromCache() + { + const int index = 4; + var context = new TestContext(BlockSize, CacheBlockCount, StorageBlockCount, 21341); + + // Cache the block + context.ReadCachedStorage(index); + + // Directly modify the cache buffer + context.ModifyAllCacheBlocks(); + + // Next read should return the modified data from the cache buffer + Assert.True(context.GetModifiedBaseDataBlock(index).SequenceEqual(context.ReadCachedStorage(index))); + } + + [Fact] + public void Read_BlockEvictedFromCache_ReturnsDataFromBaseStorage() + { + const int index = 4; + var context = new TestContext(BlockSize, CacheBlockCount, StorageBlockCount, 21341); + + context.ReadCachedStorage(index); + context.ModifyAllCacheBlocks(); + + // Read enough additional blocks to push the initial block out of the cache + context.ReadCachedStorage(6); + context.ReadCachedStorage(7); + context.ReadCachedStorage(8); + context.ReadCachedStorage(9); + + // Reading the initial block should now return the original data + Assert.True(context.GetBaseDataBlock(index).SequenceEqual(context.ReadCachedStorage(index))); + } + + [Fact] + public void Read_ReadMultipleBlocks_BlocksAreEvictedAtTheRightTime() + { + const int index = 4; + var context = new TestContext(BlockSize, CacheBlockCount, StorageBlockCount, 21341); + + context.ReadCachedStorage(index); + context.ModifyAllCacheBlocks(); + + context.ReadCachedStorage(6); + context.ReadCachedStorage(7); + context.ReadCachedStorage(8); + + // Reading the initial block should return the cached data + Assert.True(context.GetModifiedBaseDataBlock(index).SequenceEqual(context.ReadCachedStorage(index))); + + for (int i = 0; i < 3; i++) + context.ReadCachedStorage(9 + i); + + context.ModifyAllCacheBlocks(); + + // The initial block should have been moved to the top of the cache when it was last accessed + Assert.True(context.GetModifiedBaseDataBlock(index).SequenceEqual(context.ReadCachedStorage(index))); + + // Access all the other blocks in the cache so the initial block is the least recently accessed + for (int i = 0; i < 3; i++) + Assert.True(context.GetModifiedBaseDataBlock(9 + i).SequenceEqual(context.ReadCachedStorage(9 + i))); + + // Add a new block to the cache + Assert.True(context.GetBaseDataBlock(2).SequenceEqual(context.ReadCachedStorage(2))); + + // The initial block should have been removed from the cache + Assert.True(context.GetBaseDataBlock(index).SequenceEqual(context.ReadCachedStorage(index))); + } + + [Fact] + public void Read_UnalignedBlock_ReturnsOriginalData() + { + const int index = 4; + var context = new TestContext(BlockSize, CacheBlockCount, StorageBlockCount, 21341); + + context.ReadCachedStorage(index); + context.ModifyAllCacheBlocks(); + + // Read two blocks at once + int offset = index * BlockSize; + int size = BlockSize * 2; + + // The cache should be bypassed, returning the original data + Assert.True(context.BaseData.AsSpan(offset, size).SequenceEqual(context.ReadCachedStorage(offset, size))); + } + + [Fact] + public void OperateRange_InvalidateCache_PreviouslyCachedBlockReturnsDataFromBaseStorage() + { + const int index = 4; + var context = new TestContext(BlockSize, CacheBlockCount, StorageBlockCount, 21341); + + context.ReadCachedStorage(index); + context.ModifyAllCacheBlocks(); + + // Next read should return the modified data from the cache buffer + Assert.True(context.GetModifiedBaseDataBlock(index).SequenceEqual(context.ReadCachedStorage(index))); + + // Reading after invalidating the cache should return the original data + context.InvalidateCache(); + Assert.True(context.GetBaseDataBlock(index).SequenceEqual(context.ReadCachedStorage(index))); + } +} \ No newline at end of file