diff --git a/src/LibHac/Fs/Fsa/MountUtility.cs b/src/LibHac/Fs/Fsa/MountUtility.cs index 699b8914..65a0cdb5 100644 --- a/src/LibHac/Fs/Fsa/MountUtility.cs +++ b/src/LibHac/Fs/Fsa/MountUtility.cs @@ -1,16 +1,35 @@ using System; +using System.Runtime.CompilerServices; using LibHac.Common; using LibHac.Diag; using LibHac.Fs.Impl; +using LibHac.Fs.Shim; using LibHac.Os; using LibHac.Util; -using static LibHac.Fs.StringTraits; using static LibHac.Fs.Impl.AccessLogStrings; +using static LibHac.Fs.StringTraits; namespace LibHac.Fs.Fsa; +/// +/// Contains functions for managing mounted file systems. +/// +/// Based on FS 13.1.0 (nnSdk 13.4.0) public static class MountUtility { + /// + /// Gets the mount name and non-mounted path components from a path that has a mount name. + /// + /// If the method returns successfully, contains the mount name of the provided path; + /// otherwise the contents are undefined. + /// If the method returns successfully, contains the provided path without the + /// mount name; otherwise the contents are undefined. + /// The to process. + /// : The operation was successful.
+ /// : does not contain a sub path after + /// the mount name that begins with / or \.
+ /// : contains an invalid mount name + /// or does not have a mount name.
private static Result GetMountNameAndSubPath(out MountName mountName, out U8Span subPath, U8Span path) { UnsafeHelpers.SkipParamInit(out mountName); @@ -82,8 +101,7 @@ public static class MountUtility return false; } - // Todo: VerifyUtf8String - return true; + return Utf8StringUtil.VerifyUtf8String(name); } public static bool IsUsedReservedMountName(this FileSystemClientImpl fs, U8Span name) @@ -144,7 +162,12 @@ public static class MountUtility if (fileSystem.IsFileDataCacheAttachable()) { - // Todo: Data cache purge + using var fileDataCacheAccessor = new GlobalFileDataCacheAccessorReadableScopedPointer(); + + if (fs.TryGetGlobalFileDataCacheAccessor(ref Unsafe.AsRef(in fileDataCacheAccessor))) + { + fileSystem.PurgeFileDataCache(fileDataCacheAccessor.Get()); + } } fs.Unregister(mountName); @@ -226,13 +249,23 @@ public static class MountUtility public static Result ConvertToFsCommonPath(this FileSystemClient fs, U8SpanMutable commonPathBuffer, U8Span path) { + Result rc; + if (commonPathBuffer.IsNull()) - return ResultFs.NullptrArgument.Log(); + { + rc = ResultFs.NullptrArgument.Value; + fs.Impl.AbortIfNeeded(rc); + return rc; + } if (path.IsNull()) - return ResultFs.NullptrArgument.Log(); + { + rc = ResultFs.NullptrArgument.Value; + fs.Impl.AbortIfNeeded(rc); + return rc; + } - Result rc = GetMountNameAndSubPath(out MountName mountName, out U8Span subPath, path); + rc = GetMountNameAndSubPath(out MountName mountName, out U8Span subPath, path); fs.Impl.AbortIfNeeded(rc); if (rc.IsFailure()) return rc; diff --git a/src/LibHac/Util/Utf8StringUtil.cs b/src/LibHac/Util/Utf8StringUtil.cs new file mode 100644 index 00000000..c4fe798a --- /dev/null +++ b/src/LibHac/Util/Utf8StringUtil.cs @@ -0,0 +1,160 @@ +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using LibHac.Common; +using LibHac.Diag; + +namespace LibHac.Util; + +/// +/// Contains functions for verifying and copying UTF-8 strings. +/// +/// Based on nnSdk 13.4.0 +public static class Utf8StringUtil +{ + private static ReadOnlySpan CodePointByteLengthTable => new byte[] + { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }; + + public static bool VerifyUtf8String(U8Span str) + { + return GetCodePointCountOfUtf8String(str) != -1; + } + + public static int GetCodePointCountOfUtf8String(U8Span str) + { + Assert.SdkRequiresGreater(str.Length, 0); + + ReadOnlySpan currentStr = str.Value; + int codePointCount = 0; + + while (currentStr.Length != 0) + { + int codePointByteLength = GetCodePointByteLength(currentStr[0]); + + if (codePointByteLength > currentStr.Length) + return -1; + + if (!VerifyCode(currentStr.Slice(0, codePointByteLength))) + return -1; + + currentStr = currentStr.Slice(codePointByteLength); + + codePointCount++; + } + + return codePointCount; + } + + public static int CopyUtf8String(Span output, ReadOnlySpan input, int maxCount) + { + Assert.SdkRequiresGreater(output.Length, 0); + Assert.SdkRequiresGreater(input.Length, 0); + Assert.SdkRequiresGreater(maxCount, 0); + + ReadOnlySpan currentInput = input; + int remainingCount = maxCount; + + while (remainingCount > 0 && currentInput.Length != 0) + { + // Verify the current code point + int codePointLength = GetCodePointByteLength(currentInput[0]); + if (codePointLength > currentInput.Length) + break; + + if (!VerifyCode(currentInput.Slice(0, codePointLength))) + break; + + // Ensure the output is large enough to hold the additional code point + int currentOutputLength = + Unsafe.ByteOffset(ref MemoryMarshal.GetReference(input), ref MemoryMarshal.GetReference(currentInput)) + .ToInt32() + codePointLength; + + if (currentOutputLength + 1 > output.Length) + break; + + // Advance to the next code point + currentInput = currentInput.Slice(codePointLength); + remainingCount--; + } + + // Copy the valid UTF-8 to the output buffer + int byteLength = Unsafe + .ByteOffset(ref MemoryMarshal.GetReference(input), ref MemoryMarshal.GetReference(currentInput)).ToInt32(); + + Assert.SdkAssert(byteLength + 1 <= output.Length); + + if (byteLength != 0) + input.Slice(0, byteLength).CopyTo(output); + + output[byteLength] = 0; + return byteLength; + } + + private static int GetCodePointByteLength(byte head) + { + return CodePointByteLengthTable[head]; + } + + private static bool IsValidTail(byte tail) + { + return (tail & 0xC0) == 0x80; + } + + private static bool VerifyCode(ReadOnlySpan str) + { + if (str.Length == 1) + return true; + + switch (str.Length) + { + case 2: + if (!IsValidTail(str[1])) + return false; + + break; + case 3: + if (str[0] == 0xE0 && (str[1] & 0x20) == 0) + return false; + + if (str[0] == 0xED && (str[1] & 0x20) != 0) + return false; + + if (!IsValidTail(str[1]) || !IsValidTail(str[2])) + return false; + + break; + case 4: + if (str[0] == 0xF0 && (str[1] & 0x30) == 0) + return false; + + if (str[0] == 0xFD && (str[1] & 0x30) != 0) + return false; + + if (!IsValidTail(str[1]) || !IsValidTail(str[2]) || !IsValidTail(str[3])) + return false; + + break; + default: + return false; + } + + return true; + } +} \ No newline at end of file