-
-
Save ChristopherHaws/b1c54b95838f1513bfb74fa1c8e408f3 to your computer and use it in GitHub Desktop.
| using System.Collections.Generic; | |
| namespace System.Threading.Tasks | |
| { | |
| public static class AsyncUtilities | |
| { | |
| /// <summary> | |
| /// Execute's an async Task{T} method which has a void return value synchronously | |
| /// </summary> | |
| /// <param name="task">Task{T} method to execute</param> | |
| public static void RunSync(Func<Task> task) | |
| { | |
| var oldContext = SynchronizationContext.Current; | |
| var sync = new ExclusiveSynchronizationContext(); | |
| SynchronizationContext.SetSynchronizationContext(sync); | |
| sync.Post(async _ => | |
| { | |
| try | |
| { | |
| await task(); | |
| } | |
| catch (Exception e) | |
| { | |
| sync.InnerException = e; | |
| throw; | |
| } | |
| finally | |
| { | |
| sync.EndMessageLoop(); | |
| } | |
| }, null); | |
| sync.BeginMessageLoop(); | |
| SynchronizationContext.SetSynchronizationContext(oldContext); | |
| } | |
| /// <summary> | |
| /// Execute's an async Task{T} method which has a T return type synchronously | |
| /// </summary> | |
| /// <typeparam name="T">Return Type</typeparam> | |
| /// <param name="task">Task{T} method to execute</param> | |
| /// <returns></returns> | |
| public static T RunSync<T>(Func<Task<T>> task) | |
| { | |
| var oldContext = SynchronizationContext.Current; | |
| var sync = new ExclusiveSynchronizationContext(); | |
| SynchronizationContext.SetSynchronizationContext(sync); | |
| T ret = default; | |
| sync.Post(async _ => | |
| { | |
| try | |
| { | |
| ret = await task(); | |
| } | |
| catch (Exception e) | |
| { | |
| sync.InnerException = e; | |
| throw; | |
| } | |
| finally | |
| { | |
| sync.EndMessageLoop(); | |
| } | |
| }, null); | |
| sync.BeginMessageLoop(); | |
| SynchronizationContext.SetSynchronizationContext(oldContext); | |
| return ret; | |
| } | |
| private class ExclusiveSynchronizationContext : SynchronizationContext, IDisposable | |
| { | |
| private readonly AutoResetEvent workItemsWaiting = new AutoResetEvent(false); | |
| private readonly Queue<Tuple<SendOrPostCallback, Object>> items = new Queue<Tuple<SendOrPostCallback, Object>>(); | |
| private bool done; | |
| public Exception InnerException { get; set; } | |
| public void Dispose() | |
| { | |
| this.workItemsWaiting?.Dispose(); | |
| } | |
| public override void Send(SendOrPostCallback d, Object state) | |
| { | |
| throw new NotSupportedException("We cannot send to our same thread"); | |
| } | |
| public override void Post(SendOrPostCallback d, Object state) | |
| { | |
| lock (this.items) | |
| { | |
| this.items.Enqueue(Tuple.Create(d, state)); | |
| } | |
| this.workItemsWaiting.Set(); | |
| } | |
| public void EndMessageLoop() | |
| { | |
| this.Post(_ => this.done = true, null); | |
| } | |
| public void BeginMessageLoop() | |
| { | |
| while (!this.done) | |
| { | |
| Tuple<SendOrPostCallback, object> task = null; | |
| lock (this.items) | |
| { | |
| if (this.items.Count > 0) | |
| { | |
| task = this.items.Dequeue(); | |
| } | |
| } | |
| if (task != null) | |
| { | |
| task.Item1(task.Item2); | |
| if (this.InnerException != null) // the method threw an exeption | |
| { | |
| throw new AggregateException("AsyncHelpers.Run method threw an exception.", this.InnerException); | |
| } | |
| } | |
| else | |
| { | |
| this.workItemsWaiting.WaitOne(); | |
| } | |
| } | |
| } | |
| public override SynchronizationContext CreateCopy() | |
| { | |
| return this; | |
| } | |
| } | |
| } | |
| } |
| using System.Data.Common; | |
| using System.Data.SqlClient; | |
| using System.Threading.Tasks; | |
| using Microsoft.Azure.Services.AppAuthentication; | |
| using Microsoft.EntityFrameworkCore.SqlServer.Storage.Internal; | |
| using Microsoft.EntityFrameworkCore.Storage; | |
| namespace Microsoft.EntityFrameworkCore | |
| { | |
| public static class AzureSqlServerConnectionExtensions | |
| { | |
| public static void UseAzureAccessToken(this DbContextOptionsBuilder options) | |
| { | |
| options.ReplaceService<ISqlServerConnection, AzureSqlServerConnection>(); | |
| } | |
| } | |
| public class AzureSqlServerConnection : SqlServerConnection | |
| { | |
| // Compensate for slow SQL Server database creation | |
| private const int DefaultMasterConnectionCommandTimeout = 60; | |
| private static readonly AzureServiceTokenProvider TokenProvider = new AzureServiceTokenProvider(); | |
| public AzureSqlServerConnection(RelationalConnectionDependencies dependencies) | |
| : base(dependencies) | |
| { | |
| } | |
| protected override DbConnection CreateDbConnection() => new SqlConnection(this.ConnectionString) | |
| { | |
| // AzureServiceTokenProvider handles caching the token and refreshing it before it expires | |
| AccessToken = AsyncUtilities.RunSync(() => TokenProvider.GetAccessTokenAsync("https://database.windows.net/")) | |
| }; | |
| public override ISqlServerConnection CreateMasterConnection() | |
| { | |
| var connectionStringBuilder = new SqlConnectionStringBuilder(this.ConnectionString) | |
| { | |
| InitialCatalog = "master" | |
| }; | |
| connectionStringBuilder.Remove("AttachDBFilename"); | |
| var contextOptions = new DbContextOptionsBuilder() | |
| .UseSqlServer( | |
| connectionStringBuilder.ConnectionString, | |
| b => b.CommandTimeout(this.CommandTimeout ?? DefaultMasterConnectionCommandTimeout)) | |
| .Options; | |
| return new AzureSqlServerConnection(this.Dependencies.With(contextOptions)); | |
| } | |
| } | |
| } |
| public class Startup | |
| { | |
| private readonly IConfiguration configuration; | |
| private readonly IHostingEnvironment env; | |
| public Startup(IConfiguration configuration, IHostingEnvironment env) | |
| { | |
| this.configuration = configuration; | |
| this.env = env; | |
| } | |
| // This method gets called by the runtime. Use this method to add services to the container. | |
| public void ConfigureServices(IServiceCollection services) | |
| { | |
| services.AddDbContextPool<ApplicationContext>(options => | |
| { | |
| options.UseSqlServer(this.configuration.GetConnectionString("DefaultConnection")); | |
| if (!this.env.IsDevelopment()) | |
| { | |
| options.UseAzureAccessToken(); | |
| } | |
| }); | |
| // Removed unrelated code... | |
| } | |
| // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. | |
| public void Configure(IApplicationBuilder app) | |
| { | |
| // Removed unrelated code... | |
| } | |
| } |
@sebader I cant remember where I got it. I feel like I got it from a MS repo at some point, but I don't remember. I updated the gist with the version I am using. Glad it works for you!
for the sake of completeness 😊 the nuget package Microsoft.Azure.Services.AppAuthentication is needed
Why is it needed to override CreateMasterConnection()?
@OskarKlintrot Because it returns AzureSqlServerConnection instead of SqlServerConnection.
I missed that one, thanks for the clarification!
FYI, to anyone interested, I moved TokenProvider to be a static readonly field so that the caching of tokens works properly.
Did it not work properly before? The token is already cached in a static field.
@OskarKlintrot I was not aware of that, thanks for the info. I suppose all that my update does then is remove a small allocation. ;)
That saves ~28µs on my machine if I remember correctly when I used benchmarkdotnet to see how long time it took to create a new instance :) I ended up using an extension (IDbContextOptionsExtension) to be able to use EF's DI instead and be able to mock it for unit testing purposes. It's probably a lot slower, though.
Thanks a lot for this!
Only one thing: You are using
AsyncUtilitieshere without saying where you got that from. I found this and it works. Was that the one you are using?