diff --git a/src/LibHac/Diag/Assert.cs b/src/LibHac/Diag/Assert.cs index 318502da..4c22a26a 100644 --- a/src/LibHac/Diag/Assert.cs +++ b/src/LibHac/Diag/Assert.cs @@ -2,6 +2,7 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using LibHac.Common; using LibHac.Diag.Impl; using LibHac.Os; @@ -395,6 +396,98 @@ public static class Assert NotNullImpl(AssertionType.SdkRequires, value, valueText, functionName, fileName, lineNumber); } + // --------------------------------------------------------------------- + // Not null UniqueRef + // --------------------------------------------------------------------- + + private static void NotNullImpl(AssertionType assertionType, in UniqueRef value, + string valueText, string functionName, string fileName, int lineNumber) where T : class, IDisposable + { + if (AssertImpl.NotNull(in value)) + return; + + AssertImpl.InvokeAssertionNotNull(assertionType, valueText, functionName, fileName, lineNumber); + } + + [Conditional(AssertCondition)] + public static void NotNull(in UniqueRef value, + [CallerArgumentExpression("value")] string valueText = "", + [CallerMemberName] string functionName = "", + [CallerFilePath] string fileName = "", + [CallerLineNumber] int lineNumber = 0) + where T : class, IDisposable + { + NotNullImpl(AssertionType.UserAssert, in value, valueText, functionName, fileName, lineNumber); + } + + [Conditional(AssertCondition)] + internal static void SdkNotNull(in UniqueRef value, + [CallerArgumentExpression("value")] string valueText = "", + [CallerMemberName] string functionName = "", + [CallerFilePath] string fileName = "", + [CallerLineNumber] int lineNumber = 0) + where T : class, IDisposable + { + NotNullImpl(AssertionType.SdkAssert, in value, valueText, functionName, fileName, lineNumber); + } + + [Conditional(AssertCondition)] + internal static void SdkRequiresNotNull(in UniqueRef value, + [CallerArgumentExpression("value")] string valueText = "", + [CallerMemberName] string functionName = "", + [CallerFilePath] string fileName = "", + [CallerLineNumber] int lineNumber = 0) + where T : class, IDisposable + { + NotNullImpl(AssertionType.SdkRequires, in value, valueText, functionName, fileName, lineNumber); + } + + // --------------------------------------------------------------------- + // Not null SharedRef + // --------------------------------------------------------------------- + + private static void NotNullImpl(AssertionType assertionType, in SharedRef value, + string valueText, string functionName, string fileName, int lineNumber) where T : class, IDisposable + { + if (AssertImpl.NotNull(in value)) + return; + + AssertImpl.InvokeAssertionNotNull(assertionType, valueText, functionName, fileName, lineNumber); + } + + [Conditional(AssertCondition)] + public static void NotNull(in SharedRef value, + [CallerArgumentExpression("value")] string valueText = "", + [CallerMemberName] string functionName = "", + [CallerFilePath] string fileName = "", + [CallerLineNumber] int lineNumber = 0) + where T : class, IDisposable + { + NotNullImpl(AssertionType.UserAssert, in value, valueText, functionName, fileName, lineNumber); + } + + [Conditional(AssertCondition)] + internal static void SdkNotNull(in SharedRef value, + [CallerArgumentExpression("value")] string valueText = "", + [CallerMemberName] string functionName = "", + [CallerFilePath] string fileName = "", + [CallerLineNumber] int lineNumber = 0) + where T : class, IDisposable + { + NotNullImpl(AssertionType.SdkAssert, in value, valueText, functionName, fileName, lineNumber); + } + + [Conditional(AssertCondition)] + internal static void SdkRequiresNotNull(in SharedRef value, + [CallerArgumentExpression("value")] string valueText = "", + [CallerMemberName] string functionName = "", + [CallerFilePath] string fileName = "", + [CallerLineNumber] int lineNumber = 0) + where T : class, IDisposable + { + NotNullImpl(AssertionType.SdkRequires, in value, valueText, functionName, fileName, lineNumber); + } + // --------------------------------------------------------------------- // Null // --------------------------------------------------------------------- diff --git a/src/LibHac/Diag/Impl/AssertImpl.cs b/src/LibHac/Diag/Impl/AssertImpl.cs index 63ce1c12..cc16becc 100644 --- a/src/LibHac/Diag/Impl/AssertImpl.cs +++ b/src/LibHac/Diag/Impl/AssertImpl.cs @@ -1,6 +1,7 @@ using System; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using LibHac.Common; using LibHac.Util; namespace LibHac.Diag.Impl; @@ -145,6 +146,16 @@ internal static class AssertImpl return !Unsafe.IsNullRef(ref MemoryMarshal.GetReference(span)) || span.Length == 0; } + public static bool NotNull(in UniqueRef item) where T : class, IDisposable + { + return item.HasValue; + } + + public static bool NotNull(in SharedRef item) where T : class, IDisposable + { + return item.HasValue; + } + public static bool WithinRange(int value, int lowerInclusive, int upperExclusive) { return lowerInclusive <= value && value < upperExclusive; diff --git a/src/LibHac/Fs/SubStorage.cs b/src/LibHac/Fs/SubStorage.cs index 115d4a4b..7cd86327 100644 --- a/src/LibHac/Fs/SubStorage.cs +++ b/src/LibHac/Fs/SubStorage.cs @@ -201,6 +201,7 @@ public class SubStorage : IStorage if (rc.IsFailure()) return rc; _size = size; + return Result.Success; } @@ -217,9 +218,13 @@ public class SubStorage : IStorage protected override Result DoOperateRange(Span outBuffer, OperationId operationId, long offset, long size, ReadOnlySpan inBuffer) { if (!IsValid()) return ResultFs.NotInitialized.Log(); - if (size == 0) return Result.Success; - if (!CheckOffsetAndSize(_offset, size)) return ResultFs.OutOfRange.Log(); - return base.DoOperateRange(outBuffer, operationId, _offset + offset, size, inBuffer); + if (operationId != OperationId.InvalidateCache) + { + if (size == 0) return Result.Success; + if (!CheckOffsetAndSize(_offset, size)) return ResultFs.OutOfRange.Log(); + } + + return BaseStorage.OperateRange(outBuffer, operationId, _offset + offset, size, inBuffer); } } diff --git a/src/LibHac/Fs/ValueSubStorage.cs b/src/LibHac/Fs/ValueSubStorage.cs new file mode 100644 index 00000000..1a03624f --- /dev/null +++ b/src/LibHac/Fs/ValueSubStorage.cs @@ -0,0 +1,179 @@ +using System; +using System.Runtime.CompilerServices; +using LibHac.Common; +using LibHac.Diag; + +namespace LibHac.Fs; + +[NonCopyableDisposable] +public struct ValueSubStorage : IDisposable +{ + private IStorage _baseStorage; + private long _offset; + private long _size; + private bool _isResizable; + private SharedRef _sharedBaseStorage; + + public ValueSubStorage() + { + _baseStorage = null; + _offset = 0; + _size = 0; + _isResizable = false; + _sharedBaseStorage = new SharedRef(); + } + + public ValueSubStorage(in ValueSubStorage other) + { + _baseStorage = other._baseStorage; + _offset = other._offset; + _size = other._size; + _isResizable = other._isResizable; + _sharedBaseStorage = SharedRef.CreateCopy(in other._sharedBaseStorage); + } + + public ValueSubStorage(IStorage baseStorage, long offset, long size) + { + _baseStorage = baseStorage; + _offset = offset; + _size = size; + _isResizable = false; + _sharedBaseStorage = new SharedRef(); + + Assert.SdkRequiresNotNull(baseStorage); + Assert.SdkRequiresLessEqual(0, offset); + Assert.SdkRequiresLessEqual(0, size); + } + + public ValueSubStorage(in ValueSubStorage subStorage, long offset, long size) + { + _baseStorage = subStorage._baseStorage; + _offset = subStorage._offset + offset; + _size = size; + _isResizable = false; + _sharedBaseStorage = SharedRef.CreateCopy(in subStorage._sharedBaseStorage); + + Assert.SdkRequiresLessEqual(0, offset); + Assert.SdkRequiresLessEqual(0, size); + Assert.SdkRequires(subStorage.IsValid()); + Assert.SdkRequiresGreaterEqual(subStorage._size, offset + size); + } + + public ValueSubStorage(in SharedRef baseStorage, long offset, long size) + { + _baseStorage = baseStorage.Get; + _offset = offset; + _size = size; + _isResizable = false; + _sharedBaseStorage = SharedRef.CreateCopy(in baseStorage); + + Assert.SdkRequiresNotNull(in baseStorage); + Assert.SdkRequiresLessEqual(0, _offset); + Assert.SdkRequiresLessEqual(0, _size); + } + + public void Dispose() + { + _baseStorage = null; + _sharedBaseStorage.Destroy(); + } + + public void Set(in ValueSubStorage other) + { + if (!Unsafe.AreSame(ref Unsafe.AsRef(in this), ref Unsafe.AsRef(in other))) + { + _baseStorage = other._baseStorage; + _offset = other._offset; + _size = other._size; + _isResizable = other._isResizable; + _sharedBaseStorage.SetByCopy(in other._sharedBaseStorage); + } + } + + private readonly bool IsValid() => _baseStorage is not null; + + public void SetResizable(bool isResizable) + { + _isResizable = isResizable; + } + + public readonly Result Read(long offset, Span destination) + { + if (!IsValid()) return ResultFs.NotInitialized.Log(); + if (destination.Length == 0) return Result.Success; + + if (!IStorage.CheckAccessRange(offset, destination.Length, _size)) + return ResultFs.OutOfRange.Log(); + + return _baseStorage.Read(_offset + offset, destination); + } + + public readonly Result Write(long offset, ReadOnlySpan source) + { + if (!IsValid()) return ResultFs.NotInitialized.Log(); + if (source.Length == 0) return Result.Success; + + if (!IStorage.CheckAccessRange(offset, source.Length, _size)) + return ResultFs.OutOfRange.Log(); + + return _baseStorage.Write(_offset + offset, source); + } + + public readonly Result Flush() + { + if (!IsValid()) return ResultFs.NotInitialized.Log(); + + return _baseStorage.Flush(); + } + + public Result SetSize(long size) + { + if (!IsValid()) return ResultFs.NotInitialized.Log(); + if (!_isResizable) return ResultFs.UnsupportedSetSizeForNotResizableSubStorage.Log(); + if (!IStorage.CheckOffsetAndSize(_offset, size)) return ResultFs.InvalidSize.Log(); + + Result rc = _baseStorage.GetSize(out long currentSize); + if (rc.IsFailure()) return rc; + + if (currentSize != _offset + _size) + { + // SubStorage cannot be resized unless it is located at the end of the base storage. + return ResultFs.UnsupportedSetSizeForResizableSubStorage.Log(); + } + + rc = _baseStorage.SetSize(_offset + size); + if (rc.IsFailure()) return rc; + + _size = size; + + return Result.Success; + } + + public readonly Result GetSize(out long size) + { + UnsafeHelpers.SkipParamInit(out size); + + if (!IsValid()) return ResultFs.NotInitialized.Log(); + + size = _size; + return Result.Success; + } + + public readonly Result OperateRange(OperationId operationId, long offset, long size) + { + return OperateRange(Span.Empty, operationId, offset, size, ReadOnlySpan.Empty); + } + + public readonly Result OperateRange(Span outBuffer, OperationId operationId, long offset, long size, ReadOnlySpan inBuffer) + { + if (!IsValid()) return ResultFs.NotInitialized.Log(); + + if (operationId != OperationId.InvalidateCache) + { + if (size == 0) return Result.Success; + if (!IStorage.CheckOffsetAndSize(_offset, size)) return ResultFs.OutOfRange.Log(); + } + + return _baseStorage.OperateRange(outBuffer, operationId, _offset + offset, size, inBuffer); + } +}