diff --git a/src/NServiceBus.AcceptanceTests/Msmq/When_publishing_with_authorizer.cs b/src/NServiceBus.AcceptanceTests/Msmq/When_publishing_with_authorizer.cs new file mode 100644 index 00000000000..025b09ae4c4 --- /dev/null +++ b/src/NServiceBus.AcceptanceTests/Msmq/When_publishing_with_authorizer.cs @@ -0,0 +1,177 @@ +namespace NServiceBus.AcceptanceTests.Msmq +{ + using System.Collections.Generic; + using NServiceBus.AcceptanceTesting; + using NServiceBus.AcceptanceTests.EndpointTemplates; + using NServiceBus.AcceptanceTests.ScenarioDescriptors; + using NServiceBus.Features; + using NServiceBus.MessageMutator; + using NServiceBus.Persistence.InMemory; + using NUnit.Framework; + + public class When_publishing_with_authorizer : NServiceBusAcceptanceTest + { + [Test] + public void Should_only_deliver_to_authorized() + { + Scenario.Define() + .WithEndpoint(b => + b.When(c => c.Subscriber1Subscribed && c.Subscriber2Subscribed, (bus, c) => bus.Publish(new MyEvent())) + ) + .WithEndpoint(b => b.When(bus => + { + bus.Subscribe(); + })) + .WithEndpoint(b => b.When(bus => + { + bus.Subscribe(); + })) + .Done(c => + c.Subscriber1GotTheEvent && + c.DeclinedSubscriber2) + .Repeat(r => r.For(Transports.Msmq)) + .Should(c => + { + Assert.True(c.Subscriber1GotTheEvent); + Assert.False(c.Subscriber2GotTheEvent); + }) + .Run(); + } + + public class TestContext : ScenarioContext + { + public bool Subscriber1GotTheEvent { get; set; } + public bool Subscriber2GotTheEvent { get; set; } + public bool Subscriber1Subscribed { get; set; } + public bool Subscriber2Subscribed { get; set; } + public bool DeclinedSubscriber2 { get; set; } + } + + class Publisher : EndpointConfigurationBuilder + { + public Publisher() + { + EndpointSetup(b => + { + // InMemoryPersistence.UseAsDefault(); + Configure.Features.Disable(); + Configure.Component(DependencyLifecycle.InstancePerCall); + }); + } + + class MyTransportMessageMutator : IMutateIncomingTransportMessages + { + TestContext context; + + public MyTransportMessageMutator(TestContext context) + { + this.context = context; + } + + public void MutateIncoming(TransportMessage transportMessage) + { + if (transportMessage.MessageIntent == MessageIntentEnum.Subscribe) + { + var originatingEndpoint = transportMessage.Headers[Headers.OriginatingEndpoint]; + if (originatingEndpoint.Contains("Subscriber1")) + { + context.Subscriber1Subscribed = true; + } + if (originatingEndpoint.Contains("Subscriber2")) + { + context.Subscriber2Subscribed = true; + } + } + } + } + + public class SubscriptionAuthorizer : IAuthorizeSubscriptions + { + TestContext context; + + public SubscriptionAuthorizer(TestContext context) + { + this.context = context; + } + + public bool AuthorizeSubscribe(string messageType, string clientEndpoint, IDictionary headers) + { + var isFromSubscriber1 = headers[Headers.OriginatingEndpoint] + .Contains("Subscriber1"); + if (!isFromSubscriber1) + { + context.DeclinedSubscriber2 = true; + } + return isFromSubscriber1; + } + + public bool AuthorizeUnsubscribe(string messageType, string clientEndpoint, IDictionary headers) + { + return true; + } + } + } + + public class Subscriber1 : EndpointConfigurationBuilder + { + public Subscriber1() + { + EndpointSetup(c => + { + InMemoryPersistence.UseAsDefault(); + Configure.Features.Disable(); + }) + .AddMapping(typeof(Publisher)); + } + + public class MyEventHandler : IHandleMessages + { + TestContext context; + + public MyEventHandler(TestContext context) + { + this.context = context; + } + + public void Handle(MyEvent message) + { + context.Subscriber1GotTheEvent = true; + } + } + + + } + + public class Subscriber2 : EndpointConfigurationBuilder + { + public Subscriber2() + { + EndpointSetup(c => + { + InMemoryPersistence.UseAsDefault(); + Configure.Features.Disable(); + }) + .AddMapping(typeof(Publisher)); + } + + public class MyEventHandler : IHandleMessages + { + TestContext context; + + public MyEventHandler(TestContext context) + { + this.context = context; + } + + public void Handle(MyEvent messageThatIsEnlisted) + { + context.Subscriber2GotTheEvent = true; + } + } + } + + public class MyEvent : IEvent + { + } + } +} \ No newline at end of file diff --git a/src/NServiceBus.AcceptanceTests/Msmq/When_unsubscribing_with_authorizer.cs b/src/NServiceBus.AcceptanceTests/Msmq/When_unsubscribing_with_authorizer.cs new file mode 100644 index 00000000000..ffdc8a2f18c --- /dev/null +++ b/src/NServiceBus.AcceptanceTests/Msmq/When_unsubscribing_with_authorizer.cs @@ -0,0 +1,157 @@ +namespace NServiceBus.AcceptanceTests.Msmq +{ + using System.Collections.Generic; + using NServiceBus.AcceptanceTesting; + using NServiceBus.AcceptanceTests.EndpointTemplates; + using NServiceBus.AcceptanceTests.ScenarioDescriptors; + using NServiceBus.Features; + using NServiceBus.MessageMutator; + using NServiceBus.Persistence.InMemory; + using NUnit.Framework; + + public class When_unsubscribing_with_authorizer : NServiceBusAcceptanceTest + { + + [Test] + public void Should_ignore_unsubscribe() + { + Scenario.Define() + .WithEndpoint(b => + b.When(c => c.Subscribed, (bus, c) => + { + bus.Publish(new MyEvent()); + }).When(c => c.SubscriberEventCount == 1, (bus, c) => + { + bus.Publish(new MyEvent()); + }) + ) + .WithEndpoint(b => b.When(c => c.PublisherStarted, bus => + { + bus.Subscribe(); + })) + .Done(c => + c.SubscriberEventCount == 2 && + c.DeclinedUnSubscribe) + .Repeat(r => r.For(Transports.Msmq)) + .Run(); + } + + public class TestContext : ScenarioContext + { + public int SubscriberEventCount { get; set; } + public bool UnsubscribeAttempted { get; set; } + public bool DeclinedUnSubscribe { get; set; } + public bool Subscribed { get; set; } + public bool PublisherStarted { get; set; } + } + + class Publisher : EndpointConfigurationBuilder + { + public Publisher() + { + EndpointSetup(b => + { + InMemoryPersistence.UseAsDefault(); + Configure.Features.Disable(); + Configure.Component(DependencyLifecycle.InstancePerCall); + }); + } + class MyTransportMessageMutator : IMutateIncomingTransportMessages + { + TestContext context; + + public MyTransportMessageMutator(TestContext context) + { + this.context = context; + } + + public void MutateIncoming(TransportMessage transportMessage) + { + if (transportMessage.MessageIntent == MessageIntentEnum.Subscribe) + { + var originatingEndpoint = transportMessage.Headers[Headers.OriginatingEndpoint]; + if (originatingEndpoint.Contains("Subscriber")) + { + context.Subscribed = true; + } + } + } + } + + public class CaptureStarted : IWantToRunWhenBusStartsAndStops + { + TestContext context; + + public CaptureStarted(TestContext context) + { + this.context = context; + } + + public void Start() + { + context.PublisherStarted = true; + } + + public void Stop() + { + } + } + + public class SubscriptionAuthorizer : IAuthorizeSubscriptions + { + TestContext context; + + public SubscriptionAuthorizer(TestContext context) + { + this.context = context; + } + + public bool AuthorizeSubscribe(string messageType, string clientEndpoint, IDictionary headers) + { + return true; + } + + public bool AuthorizeUnsubscribe(string messageType, string clientEndpoint, IDictionary headers) + { + context.DeclinedUnSubscribe = true; + return false; + } + } + } + + public class Subscriber : EndpointConfigurationBuilder + { + public Subscriber() + { + EndpointSetup(c => + { + InMemoryPersistence.UseAsDefault(); + Configure.Features.Disable(); + }) + .AddMapping(typeof(Publisher)); + } + + public class MyEventHandler : IHandleMessages + { + IBus bus; + TestContext context; + + public MyEventHandler(IBus bus, TestContext context) + { + this.bus = bus; + this.context = context; + } + + public void Handle(MyEvent message) + { + context.SubscriberEventCount++; + bus.Unsubscribe(); + } + } + } + + public class MyEvent : IEvent + { + } + } +} \ No newline at end of file diff --git a/src/NServiceBus.AcceptanceTests/NServiceBus.AcceptanceTests.csproj b/src/NServiceBus.AcceptanceTests/NServiceBus.AcceptanceTests.csproj index f8460315d4c..ae74570100c 100644 --- a/src/NServiceBus.AcceptanceTests/NServiceBus.AcceptanceTests.csproj +++ b/src/NServiceBus.AcceptanceTests/NServiceBus.AcceptanceTests.csproj @@ -108,6 +108,8 @@ + + diff --git a/src/NServiceBus.Core.Tests/NServiceBus.Core.Tests.csproj b/src/NServiceBus.Core.Tests/NServiceBus.Core.Tests.csproj index ec04f110887..638ab4cb171 100644 --- a/src/NServiceBus.Core.Tests/NServiceBus.Core.Tests.csproj +++ b/src/NServiceBus.Core.Tests/NServiceBus.Core.Tests.csproj @@ -136,6 +136,7 @@ + diff --git a/src/NServiceBus.Core.Tests/UnicastBusAuthorizationTests.cs b/src/NServiceBus.Core.Tests/UnicastBusAuthorizationTests.cs new file mode 100644 index 00000000000..031d4933cd2 --- /dev/null +++ b/src/NServiceBus.Core.Tests/UnicastBusAuthorizationTests.cs @@ -0,0 +1,66 @@ +namespace NServiceBus.Unicast.Tests +{ + using System; + using System.Collections.Generic; + using NServiceBus.Unicast.Config; + using NUnit.Framework; + + [TestFixture] + public class UnicastBusAuthorizationTests + { + + [Test] + public void Should_use_noop_for_no_authorizer() + { + var authorizationType = ConfigUnicastBus.FindAuthorizationType(new List()); + Assert.AreEqual("NoopSubscriptionAuthorizer", authorizationType.Name); + } + + [Test] + public void Should_use_single_for_one_authorizer() + { + var authorizationType = ConfigUnicastBus.FindAuthorizationType(new List() + { + typeof(Authorizer1) + }); + Assert.AreEqual(typeof(Authorizer1), authorizationType); + } + + [Test] + public void Should_throw_for_multiple_authorizer() + { + var exception = Assert.Throws(() => ConfigUnicastBus.FindAuthorizationType(new List() + { + typeof(Authorizer1), + typeof(Authorizer2) + })); + Assert.AreEqual("Only one instance of IAuthorizeSubscriptions is allowed. Found the following: 'NServiceBus.Unicast.Tests.UnicastBusAuthorizationTests+Authorizer1', 'NServiceBus.Unicast.Tests.UnicastBusAuthorizationTests+Authorizer2'.", exception.Message); + } + + class Authorizer1 : IAuthorizeSubscriptions + { + public bool AuthorizeSubscribe(string messageType, string clientEndpoint, IDictionary headers) + { + throw new NotImplementedException(); + } + + public bool AuthorizeUnsubscribe(string messageType, string clientEndpoint, IDictionary headers) + { + throw new NotImplementedException(); + } + } + + class Authorizer2 : IAuthorizeSubscriptions + { + public bool AuthorizeSubscribe(string messageType, string clientEndpoint, IDictionary headers) + { + throw new NotImplementedException(); + } + + public bool AuthorizeUnsubscribe(string messageType, string clientEndpoint, IDictionary headers) + { + throw new NotImplementedException(); + } + } + } +} \ No newline at end of file diff --git a/src/NServiceBus.Core/Unicast/Config/ConfigUnicastBus.cs b/src/NServiceBus.Core/Unicast/Config/ConfigUnicastBus.cs index 2a9c930b8b5..2b9874cb5c0 100644 --- a/src/NServiceBus.Core/Unicast/Config/ConfigUnicastBus.cs +++ b/src/NServiceBus.Core/Unicast/Config/ConfigUnicastBus.cs @@ -8,6 +8,7 @@ namespace NServiceBus.Unicast.Config using Features; using Logging; using Messages; + using NServiceBus.Unicast.Subscriptions.MessageDrivenSubscriptions; using ObjectBuilder; using Pipeline; using Pipeline.Contexts; @@ -58,12 +59,31 @@ void RegisterMessageModules() void ConfigureSubscriptionAuthorization() { - var authType = TypesToScan.FirstOrDefault(t => typeof(IAuthorizeSubscriptions).IsAssignableFrom(t) && !t.IsInterface); + var typeToUse = FindAuthorizationType(TypesToScan); + Configurer.ConfigureComponent(typeToUse, DependencyLifecycle.SingleInstance); + } + - if (authType != null) - Configurer.ConfigureComponent(authType, DependencyLifecycle.SingleInstance); + internal static Type FindAuthorizationType(IEnumerable availableTypes) + { + var authType = typeof(IAuthorizeSubscriptions); + var noopType = typeof(NoopSubscriptionAuthorizer); + var foundAuthTypes = availableTypes + .Where(t => t != noopType && authType.IsAssignableFrom(t) && !t.IsInterface && !t.IsAbstract) + .ToList(); + if (foundAuthTypes.Count > 1) + { + var fullNames = foundAuthTypes.Select(type => type.FullName); + var error = string.Format("Only one instance of IAuthorizeSubscriptions is allowed. Found the following: '{0}'.", string.Join("', '", fullNames)); + throw new Exception(error); + } + if (foundAuthTypes.Count == 0) + { + return noopType; + } + return foundAuthTypes.Single(); } - + /// /// Used to configure the bus. ///