Skip to content

Instantly share code, notes, and snippets.

@kroymann
Last active August 13, 2019 15:12
Show Gist options
  • Select an option

  • Save kroymann/dd2145178f12be64610bb0e5fe8a4e54 to your computer and use it in GitHub Desktop.

Select an option

Save kroymann/dd2145178f12be64610bb0e5fe8a4e54 to your computer and use it in GitHub Desktop.
Async locking infrastructure
using System;
using System.Collections.Concurrent;
using System.Threading.Tasks;
namespace RecNet.Common.Synchronization
{
public class AsyncKeyLock<TKey> : IAsyncKeyLock<TKey>
{
#region Constants
private const int MaxPoolSize = 64;
#endregion
#region Fields
private readonly RefCountedConcurrentDictionary<TKey, AsyncLock> _activeLocks;
private readonly ConcurrentBag<AsyncLock> _pool;
#endregion
#region Constructor
public AsyncKeyLock()
{
_pool = new ConcurrentBag<AsyncLock>();
_activeLocks = new RefCountedConcurrentDictionary<TKey, AsyncLock>(CreateLeasedLock, ReturnLeasedLock);
}
#endregion
#region APIs
/// <summary>
/// Locks the current thread asynchronously.
/// </summary>
/// <param name="key">The key identifying the specific object to lock against.</param>
/// <returns>
/// The <see cref="Task{IDisposable}"/> that will release the lock.
/// </returns>
public Task<IDisposable> LockAsync(TKey key)
{
return _activeLocks.Get(key).LockAsync();
}
#endregion
#region RefCountedConcurrentDictionary Callbacks
private AsyncLock CreateLeasedLock(TKey key)
{
if (!_pool.TryTake(out AsyncLock asyncLock))
{
asyncLock = new AsyncLock();
}
asyncLock.OnRelease = () => _activeLocks.Release(key);
return asyncLock;
}
private void ReturnLeasedLock(AsyncLock asyncLock)
{
if (_pool.Count < MaxPoolSize)
{
_pool.Add(asyncLock);
}
else
{
asyncLock.Dispose();
}
}
#endregion
}
}
using System;
using System.Collections.Concurrent;
using System.Threading.Tasks;
namespace RecNet.Common.Synchronization
{
public class AsyncKeyRWLock<TKey> : IAsyncKeyRWLock<TKey>
{
#region Constants
private const int MaxPoolSize = 64;
#endregion
#region Fields
private readonly RefCountedConcurrentDictionary<TKey, AsyncReaderWriterLock> _activeLocks;
private readonly ConcurrentBag<AsyncReaderWriterLock> _pool;
#endregion
#region Constructors
public AsyncKeyRWLock()
{
_pool = new ConcurrentBag<AsyncReaderWriterLock>();
_activeLocks = new RefCountedConcurrentDictionary<TKey, AsyncReaderWriterLock>(CreateLeasedLock, ReturnLeasedLock);
}
#endregion
#region APIs
/// <summary>
/// Locks the current thread in read mode asynchronously.
/// </summary>
/// <param name="key">The key identifying the specific object to lock against.</param>
/// <returns>
/// The <see cref="Task{IDisposable}"/> that will release the lock.
/// </returns>
public Task<IDisposable> ReaderLockAsync(TKey key)
{
return _activeLocks.Get(key).ReaderLockAsync();
}
/// <summary>
/// Locks the current thread in write mode asynchronously.
/// </summary>
/// <param name="key">The key identifying the specific object to lock against.</param>
/// <returns>
/// The <see cref="Task{IDisposable}"/> that will release the lock.
/// </returns>
public Task<IDisposable> WriterLockAsync(TKey key)
{
return _activeLocks.Get(key).WriterLockAsync();
}
#endregion
#region RefCountedConcurrentDictionary Callbacks
private AsyncReaderWriterLock CreateLeasedLock(TKey key)
{
if (!_pool.TryTake(out AsyncReaderWriterLock asyncLock))
{
asyncLock = new AsyncReaderWriterLock();
}
asyncLock.OnRelease = () => _activeLocks.Release(key);
return asyncLock;
}
private void ReturnLeasedLock(AsyncReaderWriterLock asyncLock)
{
if (_pool.Count < MaxPoolSize)
{
_pool.Add(asyncLock);
}
}
#endregion
}
}
using System;
using System.Threading;
using System.Threading.Tasks;
namespace RecNet.Common.Synchronization
{
/// <summary>
/// An asynchronous locker that uses an IDisposable pattern for releasing the lock.
/// </summary>
public class AsyncLock : IDisposable
{
#region Types
private sealed class Releaser : IDisposable
{
private readonly AsyncLock _toRelease;
internal Releaser(AsyncLock toRelease) => _toRelease = toRelease;
public void Dispose() => _toRelease?.Release();
}
#endregion
#region Fields
private readonly SemaphoreSlim _semaphore;
private readonly IDisposable _releaser;
private readonly Task<IDisposable> _releaserTask;
private bool _disposed = false;
#endregion
#region Properties
/// <summary>
/// Gets or sets the callback that should be invoked whenever this lock is released.
/// </summary>
public Action OnRelease { get; set; }
#endregion
#region Constructor
/// <summary>
/// Initializes a new instance of the <see cref="AsyncLock"/> class.
/// </summary>
public AsyncLock()
{
_semaphore = new SemaphoreSlim(1, 1);
_releaser = new Releaser(this);
_releaserTask = Task.FromResult(_releaser);
}
#endregion
#region APIs
/// <summary>
/// Asynchronously obtains the lock. Dispose the returned <see cref="IDisposable"/> to release the lock.
/// </summary>
/// <returns>
/// The <see cref="Task{IDisposable}"/> that will release the lock.
/// </returns>
public Task<IDisposable> LockAsync()
{
var wait = _semaphore.WaitAsync();
// No-allocation fast path when the semaphore wait completed synchronously
return (wait.Status == TaskStatus.RanToCompletion)
? _releaserTask
: AwaitThenReturn(wait, _releaser);
async Task<IDisposable> AwaitThenReturn(Task t, IDisposable r)
{
await t;
return r;
}
}
private void Release()
{
try
{
_semaphore.Release();
}
finally
{
OnRelease?.Invoke();
}
}
#endregion
#region IDisposable
/// <summary>
/// Releases all resources used by the current instance of the <see cref="AsyncLock"/> class.
/// </summary>
public void Dispose()
{
if (!_disposed)
{
_semaphore.Dispose();
_disposed = true;
}
}
#endregion
}
}
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace RecNet.Common.Synchronization
{
/// <summary>
/// An asynchronous locker that provides read and write locking policies.
/// <para>
/// This is based on the following blog post:
/// <see cref="https://devblogs.microsoft.com/pfxteam/building-async-coordination-primitives-part-7-asyncreaderwriterlock/"/>
/// </para>
/// </summary>
public class AsyncReaderWriterLock
{
#region Types
private sealed class Releaser : IDisposable
{
private readonly AsyncReaderWriterLock _toRelease;
private readonly bool _writer;
internal Releaser(AsyncReaderWriterLock toRelease, bool writer)
{
_toRelease = toRelease;
_writer = writer;
}
public void Dispose()
{
if (_toRelease != null)
{
if (_writer)
{
_toRelease.WriterRelease();
}
else
{
_toRelease.ReaderRelease();
}
}
}
}
#endregion
#region Fields
private readonly IDisposable _writerReleaser;
private readonly IDisposable _readerReleaser;
private readonly Task<IDisposable> _writerReleaserTask;
private readonly Task<IDisposable> _readerReleaserTask;
private readonly Queue<TaskCompletionSource<IDisposable>> _waitingWriters;
private TaskCompletionSource<IDisposable> _waitingReader;
private int _readersWaiting;
private int _status;
#endregion
#region Properties
/// <summary>
/// Gets or sets the callback that should be invoked whenever this lock is released.
/// </summary>
public Action OnRelease { get; set; }
#endregion
#region Constructor
/// <summary>
/// Initializes a new instance of the <see cref="AsyncReaderWriterLock"/> class.
/// </summary>
public AsyncReaderWriterLock()
{
_writerReleaser = new Releaser(this, true);
_readerReleaser = new Releaser(this, false);
_writerReleaserTask = Task.FromResult(_writerReleaser);
_readerReleaserTask = Task.FromResult(_readerReleaser);
_waitingWriters = new Queue<TaskCompletionSource<IDisposable>>();
_waitingReader = null;
_readersWaiting = 0;
_status = 0;
}
#endregion
#region APIs
/// <summary>
/// Asynchronously obtains the lock in shared reader mode. Dispose the returned <see cref="IDisposable"/>
/// to release the lock.
/// </summary>
/// <returns>
/// The <see cref="Task{IDisposable}"/> that will release the lock.
/// </returns>
public Task<IDisposable> ReaderLockAsync()
{
lock (_waitingWriters)
{
if (_status >= 0 && _waitingWriters.Count == 0)
{
++_status;
return _readerReleaserTask;
}
else
{
++_readersWaiting;
if (_waitingReader == null)
{
_waitingReader = new TaskCompletionSource<IDisposable>();
}
return _waitingReader.Task.ContinueWith(t => t.Result, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
}
}
}
/// <summary>
/// Asynchronously obtains the lock in exclusive writer mode. Dispose the returned <see cref="IDisposable"/>
/// to release the lock.
/// </summary>
/// <returns>
/// The <see cref="Task{IDisposable}"/> that will release the lock.
/// </returns>
public Task<IDisposable> WriterLockAsync()
{
lock (_waitingWriters)
{
if (_status == 0)
{
_status = -1;
return _writerReleaserTask;
}
else
{
var waiter = new TaskCompletionSource<IDisposable>();
_waitingWriters.Enqueue(waiter);
return waiter.Task;
}
}
}
private void ReaderRelease()
{
try
{
TaskCompletionSource<IDisposable> toWake = null;
lock (_waitingWriters)
{
--_status;
if (_status == 0 && _waitingWriters.Count > 0)
{
_status = -1;
toWake = _waitingWriters.Dequeue();
}
}
toWake?.SetResult(_writerReleaser);
}
finally
{
OnRelease?.Invoke();
}
}
private void WriterRelease()
{
try
{
TaskCompletionSource<IDisposable> toWake = null;
bool toWakeIsWriter = false;
lock (_waitingWriters)
{
if (_waitingWriters.Count > 0)
{
toWake = _waitingWriters.Dequeue();
toWakeIsWriter = true;
}
else if (_readersWaiting > 0)
{
toWake = _waitingReader;
_status = _readersWaiting;
_readersWaiting = 0;
_waitingReader = null;
}
else
{
_status = 0;
}
}
toWake?.SetResult(toWakeIsWriter ? _writerReleaser : _readerReleaser);
}
finally
{
OnRelease?.Invoke();
}
}
#endregion
}
}
using System;
using System.Threading.Tasks;
namespace RecNet.Common.Synchronization
{
/// <summary>
/// The async key lock prevents multiple asynchronous threads acting upon the same object with the given key at the same time.
/// It is designed so that it does not block unique requests allowing a high throughput.
/// </summary>
public interface IAsyncKeyLock<in TKey>
{
/// <summary>
/// Locks the current thread asynchronously.
/// </summary>
/// <param name="key">The key identifying the specific object to lock against.</param>
/// <returns>
/// The <see cref="Task{IDisposable}"/> that will release the lock.
/// </returns>
Task<IDisposable> LockAsync(TKey key);
}
}
using System;
using System.Threading.Tasks;
namespace RecNet.Common.Synchronization
{
/// <summary>
/// The async key lock prevents multiple asynchronous threads acting upon the same object with the given key at the same time.
/// It is designed so that it does not block unique requests allowing a high throughput.
/// </summary>
public interface IAsyncKeyRWLock<in TKey>
{
/// <summary>
/// Locks the current thread in read mode asynchronously.
/// </summary>
/// <param name="key">The key identifying the specific object to lock against.</param>
/// <returns>
/// The <see cref="Task{IDisposable}"/> that will release the lock.
/// </returns>
Task<IDisposable> ReaderLockAsync(TKey key);
/// <summary>
/// Locks the current thread in write mode asynchronously.
/// </summary>
/// <param name="key">The key identifying the specific object to lock against.</param>
/// <returns>
/// The <see cref="Task{IDisposable}"/> that will release the lock.
/// </returns>
Task<IDisposable> WriterLockAsync(TKey key);
}
}
using System;
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
namespace RecNet.Common.Synchronization
{
/// <summary>
/// Represents a thread-safe collection of reference-counted key/value pairs that can be accessed by multiple
/// threads concurrently. Values that don't yet exist are automatically created using a caller supplied
/// value factory method, and when their final refcount is released they are removed from the dictionary.
/// </summary>
public class RefCountedConcurrentDictionary<TKey, TValue> where TValue : class
{
#region Types
/// <summary>
/// Simple immutable tuple that combines a <typeparamref name="TValue"/> instance with a ref count integer.
/// </summary>
private class RefCountedValue : IEquatable<RefCountedValue>
{
public readonly TValue Value;
public readonly int RefCount;
public RefCountedValue(TValue value, int refCount)
{
Value = value;
RefCount = refCount;
}
public bool Equals(RefCountedValue other) => (RefCount == other.RefCount) && EqualityComparer<TValue>.Default.Equals(Value, other.Value);
public override bool Equals(object obj) => (obj is RefCountedValue other) && Equals(other);
public override int GetHashCode() => ((RefCount << 5) + RefCount) ^ Value.GetHashCode();
}
#endregion
#region Fields
private readonly ConcurrentDictionary<TKey, RefCountedValue> _dictionary;
private readonly Func<TKey, TValue> _valueFactory;
private readonly Action<TValue> _valueReleaser;
#endregion
#region Constructors
/// <summary>
/// Initializes a new instance of the <see cref="RefCountedConcurrentDictionary{TKey, TValue}"/> class that is empty,
/// has the default concurrency level, has the default initial capacity, and uses the default comparer for the key type.
/// </summary>
/// <param name="valueFactory">Factory method that generates a new <typeparamref name="TValue"/> for a given <typeparamref name="TKey"/>.</param>
/// <param name="valueReleaser">Optional callback that is used to cleanup <typeparamref name="TValue"/>s after their final ref count is released.</param>
public RefCountedConcurrentDictionary(Func<TKey, TValue> valueFactory, Action<TValue> valueReleaser = null)
: this(new ConcurrentDictionary<TKey, RefCountedValue>(), valueFactory, valueReleaser) { }
/// <summary>
/// Initializes a new instance of the <see cref="RefCountedConcurrentDictionary{TKey, TValue}"/> class that is empty,
/// has the default concurrency level and capacity,, and uses the specified <see cref="IEqualityComparer{TKey}"/>.
/// </summary>
/// <param name="comparer">The <see cref="IEqualityComparer{TKey}"/> implementation to use when comparing keys.</param>
/// <param name="valueFactory">Factory method that generates a new <typeparamref name="TValue"/> for a given <typeparamref name="TKey"/>.</param>
/// <param name="valueReleaser">Optional callback that is used to cleanup <typeparamref name="TValue"/>s after their final ref count is released.</param>
public RefCountedConcurrentDictionary(IEqualityComparer<TKey> comparer, Func<TKey, TValue> valueFactory, Action<TValue> valueReleaser)
: this(new ConcurrentDictionary<TKey, RefCountedValue>(comparer), valueFactory, valueReleaser) { }
/// <summary>
/// Initializes a new instance of the <see cref="RefCountedConcurrentDictionary{TKey, TValue}"/> class that is empty,
/// has the specified concurrency level and capacity, and uses the default comparer for the key type.
/// </summary>
/// <param name="concurrencyLevel">The estimated number of threads that will access the <see cref="RefCountedConcurrentDictionary{TKey, TValue}"/> concurrently</param>
/// <param name="capacity">The initial number of elements that the <see cref="RefCountedConcurrentDictionary{TKey, TValue}"/> can contain.</param>
/// <param name="valueFactory">Factory method that generates a new <typeparamref name="TValue"/> for a given <typeparamref name="TKey"/>.</param>
/// <param name="valueReleaser">Optional callback that is used to cleanup <typeparamref name="TValue"/>s after their final ref count is released.</param>
public RefCountedConcurrentDictionary(int concurrencyLevel, int capacity, Func<TKey, TValue> valueFactory, Action<TValue> valueReleaser = null)
: this(new ConcurrentDictionary<TKey, RefCountedValue>(concurrencyLevel, capacity), valueFactory, valueReleaser) { }
/// <summary>
/// Initializes a new instance of the <see cref="RefCountedConcurrentDictionary{TKey, TValue}"/> class that is empty,
/// has the specified concurrency level, has the specified initial capacity, and uses the specified
/// <see cref="IEqualityComparer{TKey}"/>.
/// </summary>
/// <param name="concurrencyLevel">The estimated number of threads that will access the <see cref="RefCountedConcurrentDictionary{TKey, TValue}"/> concurrently</param>
/// <param name="capacity">The initial number of elements that the <see cref="RefCountedConcurrentDictionary{TKey, TValue}"/> can contain.</param>
/// <param name="comparer">The <see cref="IEqualityComparer{TKey}"/> implementation to use when comparing keys.</param>
/// <param name="valueFactory">Factory method that generates a new <typeparamref name="TValue"/> for a given <typeparamref name="TKey"/>.</param>
/// <param name="valueReleaser">Optional callback that is used to cleanup <typeparamref name="TValue"/>s after their final ref count is released.</param>
public RefCountedConcurrentDictionary(int concurrencyLevel, int capacity, IEqualityComparer<TKey> comparer, Func<TKey, TValue> valueFactory, Action<TValue> valueReleaser)
: this(new ConcurrentDictionary<TKey, RefCountedValue>(concurrencyLevel, capacity, comparer), valueFactory, valueReleaser) { }
private RefCountedConcurrentDictionary(ConcurrentDictionary<TKey, RefCountedValue> dictionary, Func<TKey, TValue> valueFactory, Action<TValue> valueReleaser)
{
_dictionary = dictionary;
_valueFactory = valueFactory ?? throw new ArgumentNullException(nameof(valueFactory));
_valueReleaser = valueReleaser;
}
#endregion
#region APIs
/// <summary>
/// Obtains a reference to the value corresponding to the specified key. If no such value exists in the
/// dictionary, then a new value is generated using the value factory method supplied in the constructor.
/// To prevent leaks, this reference MUST be released via <see cref="Release(TKey)"/.
/// </summary>
/// <param name="key">The key of the element to add ref.</param>
/// <returns>The referenced object.</returns>
public TValue Get(TKey key)
{
while (true)
{
if (_dictionary.TryGetValue(key, out var refCountedValue))
{
// Increment ref count
if (_dictionary.TryUpdate(key, new RefCountedValue(refCountedValue.Value, refCountedValue.RefCount + 1), refCountedValue))
{
return refCountedValue.Value;
}
}
else
{
// Add new value to dictionary
TValue value = _valueFactory(key);
if (_dictionary.TryAdd(key, new RefCountedValue(value, 1)))
{
return value;
}
else
{
_valueReleaser?.Invoke(value);
}
}
}
}
/// <summary>
/// Releases a reference to the value corresponding to the specified key. If this reference was the last
/// remaining reference to the value, then the value is removed from the dictionary, and the optional value
/// releaser callback is invoked.
/// </summary>
/// <param name="key">THe key of the element to release.</param>
public void Release(TKey key)
{
while (true)
{
if (!_dictionary.TryGetValue(key, out var refCountedValue))
{
// This is BAD. It indicates a ref counting problem where someone is either double-releasing,
// or they're releasing a key that they never obtained in the first place!!
throw new InvalidOperationException($"Tried to release value that doesn't exist in the dictionary ({key})!");
}
// If we're releasing the last reference, then try to remove the value from the dictionary.
// Otherwise, try to decrement the reference count.
if (refCountedValue.RefCount == 1)
{
// Remove from dictionary. We use the ICollection<>.Remove() method instead of the ConcurrentDictionary.TryRemove()
// because this specific API will only succeed if the value hasn't been changed by another thread.
if (((ICollection<KeyValuePair<TKey, RefCountedValue>>)_dictionary).Remove(new KeyValuePair<TKey, RefCountedValue>(key, refCountedValue)))
{
_valueReleaser?.Invoke(refCountedValue.Value);
return;
}
}
else
{
// Decrement ref count
if (_dictionary.TryUpdate(key, new RefCountedValue(refCountedValue.Value, refCountedValue.RefCount - 1), refCountedValue))
{
return;
}
}
}
}
#endregion
#region Testing/Debug Hooks
/// <summary>
/// Get an enumeration over the contents of the dictionary for testing/debugging purposes
/// </summary>
internal IEnumerable<(TKey key, TValue value, int refCount)> DebugGetContents()
{
return new RefCountedDictionaryEnumerable(this);
}
private class RefCountedDictionaryEnumerable : IEnumerable<(TKey key, TValue value, int refCount)>
{
private readonly RefCountedConcurrentDictionary<TKey, TValue> _dictionary;
internal RefCountedDictionaryEnumerable(RefCountedConcurrentDictionary<TKey, TValue> dictionary) => _dictionary = dictionary;
public IEnumerator<(TKey key, TValue value, int refCount)> GetEnumerator() => new RefCountedDictionaryEnumerator(_dictionary);
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
private class RefCountedDictionaryEnumerator : IEnumerator<(TKey key, TValue value, int refCount)>
{
private readonly IEnumerator<KeyValuePair<TKey, RefCountedValue>> enumerator;
public RefCountedDictionaryEnumerator(RefCountedConcurrentDictionary<TKey, TValue> dictionary) => enumerator = dictionary._dictionary.GetEnumerator();
public void Dispose() => enumerator.Dispose();
public bool MoveNext() => enumerator.MoveNext();
public void Reset() => enumerator.Reset();
object IEnumerator.Current => Current;
public (TKey key, TValue value, int refCount) Current
{
get
{
var keyValuePair = enumerator.Current;
return (keyValuePair.Key, keyValuePair.Value.Value, keyValuePair.Value.RefCount);
}
}
}
#endregion
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment