Add ReaderWriterLock

This commit is contained in:
Alex Barney 2021-03-11 02:56:46 -07:00
parent 6360a18fbf
commit 7950a91dd0
6 changed files with 617 additions and 1 deletions

View file

@ -10,4 +10,11 @@
{
bool TryLock();
}
public interface ISharedMutex : ILockable
{
void LockShared();
bool TryLockShared();
void UnlockShared();
}
}

View file

@ -0,0 +1,212 @@
using System;
using LibHac.Diag;
namespace LibHac.Os.Impl
{
internal static partial class ReaderWriterLockImpl
{
public static void AcquireReadLockImpl(this OsStateImpl os, ref ReaderWriterLockType rwLock)
{
ref InternalCriticalSection cs = ref GetLockCount(ref rwLock).Cs;
using ScopedLock<InternalCriticalSection> lk = ScopedLock.Lock(ref cs);
// If we already own the lock, no additional action is needed
if (rwLock.OwnerThread == Environment.CurrentManagedThreadId)
{
Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) == 1);
}
// Otherwise we might need to block until we can acquire the read lock
else
{
// Wait until there aren't any writers or waiting writers
while (GetWriteLocked(in GetLockCount(ref rwLock)) == 1 ||
GetWriteLockWaiterCount(in GetLockCount(ref rwLock)) != 0)
{
IncReadLockWaiterCount(ref GetLockCount(ref rwLock));
rwLock.CvReadLockWaiter.Wait(ref cs);
DecReadLockWaiterCount(ref GetLockCount(ref rwLock));
}
Assert.True(GetWriteLockCount(in rwLock) == 0);
Assert.True(rwLock.OwnerThread == 0);
}
IncReadLockCount(ref GetLockCount(ref rwLock));
}
public static bool TryAcquireReadLockImpl(this OsStateImpl os, ref ReaderWriterLockType rwLock)
{
using ScopedLock<InternalCriticalSection> lk = ScopedLock.Lock(ref GetLockCount(ref rwLock).Cs);
// Acquire the lock if we already have write access
if (rwLock.OwnerThread == Environment.CurrentManagedThreadId)
{
Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) == 1);
IncReadLockCount(ref GetLockCount(ref rwLock));
return true;
}
// Fail to acquire if there are any writers or waiting writers
if (GetWriteLocked(in GetLockCount(ref rwLock)) == 1 ||
GetWriteLockWaiterCount(in GetLockCount(ref rwLock)) != 0)
{
return false;
}
// Otherwise acquire the lock
Assert.True(GetWriteLockCount(in rwLock) == 0);
Assert.True(rwLock.OwnerThread == 0);
IncReadLockCount(ref GetLockCount(ref rwLock));
return true;
}
public static void ReleaseReadLockImpl(this OsStateImpl os, ref ReaderWriterLockType rwLock)
{
using ScopedLock<InternalCriticalSection> lk = ScopedLock.Lock(ref GetLockCount(ref rwLock).Cs);
Assert.True(GetReadLockCount(in GetLockCount(ref rwLock)) > 0);
DecReadLockWaiterCount(ref GetLockCount(ref rwLock));
// If we own the lock, check if we need to release ownership and signal any waiting threads
if (rwLock.OwnerThread == Environment.CurrentManagedThreadId)
{
Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) == 1);
// Return if we still hold any locks
if (GetWriteLockCount(in rwLock) != 0 || GetReadLockCount(in GetLockCount(ref rwLock)) != 0)
{
return;
}
// We don't hold any more locks. Release our ownership of the lock
rwLock.OwnerThread = 0;
ClearWriteLocked(ref GetLockCount(ref rwLock));
// Signal the next writer if any are waiting
if (GetWriteLockWaiterCount(in GetLockCount(ref rwLock)) != 0)
{
rwLock.CvWriteLockWaiter.Signal();
}
// Otherwise signal any waiting readers
else if (GetReadLockWaiterCount(in GetLockCount(ref rwLock)) != 0)
{
rwLock.CvReadLockWaiter.Broadcast();
}
}
// Otherwise we need to signal the next writer if we were the only reader
else
{
Assert.True(GetWriteLockCount(in rwLock) == 0);
Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) == 0);
Assert.True(rwLock.OwnerThread == 0);
// Signal the next writer if no readers are left
if (GetReadLockCount(in GetLockCount(ref rwLock)) == 0 &&
GetWriteLockWaiterCount(in GetLockCount(ref rwLock)) != 0)
{
rwLock.CvWriteLockWaiter.Signal();
}
}
}
public static void AcquireWriteLockImpl(this OsStateImpl os, ref ReaderWriterLockType rwLock)
{
ref InternalCriticalSection cs = ref GetLockCount(ref rwLock).Cs;
using ScopedLock<InternalCriticalSection> lk = ScopedLock.Lock(ref cs);
int currentThread = Environment.CurrentManagedThreadId;
// Increase the write lock count if we already own the lock
if (rwLock.OwnerThread == currentThread)
{
Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) == 1);
IncWriteLockCount(ref rwLock);
return;
}
// Otherwise wait until there aren't any readers or writers
while (GetReadLockCount(in GetLockCount(ref rwLock)) != 0 ||
GetWriteLocked(in GetLockCount(ref rwLock)) == 1)
{
IncWriteLockWaiterCount(ref GetLockCount(ref rwLock));
rwLock.CvWriteLockWaiter.Wait(ref cs);
DecWriteLockWaiterCount(ref GetLockCount(ref rwLock));
}
Assert.True(GetWriteLockCount(in rwLock) == 0);
Assert.True(rwLock.OwnerThread == 0);
// Acquire the lock
IncWriteLockCount(ref rwLock);
SetWriteLocked(ref GetLockCount(ref rwLock));
rwLock.OwnerThread = currentThread;
}
public static bool TryAcquireWriteLockImpl(this OsStateImpl os, ref ReaderWriterLockType rwLock)
{
using ScopedLock<InternalCriticalSection> lk = ScopedLock.Lock(ref GetLockCount(ref rwLock).Cs);
int currentThread = Environment.CurrentManagedThreadId;
// Acquire the lock if we already have write access
if (rwLock.OwnerThread == currentThread)
{
Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) == 1);
IncWriteLockCount(ref rwLock);
return true;
}
// Fail to acquire if there are any readers or writers
if (GetReadLockCount(in GetLockCount(ref rwLock)) != 0 ||
GetWriteLocked(in GetLockCount(ref rwLock)) == 1)
{
return false;
}
// Otherwise acquire the lock
Assert.True(GetWriteLockCount(in rwLock) == 0);
Assert.True(rwLock.OwnerThread == 0);
IncWriteLockCount(ref rwLock);
SetWriteLocked(ref GetLockCount(ref rwLock));
rwLock.OwnerThread = currentThread;
return true;
}
public static void ReleaseWriteLockImpl(this OsStateImpl os, ref ReaderWriterLockType rwLock)
{
using ScopedLock<InternalCriticalSection> lk = ScopedLock.Lock(ref GetLockCount(ref rwLock).Cs);
Assert.True(GetWriteLockCount(in rwLock) > 0);
Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) != 0);
Assert.True(rwLock.OwnerThread == Environment.CurrentManagedThreadId);
DecWriteLockCount(ref rwLock);
// Return if we still hold any locks
if (GetWriteLockCount(in rwLock) != 0 || GetReadLockCount(in GetLockCountRo(in rwLock)) != 0)
{
return;
}
// We don't hold any more locks. Release our ownership of the lock
rwLock.OwnerThread = 0;
ClearWriteLocked(ref GetLockCount(ref rwLock));
// Signal the next writer if any are waiting
if (GetWriteLockWaiterCount(in GetLockCount(ref rwLock)) != 0)
{
rwLock.CvWriteLockWaiter.Signal();
}
// Otherwise signal any waiting readers
else if (GetReadLockWaiterCount(in GetLockCount(ref rwLock)) != 0)
{
rwLock.CvReadLockWaiter.Broadcast();
}
}
}
}

View file

@ -0,0 +1,128 @@
using LibHac.Diag;
namespace LibHac.Os.Impl
{
internal static partial class ReaderWriterLockImpl
{
public static void ClearReadLockCount(ref ReaderWriterLockType.LockCountType lc)
{
lc.Counter.ReadLockCount = 0;
}
public static void ClearWriteLocked(ref ReaderWriterLockType.LockCountType lc)
{
lc.Counter.WriteLocked = 0;
}
public static void ClearReadLockWaiterCount(ref ReaderWriterLockType.LockCountType lc)
{
lc.Counter.ReadLockWaiterCount = 0;
}
public static void ClearWriteLockWaiterCount(ref ReaderWriterLockType.LockCountType lc)
{
lc.Counter.WriteLockWaiterCount = 0;
}
public static void ClearWriteLockCount(ref ReaderWriterLockType rwLock)
{
rwLock.LockCount.WriteLockCount = 0;
}
public static ref ReaderWriterLockType.LockCountType GetLockCount(ref ReaderWriterLockType rwLock)
{
return ref rwLock.LockCount;
}
public static ref readonly ReaderWriterLockType.LockCountType GetLockCountRo(in ReaderWriterLockType rwLock)
{
return ref rwLock.LockCount;
}
public static uint GetReadLockCount(in ReaderWriterLockType.LockCountType lc)
{
return lc.Counter.ReadLockCount;
}
public static uint GetWriteLocked(in ReaderWriterLockType.LockCountType lc)
{
return lc.Counter.WriteLocked;
}
public static uint GetReadLockWaiterCount(in ReaderWriterLockType.LockCountType lc)
{
return lc.Counter.ReadLockWaiterCount;
}
public static uint GetWriteLockWaiterCount(in ReaderWriterLockType.LockCountType lc)
{
return lc.Counter.WriteLockWaiterCount;
}
public static uint GetWriteLockCount(in ReaderWriterLockType rwLock)
{
return rwLock.LockCount.WriteLockCount;
}
public static void IncReadLockCount(ref ReaderWriterLockType.LockCountType lc)
{
uint readLockCount = lc.Counter.ReadLockCount;
Assert.True(readLockCount < ReaderWriterLock.ReaderWriterLockCountMax);
lc.Counter.ReadLockCount = readLockCount + 1;
}
public static void DecReadLockCount(ref ReaderWriterLockType.LockCountType lc)
{
uint readLockCount = lc.Counter.ReadLockCount;
Assert.True(readLockCount > 0);
lc.Counter.ReadLockCount = readLockCount - 1;
}
public static void IncReadLockWaiterCount(ref ReaderWriterLockType.LockCountType lc)
{
uint readLockWaiterCount = lc.Counter.ReadLockWaiterCount;
Assert.True(readLockWaiterCount < ReaderWriterLock.ReadWriteLockWaiterCountMax);
lc.Counter.ReadLockWaiterCount = readLockWaiterCount + 1;
}
public static void DecReadLockWaiterCount(ref ReaderWriterLockType.LockCountType lc)
{
uint readLockWaiterCount = lc.Counter.ReadLockWaiterCount;
Assert.True(readLockWaiterCount > 0);
lc.Counter.ReadLockWaiterCount = readLockWaiterCount - 1;
}
public static void IncWriteLockWaiterCount(ref ReaderWriterLockType.LockCountType lc)
{
uint writeLockWaiterCount = lc.Counter.WriteLockWaiterCount;
Assert.True(writeLockWaiterCount < ReaderWriterLock.ReadWriteLockWaiterCountMax);
lc.Counter.WriteLockWaiterCount = writeLockWaiterCount + 1;
}
public static void DecWriteLockWaiterCount(ref ReaderWriterLockType.LockCountType lc)
{
uint writeLockWaiterCount = lc.Counter.WriteLockWaiterCount;
Assert.True(writeLockWaiterCount > 0);
lc.Counter.WriteLockWaiterCount = writeLockWaiterCount - 1;
}
public static void IncWriteLockCount(ref ReaderWriterLockType rwLock)
{
uint writeLockCount = rwLock.LockCount.WriteLockCount;
Assert.True(writeLockCount < ReaderWriterLock.ReaderWriterLockCountMax);
rwLock.LockCount.WriteLockCount = writeLockCount + 1;
}
public static void DecWriteLockCount(ref ReaderWriterLockType rwLock)
{
uint writeLockCount = rwLock.LockCount.WriteLockCount;
Assert.True(writeLockCount > 0);
rwLock.LockCount.WriteLockCount = writeLockCount - 1;
}
public static void SetWriteLocked(ref ReaderWriterLockType.LockCountType lc)
{
lc.Counter.WriteLocked = 1;
}
}
}

View file

@ -5,7 +5,8 @@ namespace LibHac.Os
{
public class OsState : IDisposable
{
private HorizonClient Hos { get; }
public OsStateImpl Impl => new OsStateImpl(this);
internal HorizonClient Hos { get; }
internal OsResourceManager ResourceManager { get; }
// Todo: Use configuration object if/when more options are added
@ -25,4 +26,13 @@ namespace LibHac.Os
ResourceManager.Dispose();
}
}
// Functions in the nn::os::detail namespace use this struct.
public readonly struct OsStateImpl
{
internal readonly OsState Os;
internal HorizonClient Hos => Os.Hos;
internal OsStateImpl(OsState parent) => Os = parent;
}
}

View file

@ -0,0 +1,195 @@
using System;
using LibHac.Diag;
using LibHac.Os.Impl;
namespace LibHac.Os
{
public static class ReaderWriterLockApi
{
public static void InitializeReaderWriterLock(this OsState os, ref ReaderWriterLockType rwLock)
{
// Create objects.
ReaderWriterLockImpl.GetLockCount(ref rwLock).Cs.Initialize();
rwLock.CvReadLockWaiter.Initialize();
rwLock.CvWriteLockWaiter.Initialize();
// Set member variables.
ReaderWriterLockImpl.ClearReadLockCount(ref ReaderWriterLockImpl.GetLockCount(ref rwLock));
ReaderWriterLockImpl.ClearWriteLocked(ref ReaderWriterLockImpl.GetLockCount(ref rwLock));
ReaderWriterLockImpl.ClearReadLockWaiterCount(ref ReaderWriterLockImpl.GetLockCount(ref rwLock));
ReaderWriterLockImpl.ClearWriteLockWaiterCount(ref ReaderWriterLockImpl.GetLockCount(ref rwLock));
ReaderWriterLockImpl.ClearWriteLockCount(ref rwLock);
rwLock.OwnerThread = 0;
// Mark initialized.
rwLock.LockState = ReaderWriterLockType.State.Initialized;
}
public static void FinalizeReaderWriterLock(this OsState os, ref ReaderWriterLockType rwLock)
{
Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized);
// Don't allow finalizing a locked lock.
Assert.True(ReaderWriterLockImpl.GetReadLockCount(in ReaderWriterLockImpl.GetLockCount(ref rwLock)) == 0);
Assert.True(ReaderWriterLockImpl.GetWriteLocked(in ReaderWriterLockImpl.GetLockCount(ref rwLock)) == 0);
// Mark not initialized.
rwLock.LockState = ReaderWriterLockType.State.NotInitialized;
// Destroy objects.
ReaderWriterLockImpl.GetLockCount(ref rwLock).Cs.FinalizeObject();
}
public static void AcquireReadLock(this OsState os, ref ReaderWriterLockType rwLock)
{
Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized);
os.Impl.AcquireReadLockImpl(ref rwLock);
}
public static bool TryAcquireReadLock(this OsState os, ref ReaderWriterLockType rwLock)
{
Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized);
return os.Impl.TryAcquireReadLockImpl(ref rwLock);
}
public static void ReleaseReadLock(this OsState os, ref ReaderWriterLockType rwLock)
{
Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized);
os.Impl.ReleaseReadLockImpl(ref rwLock);
}
public static void AcquireWriteLock(this OsState os, ref ReaderWriterLockType rwLock)
{
Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized);
os.Impl.AcquireWriteLockImpl(ref rwLock);
}
public static bool TryAcquireWriteLock(this OsState os, ref ReaderWriterLockType rwLock)
{
Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized);
return os.Impl.TryAcquireWriteLockImpl(ref rwLock);
}
public static void ReleaseWriteLock(this OsState os, ref ReaderWriterLockType rwLock)
{
Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized);
os.Impl.ReleaseWriteLockImpl(ref rwLock);
}
public static bool IsReadLockHeld(this OsState os, in ReaderWriterLockType rwLock)
{
Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized);
return ReaderWriterLockImpl.GetReadLockCount(in ReaderWriterLockImpl.GetLockCountRo(in rwLock)) != 0;
}
// Todo: Use Horizon thread APIs
public static bool IsWriteLockHeldByCurrentThread(this OsState os, in ReaderWriterLockType rwLock)
{
Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized);
return rwLock.OwnerThread == Environment.CurrentManagedThreadId &&
ReaderWriterLockImpl.GetWriteLockCount(in rwLock) != 0;
}
public static bool IsReaderWriterLockOwnerThread(this OsState os, in ReaderWriterLockType rwLock)
{
Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized);
return rwLock.OwnerThread == Environment.CurrentManagedThreadId;
}
}
public class ReaderWriterLock : ISharedMutex
{
public const int ReaderWriterLockCountMax = (1 << 15) - 1;
public const int ReadWriteLockWaiterCountMax = (1 << 8) - 1;
private readonly OsState _os;
private ReaderWriterLockType _rwLock;
public ReaderWriterLock(OsState os)
{
_os = os;
_os.InitializeReaderWriterLock(ref _rwLock);
}
public void AcquireReadLock()
{
_os.AcquireReadLock(ref _rwLock);
}
public bool TryAcquireReadLock()
{
return _os.TryAcquireReadLock(ref _rwLock);
}
public void ReleaseReadLock()
{
_os.ReleaseReadLock(ref _rwLock);
}
public void AcquireWriteLock()
{
_os.AcquireWriteLock(ref _rwLock);
}
public bool TryAcquireWriteLock()
{
return _os.TryAcquireWriteLock(ref _rwLock);
}
public void ReleaseWriteLock()
{
_os.ReleaseWriteLock(ref _rwLock);
}
public bool IsReadLockHeld()
{
return _os.IsReadLockHeld(in _rwLock);
}
public bool IsWriteLockHeldByCurrentThread()
{
return _os.IsWriteLockHeldByCurrentThread(in _rwLock);
}
public bool IsLockOwner()
{
return _os.IsReaderWriterLockOwnerThread(in _rwLock);
}
public void LockShared()
{
AcquireReadLock();
}
public bool TryLockShared()
{
return TryAcquireReadLock();
}
public void UnlockShared()
{
ReleaseReadLock();
}
public void Lock()
{
AcquireWriteLock();
}
public bool TryLock()
{
return TryAcquireWriteLock();
}
public void Unlock()
{
ReleaseWriteLock();
}
public ref ReaderWriterLockType GetBase()
{
return ref _rwLock;
}
}
}

View file

@ -0,0 +1,64 @@
using System.Runtime.CompilerServices;
using LibHac.Os.Impl;
namespace LibHac.Os
{
public struct ReaderWriterLockType
{
internal LockCountType LockCount;
internal State LockState;
internal int OwnerThread;
internal InternalConditionVariable CvReadLockWaiter;
internal InternalConditionVariable CvWriteLockWaiter;
public enum State
{
NotInitialized,
Initialized
}
public struct LockCountType
{
public InternalCriticalSection Cs;
public ReaderWriterLockCounter Counter;
public uint WriteLockCount;
}
public struct ReaderWriterLockCounter
{
private uint _counter;
public uint ReadLockCount
{
readonly get => GetBitsValue(_counter, 0, 15);
set => _counter = SetBitsValue(value, 0, 15);
}
public uint WriteLocked
{
readonly get => GetBitsValue(_counter, 15, 1);
set => _counter = SetBitsValue(value, 15, 1);
}
public uint ReadLockWaiterCount
{
readonly get => GetBitsValue(_counter, 16, 8);
set => _counter = SetBitsValue(value, 16, 8);
}
public uint WriteLockWaiterCount
{
readonly get => GetBitsValue(_counter, 24, 8);
set => _counter = SetBitsValue(value, 24, 8);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static uint GetBitsValue(uint value, int bitsOffset, int bitsCount) =>
(value >> bitsOffset) & ~(~default(uint) << bitsCount);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static uint SetBitsValue(uint value, int bitsOffset, int bitsCount) =>
(value & ~(~default(uint) << bitsCount)) << bitsOffset;
}
}
}