Add multiple wait APIs

This adds the base work that will be needed for waiting on multiple objects. Some tweaks will still probably be necessary to make it work nicely with the .NET APIs.
This commit is contained in:
Alex Barney 2022-04-11 14:22:00 -07:00
parent 4e5e9a4627
commit e46c1f0231
12 changed files with 838 additions and 4 deletions

View file

@ -0,0 +1,8 @@
namespace LibHac.Os;
public enum TriBool
{
False = 0,
True = 1,
Undefined = 2
}

View file

@ -1,8 +1,9 @@
using System.Threading;
using System;
using System.Threading;
namespace LibHac.Os.Impl;
internal struct InternalCriticalSectionImpl
internal struct InternalCriticalSectionImpl : IDisposable
{
private object _obj;
@ -11,6 +12,8 @@ internal struct InternalCriticalSectionImpl
_obj = new object();
}
public void Dispose() { }
public void Initialize()
{
_obj = new object();

View file

@ -1,6 +1,8 @@
namespace LibHac.Os.Impl;
using System;
public struct InternalCriticalSection : ILockable
namespace LibHac.Os.Impl;
public struct InternalCriticalSection : ILockable, IDisposable
{
private InternalCriticalSectionImpl _impl;
@ -9,6 +11,11 @@ public struct InternalCriticalSection : ILockable
_impl = new InternalCriticalSectionImpl();
}
public void Dispose()
{
_impl.Dispose();
}
public void Initialize() => _impl.Initialize();
public void FinalizeObject() => _impl.FinalizeObject();

View file

@ -0,0 +1,14 @@
namespace LibHac.Os.Impl;
public class MultiWaitHolderImpl
{
private MultiWaitHolderBase _holder;
public MultiWaitHolderBase HolderBase => _holder;
public MultiWaitHolderOfNativeHandle HolderOfNativeHandle => (MultiWaitHolderOfNativeHandle)_holder;
public MultiWaitHolderImpl(MultiWaitHolderOfNativeHandle holder)
{
_holder = holder;
}
}

View file

@ -0,0 +1,59 @@
namespace LibHac.Os.Impl;
public abstract class MultiWaitHolderBase
{
private MultiWaitImpl _multiWait;
// LibHac addition because we can't reinterpret_cast a MultiWaitHolderBase
// to a MultiWaitHolderType like the original does in c++
public MultiWaitHolderType Holder { get; protected set; }
public abstract TriBool IsSignaled();
public abstract TriBool AddToObjectList();
public abstract void RemoveFromObjectList();
public abstract bool GetNativeHandle(out OsNativeHandle handle);
public virtual TimeSpan GetAbsoluteTimeToWakeup()
{
return TimeSpan.FromNanoSeconds(long.MaxValue);
}
public void SetMultiWait(MultiWaitImpl multiWait)
{
_multiWait = multiWait;
}
public MultiWaitImpl GetMultiWait()
{
return _multiWait;
}
public bool IsLinked()
{
return _multiWait is not null;
}
public bool IsNotLinked()
{
return _multiWait is null;
}
}
public abstract class MultiWaitHolderOfUserWaitObject : MultiWaitHolderBase
{
public override bool GetNativeHandle(out OsNativeHandle handle)
{
handle = default;
return false;
}
}
public abstract class MultiWaitHolderOfNativeWaitObject : MultiWaitHolderBase
{
public override TriBool AddToObjectList()
{
return TriBool.Undefined;
}
public override void RemoveFromObjectList() { /* ... */ }
}

View file

@ -0,0 +1,22 @@
namespace LibHac.Os.Impl;
public class MultiWaitHolderOfNativeHandle : MultiWaitHolderOfNativeWaitObject
{
private OsNativeHandle _handle;
internal MultiWaitHolderOfNativeHandle(OsNativeHandle handle)
{
_handle = handle;
}
public override TriBool IsSignaled()
{
return TriBool.Undefined;
}
public override bool GetNativeHandle(out OsNativeHandle handle)
{
handle = _handle;
return false;
}
}

View file

@ -0,0 +1,349 @@
using System;
using System.Collections.Generic;
using LibHac.Diag;
namespace LibHac.Os.Impl;
public class MultiWaitImpl : IDisposable
{
public const int MaximumHandleCount = 64;
public const int WaitInvalid = -3;
public const int WaitCancelled = -2;
public const int WaitTimedOut = -1;
private LinkedList<MultiWaitHolderBase> _multiWaitList;
private MultiWaitHolderBase _signaledHolder;
private TimeSpan _currentTime;
private InternalCriticalSection _csWait;
private MultiWaitTargetImpl _targetImpl;
// LibHac additions
private OsState _os;
private MultiWaitType _parent;
public MultiWaitType GetMultiWaitType() => _parent;
public MultiWaitImpl(OsState os, MultiWaitType parent)
{
_multiWaitList = new LinkedList<MultiWaitHolderBase>();
_currentTime = new TimeSpan(0);
_csWait = new InternalCriticalSection();
_targetImpl = new MultiWaitTargetImpl(os);
_os = os;
_parent = parent;
}
public void Dispose()
{
_csWait.Dispose();
_targetImpl.Dispose();
}
public MultiWaitHolderBase WaitAny()
{
return WaitAnyImpl(infinite: true, TimeSpan.FromNanoSeconds(long.MaxValue));
}
public MultiWaitHolderBase TryWaitAny()
{
return WaitAnyImpl(infinite: false, new TimeSpan(0));
}
public MultiWaitHolderBase TimedWaitAny(TimeSpan timeout)
{
return WaitAnyImpl(infinite: false, timeout);
}
public Result ReplyAndReceive(out MultiWaitHolderBase outHolder, OsNativeHandle replyTarget)
{
return WaitAnyImpl(out outHolder, infinite: true, TimeSpan.FromNanoSeconds(long.MaxValue), reply: true,
replyTarget);
}
public bool IsListEmpty()
{
return _multiWaitList.Count == 0;
}
public bool IsListNotEmpty()
{
return _multiWaitList.Count != 0;
}
public void PushBackToList(MultiWaitHolderBase holder)
{
_multiWaitList.AddLast(holder);
}
public void EraseFromList(MultiWaitHolderBase holder)
{
bool wasInList = _multiWaitList.Remove(holder);
Assert.SdkAssert(wasInList);
}
public void EraseAllFromList()
{
_multiWaitList.Clear();
}
public void MoveAllFromOther(MultiWaitImpl other)
{
// Set ourselves as multi wait for all of the other's holders.
foreach (MultiWaitHolderBase holder in other._multiWaitList)
{
holder.SetMultiWait(this);
}
LinkedListNode<MultiWaitHolderBase> node = other._multiWaitList.First;
while (node is not null)
{
other._multiWaitList.Remove(node);
_multiWaitList.AddLast(node);
node = other._multiWaitList.First;
}
}
public TimeSpan GetCurrentTime()
{
return _currentTime;
}
public void NotifyAndWakeupThread(MultiWaitHolderBase holder)
{
using ScopedLock<InternalCriticalSection> lk = ScopedLock.Lock(ref _csWait);
if (_signaledHolder is null)
{
_signaledHolder = holder;
_targetImpl.CancelWait();
}
}
private MultiWaitHolderBase WaitAnyImpl(bool infinite, TimeSpan timeout)
{
Result waitResult = WaitAnyImpl(out MultiWaitHolderBase holder, infinite, timeout, false, OsTypes.InvalidNativeHandle);
Assert.SdkAssert(waitResult.IsSuccess());
return holder;
}
private Result WaitAnyImpl(out MultiWaitHolderBase outHolder, bool infinite, TimeSpan timeout, bool reply,
OsNativeHandle replyTarget)
{
// Prepare for processing.
_signaledHolder = null;
_targetImpl.SetCurrentThreadHandleForCancelWait();
MultiWaitHolderBase holder = AddToEachObjectListAndCheckObjectState();
// Check if we've been signaled.
using (ScopedLock.Lock(ref _csWait))
{
if (_signaledHolder is not null)
holder = _signaledHolder;
}
// Process object array.
Result waitResult = Result.Success;
if (holder is null)
{
waitResult = InternalWaitAnyImpl(out holder, infinite, timeout, reply, replyTarget);
}
else if (reply && replyTarget != OsTypes.InvalidNativeHandle)
{
waitResult = _targetImpl.TimedReplyAndReceive(out int _, null, num: 0, replyTarget,
new TimeSpan(0));
if (waitResult.IsFailure())
holder = null;
}
// Unlink holders from the current object list.
RemoveFromEachObjectList();
_targetImpl.ClearCurrentThreadHandleForCancelWait();
outHolder = holder;
return waitResult;
}
private Result InternalWaitAnyImpl(out MultiWaitHolderBase outHolder, bool infinite, TimeSpan timeout, bool reply,
OsNativeHandle replyTarget)
{
var objectsArray = new OsNativeHandle[MaximumHandleCount];
var objectsArrayToHolder = new MultiWaitHolderBase[MaximumHandleCount];
int objectCount = ConstructObjectsArray(objectsArray, objectsArrayToHolder, MaximumHandleCount);
TimeSpan absoluteEndTime = infinite
? TimeSpan.FromNanoSeconds(long.MaxValue)
: _os.GetCurrentTick().ToTimeSpan(_os) + timeout;
while (true)
{
_currentTime = _os.GetCurrentTick().ToTimeSpan(_os);
MultiWaitHolderBase minTimeoutObject = RecalcMultiWaitTimeout(out TimeSpan timeoutMin, absoluteEndTime);
int index;
Result waitResult = Result.Success;
if (reply)
{
if (infinite && minTimeoutObject is null)
{
waitResult = _targetImpl.ReplyAndReceive(out index, objectsArray, objectCount, replyTarget);
}
else
{
waitResult = _targetImpl.TimedReplyAndReceive(out index, objectsArray, objectCount, replyTarget, timeoutMin);
}
}
else
{
if (infinite && minTimeoutObject is null)
{
waitResult = _targetImpl.WaitAny(out index, objectsArray, objectCount);
}
else
{
if (objectCount == 0 && timeoutMin == new TimeSpan(0))
{
index = WaitTimedOut;
}
else
{
waitResult = _targetImpl.TimedWaitAny(out index, objectsArray, objectCount, timeoutMin);
}
}
}
switch (index)
{
case WaitTimedOut:
if (minTimeoutObject is not null)
{
_currentTime = _os.GetCurrentTick().ToTimeSpan(_os);
if (minTimeoutObject.IsSignaled() == TriBool.True)
{
using ScopedLock<InternalCriticalSection> lk = ScopedLock.Lock(ref _csWait);
_signaledHolder = minTimeoutObject;
outHolder = minTimeoutObject;
return waitResult;
}
}
else
{
outHolder = null;
return waitResult;
}
break;
case WaitCancelled:
{
using ScopedLock<InternalCriticalSection> lk = ScopedLock.Lock(ref _csWait);
if (_signaledHolder is not null)
{
outHolder = _signaledHolder;
return waitResult;
}
break;
}
case WaitInvalid:
outHolder = null;
return waitResult;
default:
{
Assert.SdkAssert(index >= 0 && index < objectCount);
using ScopedLock<InternalCriticalSection> lk = ScopedLock.Lock(ref _csWait);
_signaledHolder = objectsArrayToHolder[index];
outHolder = _signaledHolder;
return waitResult;
}
}
replyTarget = OsTypes.InvalidNativeHandle;
}
}
public int ConstructObjectsArray(Span<OsNativeHandle> outHandles, Span<MultiWaitHolderBase> outObjects, int num)
{
Assert.SdkRequiresGreaterEqual(outHandles.Length, num);
Assert.SdkRequiresGreaterEqual(outObjects.Length, num);
int count = 0;
foreach (MultiWaitHolderBase holderBase in _multiWaitList)
{
if (holderBase.GetNativeHandle(out OsNativeHandle handle))
{
Abort.DoAbortUnless(count < num);
outHandles[count] = handle;
outObjects[count] = holderBase;
count++;
}
}
return count;
}
private MultiWaitHolderBase AddToEachObjectListAndCheckObjectState()
{
MultiWaitHolderBase signaledHolder = null;
foreach (MultiWaitHolderBase holderBase in _multiWaitList)
{
TriBool isSignaled = holderBase.AddToObjectList();
if (signaledHolder is null && isSignaled == TriBool.True)
{
signaledHolder = holderBase;
}
}
return signaledHolder;
}
private void RemoveFromEachObjectList()
{
foreach (MultiWaitHolderBase holderBase in _multiWaitList)
{
holderBase.RemoveFromObjectList();
}
}
public MultiWaitHolderBase RecalcMultiWaitTimeout(out TimeSpan outMinTimeout, TimeSpan endTime)
{
MultiWaitHolderBase minTimeoutHolder = null;
TimeSpan endTimeMin = endTime;
foreach (MultiWaitHolderBase holderBase in _multiWaitList)
{
TimeSpan wakeupTime = holderBase.GetAbsoluteTimeToWakeup();
if (wakeupTime < endTimeMin)
{
endTimeMin = wakeupTime;
minTimeoutHolder = holderBase;
}
}
if (endTimeMin < _currentTime)
{
outMinTimeout = new TimeSpan(0);
}
else
{
outMinTimeout = endTimeMin - _currentTime;
}
return minTimeoutHolder;
}
}

View file

@ -0,0 +1,47 @@
using System.Collections.Generic;
using LibHac.Diag;
namespace LibHac.Os.Impl;
public class MultiWaitObjectList
{
private LinkedList<MultiWaitHolderBase> _objectList;
public MultiWaitObjectList()
{
_objectList = new LinkedList<MultiWaitHolderBase>();
}
public void WakeupAllMultiWaitThreadsUnsafe()
{
foreach (MultiWaitHolderBase holderBase in _objectList)
{
holderBase.GetMultiWait().NotifyAndWakeupThread(holderBase);
}
}
public void BroadcastToUpdateObjectStateUnsafe()
{
foreach (MultiWaitHolderBase holderBase in _objectList)
{
holderBase.GetMultiWait().NotifyAndWakeupThread(null);
}
}
public bool IsEmpty()
{
return _objectList.Count == 0;
}
public void PushBackToList(MultiWaitHolderBase holderBase)
{
_objectList.AddLast(holderBase);
}
public void EraseFromList(MultiWaitHolderBase holderBase)
{
Assert.SdkRequires(_objectList.Contains(holderBase));
_objectList.Remove(holderBase);
}
}

View file

@ -0,0 +1,116 @@
using System;
using System.Threading;
using LibHac.Common;
using LibHac.Diag;
using LibHac.Fs;
namespace LibHac.Os.Impl;
public class MultiWaitTargetImpl : IDisposable
{
private EventWaitHandle _cancelEvent;
// LibHac addition
private OsState _os;
public MultiWaitTargetImpl(OsState os)
{
_cancelEvent = new EventWaitHandle(false, EventResetMode.AutoReset);
_os = os;
}
public void Dispose()
{
_cancelEvent.Dispose();
}
public void CancelWait()
{
_cancelEvent.Set();
}
public Result WaitAny(out int index, Span<WaitHandle> handles, int num)
{
return WaitForAnyObjects(out index, num, handles, int.MaxValue);
}
public Result TimedWaitAny(out int outIndex, Span<WaitHandle> handles, int num, TimeSpan timeout)
{
UnsafeHelpers.SkipParamInit(out outIndex);
var timeoutHelper = new TimeoutHelper(_os, timeout);
do
{
Result rc = WaitForAnyObjects(out int index, num, handles,
(int)timeoutHelper.GetTimeLeftOnTarget().GetMilliSeconds());
if (rc.IsFailure()) return rc.Miss();
if (index == MultiWaitImpl.WaitTimedOut)
{
outIndex = index;
return Result.Success;
}
} while (!timeoutHelper.TimedOut());
outIndex = MultiWaitImpl.WaitTimedOut;
return Result.Success;
}
public Result ReplyAndReceive(out int index, Span<WaitHandle> handles, int num, WaitHandle replyTarget)
{
return ReplyAndReceiveImpl(out index, handles, num, replyTarget, TimeSpan.FromNanoSeconds(long.MaxValue));
}
public Result TimedReplyAndReceive(out int index, Span<WaitHandle> handles, int num, WaitHandle replyTarget,
TimeSpan timeout)
{
return ReplyAndReceiveImpl(out index, handles, num, replyTarget, timeout);
}
public void SetCurrentThreadHandleForCancelWait()
{
/* ... */
}
public void ClearCurrentThreadHandleForCancelWait()
{
/* ... */
}
private Result WaitForAnyObjects(out int outIndex, int num, Span<WaitHandle> handles, int timeoutMs)
{
// Check that we can add our cancel handle to the wait.
Abort.DoAbortUnless(num + 1 < handles.Length);
handles[num] = _cancelEvent;
int index = WaitHandle.WaitAny(handles.Slice(0, num + 1).ToArray(), timeoutMs);
if (index == WaitHandle.WaitTimeout)
{
outIndex = MultiWaitImpl.WaitTimedOut;
return Result.Success;
}
Assert.SdkAssert(index >= 0 && index <= num);
if (index == num)
{
outIndex = MultiWaitImpl.WaitCancelled;
return Result.Success;
}
outIndex = index;
return Result.Success;
}
private Result ReplyAndReceiveImpl(out int outIndex, Span<WaitHandle> handles, int num, WaitHandle replyTarget,
TimeSpan timeout)
{
UnsafeHelpers.SkipParamInit(out outIndex);
Abort.DoAbortUnlessSuccess(ResultFs.NotImplemented.Value);
return ResultFs.NotImplemented.Log();
}
}

View file

@ -0,0 +1,179 @@
using LibHac.Diag;
using LibHac.Os.Impl;
namespace LibHac.Os;
// Todo: Handling waiting in .NET in an OS-agnostic way requires using WaitHandles.
// I'm not sure if this might cause issues in the future if we need to wait on objects other than
// those supported by WaitHandle.
public static class MultipleWait
{
private static MultiWaitImpl GetMultiWaitImpl(MultiWaitType multiWait)
{
return multiWait.Impl;
}
private static MultiWaitHolderType CastToMultiWaitHolder(MultiWaitHolderBase holderBase)
{
return holderBase.Holder;
}
// Note: The "IsWaiting" field is only used in develop builds
public static void InitializeMultiWait(this OsState os, MultiWaitType multiWait)
{
multiWait.Impl = new MultiWaitImpl(os, multiWait);
multiWait.IsWaiting = false;
MemoryFenceApi.FenceMemoryStoreStore();
multiWait.CurrentState = MultiWaitType.State.Initialized;
}
public static void FinalizeMultiWait(this OsState os, MultiWaitType multiWait)
{
MultiWaitImpl impl = GetMultiWaitImpl(multiWait);
Assert.SdkRequires(multiWait.CurrentState == MultiWaitType.State.Initialized);
Assert.SdkRequires(impl.IsListEmpty());
multiWait.CurrentState = MultiWaitType.State.NotInitialized;
impl.Dispose();
}
public static MultiWaitHolderType WaitAny(this OsState os, MultiWaitType multiWait)
{
MultiWaitImpl impl = GetMultiWaitImpl(multiWait);
Assert.SdkRequires(multiWait.CurrentState == MultiWaitType.State.Initialized);
Assert.SdkRequires(impl.IsListNotEmpty());
multiWait.IsWaiting = true;
MemoryFenceApi.FenceMemoryStoreAny();
MultiWaitHolderType holder = CastToMultiWaitHolder(impl.WaitAny());
MemoryFenceApi.FenceMemoryAnyStore();
multiWait.IsWaiting = false;
Assert.SdkAssert(holder is not null);
return holder;
}
public static MultiWaitHolderType TryWaitAny(this OsState os, MultiWaitType multiWait)
{
MultiWaitImpl impl = GetMultiWaitImpl(multiWait);
Assert.SdkRequires(multiWait.CurrentState == MultiWaitType.State.Initialized);
Assert.SdkRequires(impl.IsListNotEmpty());
multiWait.IsWaiting = true;
MemoryFenceApi.FenceMemoryStoreAny();
MultiWaitHolderType holder = CastToMultiWaitHolder(impl.TryWaitAny());
MemoryFenceApi.FenceMemoryAnyStore();
multiWait.IsWaiting = false;
return holder;
}
public static MultiWaitHolderType TimedWaitAny(this OsState os, MultiWaitType multiWait, TimeSpan timeout)
{
MultiWaitImpl impl = GetMultiWaitImpl(multiWait);
Assert.SdkRequires(multiWait.CurrentState == MultiWaitType.State.Initialized);
Assert.SdkRequires(impl.IsListNotEmpty());
Assert.SdkRequires(timeout.GetNanoSeconds() >= 0);
multiWait.IsWaiting = true;
MemoryFenceApi.FenceMemoryStoreAny();
MultiWaitHolderType holder = CastToMultiWaitHolder(impl.TimedWaitAny(timeout));
MemoryFenceApi.FenceMemoryAnyStore();
multiWait.IsWaiting = false;
return holder;
}
public static void FinalizeMultiWaitHolder(this OsState os, MultiWaitHolderType holder)
{
MultiWaitHolderBase holderBase = holder.Impl.HolderBase;
Assert.SdkRequires(holderBase.IsNotLinked());
}
public static void LinkMultiWaitHolder(this OsState os, MultiWaitType multiWait, MultiWaitHolderType holder)
{
MultiWaitImpl impl = GetMultiWaitImpl(multiWait);
MultiWaitHolderBase holderBase = holder.Impl.HolderBase;
Assert.SdkRequires(multiWait.CurrentState == MultiWaitType.State.Initialized);
Assert.SdkRequires(holderBase.IsNotLinked());
Assert.SdkEqual(false, multiWait.IsWaiting);
MemoryFenceApi.FenceMemoryLoadAny();
impl.PushBackToList(holderBase);
holderBase.SetMultiWait(impl);
}
public static void UnlinkMultiWaitHolder(this OsState os, MultiWaitHolderType holder)
{
MultiWaitHolderBase holderBase = holder.Impl.HolderBase;
Assert.SdkRequires(holderBase.IsLinked());
Assert.SdkEqual(false, holderBase.GetMultiWait().GetMultiWaitType().IsWaiting);
MemoryFenceApi.FenceMemoryLoadAny();
holderBase.GetMultiWait().EraseFromList(holderBase);
holderBase.SetMultiWait(null);
}
public static void UnlinkAllMultiWaitHolder(this OsState os, MultiWaitType multiWait)
{
MultiWaitImpl impl = GetMultiWaitImpl(multiWait);
Assert.SdkRequires(multiWait.CurrentState == MultiWaitType.State.Initialized);
Assert.SdkEqual(false, multiWait.IsWaiting);
MemoryFenceApi.FenceMemoryLoadAny();
impl.EraseAllFromList();
}
public static void MoveAllMultiWaitHolder(this OsState os, MultiWaitType dest, MultiWaitType source)
{
MultiWaitImpl dstImpl = GetMultiWaitImpl(dest);
MultiWaitImpl srcImpl = GetMultiWaitImpl(source);
Assert.SdkRequires(dest.CurrentState == MultiWaitType.State.Initialized);
Assert.SdkRequires(source.CurrentState == MultiWaitType.State.Initialized);
Assert.SdkEqual(false, dest.IsWaiting);
MemoryFenceApi.FenceMemoryLoadAny();
Assert.SdkEqual(false, source.IsWaiting);
MemoryFenceApi.FenceMemoryLoadAny();
dstImpl.MoveAllFromOther(srcImpl);
}
public static void SetMultiWaitHolderUserData(this OsState os, MultiWaitHolderType holder, object userData)
{
holder.UserData = userData;
}
public static object GetMultiWaitHolderUserData(this OsState os, MultiWaitHolderType holder)
{
return holder.UserData;
}
public static void InitializeMultiWaitHolder(this OsState os, MultiWaitHolderType holder, OsNativeHandle handle)
{
Assert.SdkRequires(handle != OsTypes.InvalidNativeHandle);
holder.Impl = new MultiWaitHolderImpl(new MultiWaitHolderOfNativeHandle(handle));
holder.UserData = null;
}
}

View file

@ -0,0 +1,22 @@
using LibHac.Os.Impl;
namespace LibHac.Os;
public class MultiWaitType
{
public enum State : byte
{
NotInitialized,
Initialized
}
public State CurrentState;
public bool IsWaiting;
public MultiWaitImpl Impl;
}
public class MultiWaitHolderType
{
public MultiWaitHolderImpl Impl;
public object UserData;
}

8
src/LibHac/Os/OsTypes.cs Normal file
View file

@ -0,0 +1,8 @@
global using OsNativeHandle = System.Threading.WaitHandle;
namespace LibHac.Os;
public static class OsTypes
{
public static OsNativeHandle InvalidNativeHandle => default;
}