diff --git a/src/NServiceBus.AcceptanceTests/NServiceBus.AcceptanceTests.csproj b/src/NServiceBus.AcceptanceTests/NServiceBus.AcceptanceTests.csproj index 5f1bb5d62fe..560da4242ae 100644 --- a/src/NServiceBus.AcceptanceTests/NServiceBus.AcceptanceTests.csproj +++ b/src/NServiceBus.AcceptanceTests/NServiceBus.AcceptanceTests.csproj @@ -92,6 +92,7 @@ + diff --git a/src/NServiceBus.AcceptanceTests/Sagas/When_saga_started_concurrently.cs b/src/NServiceBus.AcceptanceTests/Sagas/When_saga_started_concurrently.cs new file mode 100644 index 00000000000..06e41f873b6 --- /dev/null +++ b/src/NServiceBus.AcceptanceTests/Sagas/When_saga_started_concurrently.cs @@ -0,0 +1,150 @@ +namespace NServiceBus.AcceptanceTests.Sagas +{ + using System; + using System.Threading.Tasks; + using AcceptanceTesting; + using EndpointTemplates; + using NUnit.Framework; + + public class When_saga_started_concurrently : NServiceBusAcceptanceTest + { + [Test] + public async Task Should_start_single_saga() + { + var context = await Scenario.Define(c => { c.SomeId = Guid.NewGuid().ToString(); }) + .WithEndpoint(b => + { + b.When((session, ctx) => + { + var t1 = session.SendLocal(new StartMessageOne + { + SomeId = ctx.SomeId + }); + var t2 = session.SendLocal(new StartMessageTwo + { + SomeId = ctx.SomeId + }); + return Task.WhenAll(t1, t2); + }); + }) + .Done(c => c.PlacedSagaId != Guid.Empty && c.BilledSagaId != Guid.Empty) + .Run(); + + Assert.AreNotEqual(Guid.Empty, context.PlacedSagaId); + Assert.AreNotEqual(Guid.Empty, context.BilledSagaId); + Assert.AreEqual(context.PlacedSagaId, context.BilledSagaId, "Both messages should have been handled by the same saga, but SagaIds don't match."); + } + + class Context : ScenarioContext + { + public string SomeId { get; set; } + public Guid PlacedSagaId { get; set; } + public Guid BilledSagaId { get; set; } + public bool SagaCompleted { get; set; } + } + + class ConcurrentHandlerEndpoint : EndpointConfigurationBuilder + { + public ConcurrentHandlerEndpoint() + { + EndpointSetup(b => + { + b.LimitMessageProcessingConcurrencyTo(2); + b.Recoverability().Immediate(immediate => immediate.NumberOfRetries(3)); + }); + } + + class ConcurrentlyStartedSaga : Saga, + IAmStartedByMessages, + IAmStartedByMessages + { + public Context Context { get; set; } + + public async Task Handle(StartMessageOne message, IMessageHandlerContext context) + { + Data.Placed = true; + await context.SendLocal(new SuccessfulProcessing + { + SagaId = Data.Id, + Type = nameof(StartMessageOne) + }); + CheckForCompletion(context); + } + + public async Task Handle(StartMessageTwo message, IMessageHandlerContext context) + { + Data.Billed = true; + await context.SendLocal(new SuccessfulProcessing + { + SagaId = Data.Id, + Type = nameof(StartMessageTwo) + }); + CheckForCompletion(context); + } + + protected override void ConfigureHowToFindSaga(SagaPropertyMapper mapper) + { + mapper.ConfigureMapping(msg => msg.SomeId).ToSaga(saga => saga.OrderId); + mapper.ConfigureMapping(msg => msg.SomeId).ToSaga(saga => saga.OrderId); + } + + void CheckForCompletion(IMessageHandlerContext context) + { + if (!Data.Billed || !Data.Placed) + { + return; + } + MarkAsComplete(); + Context.SagaCompleted = true; + } + } + + class ConcurrentlyStartedSagaData : ContainSagaData + { + public virtual string OrderId { get; set; } + public virtual bool Placed { get; set; } + public virtual bool Billed { get; set; } + } + + // Intercepts the messages sent out by the saga + class LogSuccessfulHandler : IHandleMessages + { + public Context Context { get; set; } + + public Task Handle(SuccessfulProcessing message, IMessageHandlerContext context) + { + if (message.Type == nameof(StartMessageOne)) + { + Context.PlacedSagaId = message.SagaId; + } + else if (message.Type == nameof(StartMessageTwo)) + { + Context.BilledSagaId = message.SagaId; + } + else + { + throw new Exception("Unknown type"); + } + + return Task.FromResult(0); + } + } + } + + class StartMessageOne : ICommand + { + public string SomeId { get; set; } + } + + class StartMessageTwo : ICommand + { + public string SomeId { get; set; } + } + + class SuccessfulProcessing : ICommand + { + public string Type { get; set; } + public Guid SagaId { get; set; } + } + } +} \ No newline at end of file diff --git a/src/NServiceBus.Core/Persistence/InMemory/SagaPersister/InMemorySagaPersister.cs b/src/NServiceBus.Core/Persistence/InMemory/SagaPersister/InMemorySagaPersister.cs index 5b9aad08bc8..552b2d5f926 100644 --- a/src/NServiceBus.Core/Persistence/InMemory/SagaPersister/InMemorySagaPersister.cs +++ b/src/NServiceBus.Core/Persistence/InMemory/SagaPersister/InMemorySagaPersister.cs @@ -17,7 +17,11 @@ public Task Complete(IContainSagaData sagaData, SynchronizedStorageSession sessi inMemSession.Enlist(() => { VersionedSagaEntity value; - data.TryRemove(sagaData.Id, out value); + if (data.TryRemove(sagaData.Id, out value)) + { + object lockToken; + lockers.TryRemove(value.LockTokenKey, out lockToken); + } }); return TaskEx.CompletedTask; } @@ -66,14 +70,19 @@ public Task Save(IContainSagaData sagaData, SagaCorrelationProperty correlationP var inMemSession = (InMemorySynchronizedStorageSession) session; inMemSession.Enlist(() => { - if (correlationProperty != SagaCorrelationProperty.None) + var lockenTokenKey = $"{sagaData.GetType().FullName}.{correlationProperty?.Name ?? "None"}.{correlationProperty?.Value ?? "None"}"; + var lockToken = lockers.GetOrAdd(lockenTokenKey, key => new object()); + lock (lockToken) { - ValidateUniqueProperties(correlationProperty, sagaData); + if (correlationProperty != SagaCorrelationProperty.None) + { + ValidateUniqueProperties(correlationProperty, sagaData); + } + + data.AddOrUpdate(sagaData.Id, + id => new VersionedSagaEntity(sagaData, lockenTokenKey), + (id, original) => new VersionedSagaEntity(sagaData, lockenTokenKey, original)); // we can never end up here. } - - data.AddOrUpdate(sagaData.Id, - id => new VersionedSagaEntity(sagaData), - (id, original) => new VersionedSagaEntity(sagaData, original)); }); return TaskEx.CompletedTask; } @@ -84,8 +93,8 @@ public Task Update(IContainSagaData sagaData, SynchronizedStorageSession session inMemSession.Enlist(() => { data.AddOrUpdate(sagaData.Id, - id => new VersionedSagaEntity(sagaData), - (id, original) => new VersionedSagaEntity(sagaData, original)); + id => new VersionedSagaEntity(sagaData, $"{sagaData.GetType().FullName}.None.None"), // we can never end up here. + (id, original) => new VersionedSagaEntity(sagaData, original.LockTokenKey, original)); }); return TaskEx.CompletedTask; } @@ -94,6 +103,7 @@ void ValidateUniqueProperties(SagaCorrelationProperty correlationProperty, ICont { var sagaType = saga.GetType(); var existingSagas = new List(); + // ReSharper disable once LoopCanBeConvertedToQuery foreach (var s in data) { if (s.Value.SagaData.GetType() == sagaType && (s.Key != saga.Id)) @@ -121,11 +131,13 @@ void ValidateUniqueProperties(SagaCorrelationProperty correlationProperty, ICont } ConcurrentDictionary data = new ConcurrentDictionary(); + ConcurrentDictionary lockers = new ConcurrentDictionary(); class VersionedSagaEntity { - public VersionedSagaEntity(IContainSagaData sagaData, VersionedSagaEntity original = null) + public VersionedSagaEntity(IContainSagaData sagaData, string lockTokenKey, VersionedSagaEntity original = null) { + LockTokenKey = lockTokenKey; SagaData = DeepClone(sagaData); if (original != null) { @@ -171,6 +183,7 @@ static IContainSagaData DeepClone(IContainSagaData source) } public IContainSagaData SagaData; + public string LockTokenKey; ConditionalWeakTable versionCache;