diff --git a/src/NServiceBus.AcceptanceTests/NServiceBus.AcceptanceTests.csproj b/src/NServiceBus.AcceptanceTests/NServiceBus.AcceptanceTests.csproj
index 5f1bb5d62fe..560da4242ae 100644
--- a/src/NServiceBus.AcceptanceTests/NServiceBus.AcceptanceTests.csproj
+++ b/src/NServiceBus.AcceptanceTests/NServiceBus.AcceptanceTests.csproj
@@ -92,6 +92,7 @@
+
diff --git a/src/NServiceBus.AcceptanceTests/Sagas/When_saga_started_concurrently.cs b/src/NServiceBus.AcceptanceTests/Sagas/When_saga_started_concurrently.cs
new file mode 100644
index 00000000000..06e41f873b6
--- /dev/null
+++ b/src/NServiceBus.AcceptanceTests/Sagas/When_saga_started_concurrently.cs
@@ -0,0 +1,150 @@
+namespace NServiceBus.AcceptanceTests.Sagas
+{
+ using System;
+ using System.Threading.Tasks;
+ using AcceptanceTesting;
+ using EndpointTemplates;
+ using NUnit.Framework;
+
+ public class When_saga_started_concurrently : NServiceBusAcceptanceTest
+ {
+ [Test]
+ public async Task Should_start_single_saga()
+ {
+ var context = await Scenario.Define(c => { c.SomeId = Guid.NewGuid().ToString(); })
+ .WithEndpoint(b =>
+ {
+ b.When((session, ctx) =>
+ {
+ var t1 = session.SendLocal(new StartMessageOne
+ {
+ SomeId = ctx.SomeId
+ });
+ var t2 = session.SendLocal(new StartMessageTwo
+ {
+ SomeId = ctx.SomeId
+ });
+ return Task.WhenAll(t1, t2);
+ });
+ })
+ .Done(c => c.PlacedSagaId != Guid.Empty && c.BilledSagaId != Guid.Empty)
+ .Run();
+
+ Assert.AreNotEqual(Guid.Empty, context.PlacedSagaId);
+ Assert.AreNotEqual(Guid.Empty, context.BilledSagaId);
+ Assert.AreEqual(context.PlacedSagaId, context.BilledSagaId, "Both messages should have been handled by the same saga, but SagaIds don't match.");
+ }
+
+ class Context : ScenarioContext
+ {
+ public string SomeId { get; set; }
+ public Guid PlacedSagaId { get; set; }
+ public Guid BilledSagaId { get; set; }
+ public bool SagaCompleted { get; set; }
+ }
+
+ class ConcurrentHandlerEndpoint : EndpointConfigurationBuilder
+ {
+ public ConcurrentHandlerEndpoint()
+ {
+ EndpointSetup(b =>
+ {
+ b.LimitMessageProcessingConcurrencyTo(2);
+ b.Recoverability().Immediate(immediate => immediate.NumberOfRetries(3));
+ });
+ }
+
+ class ConcurrentlyStartedSaga : Saga,
+ IAmStartedByMessages,
+ IAmStartedByMessages
+ {
+ public Context Context { get; set; }
+
+ public async Task Handle(StartMessageOne message, IMessageHandlerContext context)
+ {
+ Data.Placed = true;
+ await context.SendLocal(new SuccessfulProcessing
+ {
+ SagaId = Data.Id,
+ Type = nameof(StartMessageOne)
+ });
+ CheckForCompletion(context);
+ }
+
+ public async Task Handle(StartMessageTwo message, IMessageHandlerContext context)
+ {
+ Data.Billed = true;
+ await context.SendLocal(new SuccessfulProcessing
+ {
+ SagaId = Data.Id,
+ Type = nameof(StartMessageTwo)
+ });
+ CheckForCompletion(context);
+ }
+
+ protected override void ConfigureHowToFindSaga(SagaPropertyMapper mapper)
+ {
+ mapper.ConfigureMapping(msg => msg.SomeId).ToSaga(saga => saga.OrderId);
+ mapper.ConfigureMapping(msg => msg.SomeId).ToSaga(saga => saga.OrderId);
+ }
+
+ void CheckForCompletion(IMessageHandlerContext context)
+ {
+ if (!Data.Billed || !Data.Placed)
+ {
+ return;
+ }
+ MarkAsComplete();
+ Context.SagaCompleted = true;
+ }
+ }
+
+ class ConcurrentlyStartedSagaData : ContainSagaData
+ {
+ public virtual string OrderId { get; set; }
+ public virtual bool Placed { get; set; }
+ public virtual bool Billed { get; set; }
+ }
+
+ // Intercepts the messages sent out by the saga
+ class LogSuccessfulHandler : IHandleMessages
+ {
+ public Context Context { get; set; }
+
+ public Task Handle(SuccessfulProcessing message, IMessageHandlerContext context)
+ {
+ if (message.Type == nameof(StartMessageOne))
+ {
+ Context.PlacedSagaId = message.SagaId;
+ }
+ else if (message.Type == nameof(StartMessageTwo))
+ {
+ Context.BilledSagaId = message.SagaId;
+ }
+ else
+ {
+ throw new Exception("Unknown type");
+ }
+
+ return Task.FromResult(0);
+ }
+ }
+ }
+
+ class StartMessageOne : ICommand
+ {
+ public string SomeId { get; set; }
+ }
+
+ class StartMessageTwo : ICommand
+ {
+ public string SomeId { get; set; }
+ }
+
+ class SuccessfulProcessing : ICommand
+ {
+ public string Type { get; set; }
+ public Guid SagaId { get; set; }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/NServiceBus.Core/Persistence/InMemory/SagaPersister/InMemorySagaPersister.cs b/src/NServiceBus.Core/Persistence/InMemory/SagaPersister/InMemorySagaPersister.cs
index 5b9aad08bc8..552b2d5f926 100644
--- a/src/NServiceBus.Core/Persistence/InMemory/SagaPersister/InMemorySagaPersister.cs
+++ b/src/NServiceBus.Core/Persistence/InMemory/SagaPersister/InMemorySagaPersister.cs
@@ -17,7 +17,11 @@ public Task Complete(IContainSagaData sagaData, SynchronizedStorageSession sessi
inMemSession.Enlist(() =>
{
VersionedSagaEntity value;
- data.TryRemove(sagaData.Id, out value);
+ if (data.TryRemove(sagaData.Id, out value))
+ {
+ object lockToken;
+ lockers.TryRemove(value.LockTokenKey, out lockToken);
+ }
});
return TaskEx.CompletedTask;
}
@@ -66,14 +70,19 @@ public Task Save(IContainSagaData sagaData, SagaCorrelationProperty correlationP
var inMemSession = (InMemorySynchronizedStorageSession) session;
inMemSession.Enlist(() =>
{
- if (correlationProperty != SagaCorrelationProperty.None)
+ var lockenTokenKey = $"{sagaData.GetType().FullName}.{correlationProperty?.Name ?? "None"}.{correlationProperty?.Value ?? "None"}";
+ var lockToken = lockers.GetOrAdd(lockenTokenKey, key => new object());
+ lock (lockToken)
{
- ValidateUniqueProperties(correlationProperty, sagaData);
+ if (correlationProperty != SagaCorrelationProperty.None)
+ {
+ ValidateUniqueProperties(correlationProperty, sagaData);
+ }
+
+ data.AddOrUpdate(sagaData.Id,
+ id => new VersionedSagaEntity(sagaData, lockenTokenKey),
+ (id, original) => new VersionedSagaEntity(sagaData, lockenTokenKey, original)); // we can never end up here.
}
-
- data.AddOrUpdate(sagaData.Id,
- id => new VersionedSagaEntity(sagaData),
- (id, original) => new VersionedSagaEntity(sagaData, original));
});
return TaskEx.CompletedTask;
}
@@ -84,8 +93,8 @@ public Task Update(IContainSagaData sagaData, SynchronizedStorageSession session
inMemSession.Enlist(() =>
{
data.AddOrUpdate(sagaData.Id,
- id => new VersionedSagaEntity(sagaData),
- (id, original) => new VersionedSagaEntity(sagaData, original));
+ id => new VersionedSagaEntity(sagaData, $"{sagaData.GetType().FullName}.None.None"), // we can never end up here.
+ (id, original) => new VersionedSagaEntity(sagaData, original.LockTokenKey, original));
});
return TaskEx.CompletedTask;
}
@@ -94,6 +103,7 @@ void ValidateUniqueProperties(SagaCorrelationProperty correlationProperty, ICont
{
var sagaType = saga.GetType();
var existingSagas = new List();
+ // ReSharper disable once LoopCanBeConvertedToQuery
foreach (var s in data)
{
if (s.Value.SagaData.GetType() == sagaType && (s.Key != saga.Id))
@@ -121,11 +131,13 @@ void ValidateUniqueProperties(SagaCorrelationProperty correlationProperty, ICont
}
ConcurrentDictionary data = new ConcurrentDictionary();
+ ConcurrentDictionary lockers = new ConcurrentDictionary();
class VersionedSagaEntity
{
- public VersionedSagaEntity(IContainSagaData sagaData, VersionedSagaEntity original = null)
+ public VersionedSagaEntity(IContainSagaData sagaData, string lockTokenKey, VersionedSagaEntity original = null)
{
+ LockTokenKey = lockTokenKey;
SagaData = DeepClone(sagaData);
if (original != null)
{
@@ -171,6 +183,7 @@ static IContainSagaData DeepClone(IContainSagaData source)
}
public IContainSagaData SagaData;
+ public string LockTokenKey;
ConditionalWeakTable versionCache;