Skip to content

Commit

Permalink
Fix #5: Add missing ability to subscribe for specific session's events
Browse files Browse the repository at this point in the history
  • Loading branch information
seclerp committed May 28, 2024
1 parent 9f3cdd6 commit 24cad09
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/ChromeProtocol.Runtime/Messaging/IProtocolClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ public interface IProtocolClient : IDisposable
void ListenEvent<TEvent>(AsyncDomainEventHandler<TEvent> handler)
where TEvent : IEvent;

IDisposable SubscribeAsync<TEvent>(AsyncDomainEventHandler<TEvent> handler)
IDisposable SubscribeAsync<TEvent>(AsyncDomainEventHandler<TEvent> handler, string? sessionId = default)
where TEvent : IEvent;

IDisposable SubscribeSync<TEvent>(SyncDomainEventHandler<TEvent> handler)
IDisposable SubscribeSync<TEvent>(SyncDomainEventHandler<TEvent> handler, string? sessionId = default)
where TEvent : IEvent;

Task<TResponse> SendCommandAsync<TResponse>(ICommand<TResponse> command, string? sessionId = default, CancellationToken? token = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ public void ListenEvent<TEvent>(AsyncDomainEventHandler<TEvent> handler) where T
SubscribeAsync(handler);

public IDisposable SubscribeAsync<TEvent>(AsyncDomainEventHandler<TEvent> handler) where TEvent : IEvent =>
_mainClient.SubscribeAsync(handler);
_mainClient.SubscribeAsync(handler, SessionId);

public IDisposable SubscribeSync<TEvent>(SyncDomainEventHandler<TEvent> handler) where TEvent : IEvent =>
_mainClient.SubscribeSync(handler);
_mainClient.SubscribeSync(handler, SessionId);

public Task<TResponse> SendCommandAsync<TResponse>(ICommand<TResponse> command, CancellationToken? token = default)
where TResponse : IType => _mainClient.SendCommandAsync(command, SessionId, token);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class WebSocketProtocolClient<TNativeClient> : IProtocolClient
{
private readonly Uri _wsUri;
private readonly ILogger _logger;
private readonly ConcurrentDictionary<string, Func<ProtocolEvent<JObject>, Task>> _eventHandlers = new ();
private readonly ConcurrentDictionary<(string? sessionId, string eventName), Func<ProtocolEvent<JObject>, Task>> _eventHandlers = new ();
private readonly ConcurrentDictionary<int, TaskCompletionSource<JObject>> _responseResolvers = new ();
private readonly TNativeClient _nativeClient;
private CancellationTokenSource _connectionCancellation;
Expand Down Expand Up @@ -67,7 +67,7 @@ public void ListenEvent<TEvent>(AsyncDomainEventHandler<TEvent> handler) where T
SubscribeAsync(handler);
}

public IDisposable SubscribeAsync<TEvent>(AsyncDomainEventHandler<TEvent> handler) where TEvent : IEvent
public IDisposable SubscribeAsync<TEvent>(AsyncDomainEventHandler<TEvent> handler, string? sessionId = default) where TEvent : IEvent
{
Func<ProtocolEvent<JObject>, Task> HandleProtocolEvent(AsyncDomainEventHandler<TEvent> eventHandler) =>
async rawEvent =>
Expand All @@ -76,10 +76,10 @@ Func<ProtocolEvent<JObject>, Task> HandleProtocolEvent(AsyncDomainEventHandler<T
await eventHandler(eventItself).ConfigureAwait(false);
};

return SubscribeInternal<TEvent>(HandleProtocolEvent(handler));
return SubscribeInternal<TEvent>(HandleProtocolEvent(handler), sessionId);
}

public IDisposable SubscribeSync<TEvent>(SyncDomainEventHandler<TEvent> handler) where TEvent : IEvent
public IDisposable SubscribeSync<TEvent>(SyncDomainEventHandler<TEvent> handler, string? sessionId = default) where TEvent : IEvent
{
Func<ProtocolEvent<JObject>, Task> HandleProtocolEvent(SyncDomainEventHandler<TEvent> eventHandler) =>
rawEvent =>
Expand All @@ -88,7 +88,7 @@ Func<ProtocolEvent<JObject>, Task> HandleProtocolEvent(SyncDomainEventHandler<TE
return Task.Run(() => eventHandler(eventItself));
};

return SubscribeInternal<TEvent>(HandleProtocolEvent(handler));
return SubscribeInternal<TEvent>(HandleProtocolEvent(handler), sessionId);
}

public async Task<TResponse> SendCommandAsync<TResponse>(ICommand<TResponse> command,
Expand Down Expand Up @@ -140,11 +140,11 @@ private async Task FireInternalAsync(int id, string methodName, ICommand command
}
}

private IDisposable SubscribeInternal<TEvent>(Func<ProtocolEvent<JObject>, Task> rawHandler) where TEvent : IEvent
private IDisposable SubscribeInternal<TEvent>(Func<ProtocolEvent<JObject>, Task> rawHandler, string? sessionId) where TEvent : IEvent
{
var eventName = GetMethodName(typeof(TEvent));
var subscription = new ProtocolSubscription<TNativeClient>(eventName, rawHandler, this);
_eventHandlers.AddOrUpdate(GetMethodName(typeof(TEvent)), rawHandler, (_, existing) => existing + rawHandler);
var subscription = new ProtocolSubscription<TNativeClient>(sessionId, eventName, rawHandler, this);
_eventHandlers.AddOrUpdate((sessionId, eventName), rawHandler, (_, existing) => existing + rawHandler);
return subscription;
}

Expand Down Expand Up @@ -232,7 +232,7 @@ private Task ProcessIncoming(string message) =>
private async Task ProcessIncomingEvent(ProtocolEvent<JObject> @event)
{
OnEventReceived?.Invoke(this, @event);
if (_eventHandlers.TryGetValue(@event.Method, out var handler))
if (_eventHandlers.TryGetValue((@event.SessionId, @event.Method), out var handler))
await handler.Invoke(@event).ConfigureAwait(false);
}

Expand Down Expand Up @@ -262,29 +262,32 @@ private static string GetMethodName(MemberInfo type) =>

private class ProtocolSubscription<T> : IDisposable where T : WebSocket
{
private readonly string? _sessionId;
private readonly string _eventName;
private readonly Func<ProtocolEvent<JObject>, Task>? _wrappedHandler;
private readonly WebSocketProtocolClient<T> _client;

public ProtocolSubscription(string eventName, Func<ProtocolEvent<JObject>,Task>? wrappedHandler, WebSocketProtocolClient<T> client)
public ProtocolSubscription(string? sessionId, string eventName, Func<ProtocolEvent<JObject>,Task>? wrappedHandler, WebSocketProtocolClient<T> client)
{
_sessionId = sessionId;
_eventName = eventName;
_wrappedHandler = wrappedHandler;
_client = client;
}

public void Dispose()
{
if (_client._eventHandlers.TryGetValue(_eventName, out var aggregatedHandlers))
var eventKey = (_sessionId, _eventName);
if (_client._eventHandlers.TryGetValue(eventKey, out var aggregatedHandlers))
{
var updatedHandlers = aggregatedHandlers - _wrappedHandler;
if (updatedHandlers is null)
{
_client._eventHandlers.TryRemove(_eventName, out _);
_client._eventHandlers.TryRemove(eventKey, out _);
}
else
{
_client._eventHandlers[_eventName] = updatedHandlers;
_client._eventHandlers[eventKey] = updatedHandlers;
}
}
}
Expand Down

0 comments on commit 24cad09

Please sign in to comment.