using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace TaskSchedulerBug
{
class Program
{
private static CustomTaskScheduler _scheduler;
private static TaskFactory _factory;
static void Main(string[] args)
{
_scheduler = new CustomTaskScheduler();
_factory = new TaskFactory(_scheduler);
var done = false;
_factory.StartNew(async () =>
{
await DoWork();
done = true;
});
while (!done)
{
_scheduler.Run();
Thread.Sleep(1);
}
}
static async Task DoWork()
{
Console.WriteLine("Before await: {0}", TaskScheduler.Current);
var name = await GetName();
Console.WriteLine("After await: {0}", TaskScheduler.Current);
Console.WriteLine("hello {0}", name);
}
static async Task<string> GetName()
{
return "Brian";
}
}
class CustomTaskScheduler : TaskScheduler
{
private ConcurrentQueue<Task> _tasks;
public CustomTaskScheduler()
{
_tasks = new ConcurrentQueue<Task>();
}
public void Run()
{
var count = _tasks.Count;
if (count == 0)
return;
while (--count >= 0)
{
Task task;
if (!_tasks.TryDequeue(out task))
break;
TryExecuteTask(task);
}
}
protected override IEnumerable<Task> GetScheduledTasks()
{
return _tasks;
}
protected override void QueueTask(Task task)
{
_tasks.Enqueue(task);
}
protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
{
if (taskWasPreviouslyQueued && !TryDequeue(task))
return false;
return TryExecuteTask(task);
}
}
}