diff --git a/src/ChromeProtocol.Runtime/Messaging/IProtocolClient.cs b/src/ChromeProtocol.Runtime/Messaging/IProtocolClient.cs index 6a9884a..4957d5d 100644 --- a/src/ChromeProtocol.Runtime/Messaging/IProtocolClient.cs +++ b/src/ChromeProtocol.Runtime/Messaging/IProtocolClient.cs @@ -20,10 +20,10 @@ public interface IProtocolClient : IDisposable void ListenEvent(AsyncDomainEventHandler handler) where TEvent : IEvent; - IDisposable SubscribeAsync(AsyncDomainEventHandler handler) + IDisposable SubscribeAsync(AsyncDomainEventHandler handler, string? sessionId = default) where TEvent : IEvent; - IDisposable SubscribeSync(SyncDomainEventHandler handler) + IDisposable SubscribeSync(SyncDomainEventHandler handler, string? sessionId = default) where TEvent : IEvent; Task SendCommandAsync(ICommand command, string? sessionId = default, CancellationToken? token = default) diff --git a/src/ChromeProtocol.Runtime/Messaging/WebSockets/ScopedProtocolClient.cs b/src/ChromeProtocol.Runtime/Messaging/WebSockets/ScopedProtocolClient.cs index eaac0a5..f4ba56c 100644 --- a/src/ChromeProtocol.Runtime/Messaging/WebSockets/ScopedProtocolClient.cs +++ b/src/ChromeProtocol.Runtime/Messaging/WebSockets/ScopedProtocolClient.cs @@ -18,10 +18,10 @@ public void ListenEvent(AsyncDomainEventHandler handler) where T SubscribeAsync(handler); public IDisposable SubscribeAsync(AsyncDomainEventHandler handler) where TEvent : IEvent => - _mainClient.SubscribeAsync(handler); + _mainClient.SubscribeAsync(handler, SessionId); public IDisposable SubscribeSync(SyncDomainEventHandler handler) where TEvent : IEvent => - _mainClient.SubscribeSync(handler); + _mainClient.SubscribeSync(handler, SessionId); public Task SendCommandAsync(ICommand command, CancellationToken? token = default) where TResponse : IType => _mainClient.SendCommandAsync(command, SessionId, token); diff --git a/src/ChromeProtocol.Runtime/Messaging/WebSockets/WebSocketProtocolClient.cs b/src/ChromeProtocol.Runtime/Messaging/WebSockets/WebSocketProtocolClient.cs index a9a827d..fea0378 100644 --- a/src/ChromeProtocol.Runtime/Messaging/WebSockets/WebSocketProtocolClient.cs +++ b/src/ChromeProtocol.Runtime/Messaging/WebSockets/WebSocketProtocolClient.cs @@ -15,7 +15,7 @@ public class WebSocketProtocolClient : IProtocolClient { private readonly Uri _wsUri; private readonly ILogger _logger; - private readonly ConcurrentDictionary, Task>> _eventHandlers = new (); + private readonly ConcurrentDictionary<(string? sessionId, string eventName), Func, Task>> _eventHandlers = new (); private readonly ConcurrentDictionary> _responseResolvers = new (); private readonly TNativeClient _nativeClient; private CancellationTokenSource _connectionCancellation; @@ -67,7 +67,7 @@ public void ListenEvent(AsyncDomainEventHandler handler) where T SubscribeAsync(handler); } - public IDisposable SubscribeAsync(AsyncDomainEventHandler handler) where TEvent : IEvent + public IDisposable SubscribeAsync(AsyncDomainEventHandler handler, string? sessionId = default) where TEvent : IEvent { Func, Task> HandleProtocolEvent(AsyncDomainEventHandler eventHandler) => async rawEvent => @@ -76,10 +76,10 @@ Func, Task> HandleProtocolEvent(AsyncDomainEventHandler(HandleProtocolEvent(handler)); + return SubscribeInternal(HandleProtocolEvent(handler), sessionId); } - public IDisposable SubscribeSync(SyncDomainEventHandler handler) where TEvent : IEvent + public IDisposable SubscribeSync(SyncDomainEventHandler handler, string? sessionId = default) where TEvent : IEvent { Func, Task> HandleProtocolEvent(SyncDomainEventHandler eventHandler) => rawEvent => @@ -88,7 +88,7 @@ Func, Task> HandleProtocolEvent(SyncDomainEventHandler eventHandler(eventItself)); }; - return SubscribeInternal(HandleProtocolEvent(handler)); + return SubscribeInternal(HandleProtocolEvent(handler), sessionId); } public async Task SendCommandAsync(ICommand command, @@ -140,11 +140,11 @@ private async Task FireInternalAsync(int id, string methodName, ICommand command } } - private IDisposable SubscribeInternal(Func, Task> rawHandler) where TEvent : IEvent + private IDisposable SubscribeInternal(Func, Task> rawHandler, string? sessionId) where TEvent : IEvent { var eventName = GetMethodName(typeof(TEvent)); - var subscription = new ProtocolSubscription(eventName, rawHandler, this); - _eventHandlers.AddOrUpdate(GetMethodName(typeof(TEvent)), rawHandler, (_, existing) => existing + rawHandler); + var subscription = new ProtocolSubscription(sessionId, eventName, rawHandler, this); + _eventHandlers.AddOrUpdate((sessionId, eventName), rawHandler, (_, existing) => existing + rawHandler); return subscription; } @@ -232,7 +232,7 @@ private Task ProcessIncoming(string message) => private async Task ProcessIncomingEvent(ProtocolEvent @event) { OnEventReceived?.Invoke(this, @event); - if (_eventHandlers.TryGetValue(@event.Method, out var handler)) + if (_eventHandlers.TryGetValue((@event.SessionId, @event.Method), out var handler)) await handler.Invoke(@event).ConfigureAwait(false); } @@ -262,12 +262,14 @@ private static string GetMethodName(MemberInfo type) => private class ProtocolSubscription : IDisposable where T : WebSocket { + private readonly string? _sessionId; private readonly string _eventName; private readonly Func, Task>? _wrappedHandler; private readonly WebSocketProtocolClient _client; - public ProtocolSubscription(string eventName, Func,Task>? wrappedHandler, WebSocketProtocolClient client) + public ProtocolSubscription(string? sessionId, string eventName, Func,Task>? wrappedHandler, WebSocketProtocolClient client) { + _sessionId = sessionId; _eventName = eventName; _wrappedHandler = wrappedHandler; _client = client; @@ -275,16 +277,17 @@ public ProtocolSubscription(string eventName, Func,Task>? public void Dispose() { - if (_client._eventHandlers.TryGetValue(_eventName, out var aggregatedHandlers)) + var eventKey = (_sessionId, _eventName); + if (_client._eventHandlers.TryGetValue(eventKey, out var aggregatedHandlers)) { var updatedHandlers = aggregatedHandlers - _wrappedHandler; if (updatedHandlers is null) { - _client._eventHandlers.TryRemove(_eventName, out _); + _client._eventHandlers.TryRemove(eventKey, out _); } else { - _client._eventHandlers[_eventName] = updatedHandlers; + _client._eventHandlers[eventKey] = updatedHandlers; } } }