diff --git a/src/Lavalink4NET.DSharpPlus.Nightly/DSharpPlusUtilities.cs b/src/Lavalink4NET.DSharpPlus.Nightly/DSharpPlusUtilities.cs deleted file mode 100644 index d4085169..00000000 --- a/src/Lavalink4NET.DSharpPlus.Nightly/DSharpPlusUtilities.cs +++ /dev/null @@ -1,93 +0,0 @@ -namespace Lavalink4NET.DSharpPlus; - -using System; -using System.Collections.Concurrent; -using System.Reflection; -using global::DSharpPlus; -using global::DSharpPlus.Clients; -using global::DSharpPlus.AsyncEvents; -using System.Threading.Tasks; - -/// -/// An utility for getting internal / private fields from DSharpPlus WebSocket Gateway Payloads. -/// -public static partial class DSharpPlusUtilities -{ - /// - /// The internal "events" property info in . - /// - // https://github.com/DSharpPlus/DSharpPlus/blob/master/DSharpPlus/Clients/DiscordClient.cs#L37 - private static readonly FieldInfo eventsField = - typeof(DiscordClient).GetField("events", BindingFlags.NonPublic | BindingFlags.Instance)!; - - /// - /// Gets the internal "events" property value of the specified . - /// - /// the instance - /// the "events" value - public static ConcurrentDictionary GetEvents(this DiscordClient client) - => (ConcurrentDictionary)eventsField.GetValue(client)!; - - /// - /// The internal "errorHandler" property info in . - /// - // https://github.com/DSharpPlus/DSharpPlus/blob/master/DSharpPlus/Clients/DiscordClient.cs#L41 - private static readonly FieldInfo errorHandlerField = - typeof(DiscordClient).GetField("errorHandler", BindingFlags.NonPublic | BindingFlags.Instance)!; - - /// - /// Gets the internal "errorHandler" property value of the specified . - /// - /// the instance - /// the "errorHandler" value - public static IClientErrorHandler GetErrorHandler(this DiscordClient client) - => (IClientErrorHandler)errorHandlerField.GetValue(client)!; - - /// - /// The internal "Register" method info in . - /// - // https://github.com/DSharpPlus/DSharpPlus/blob/master/DSharpPlus/AsyncEvents/AsyncEvent.cs#L14 - private static readonly MethodInfo asyncEventRegisterMethod = - typeof(AsyncEvent).GetMethod("Register", BindingFlags.NonPublic | BindingFlags.Instance, [typeof(Delegate)])!; - - /// - /// Calls the internal "Register" method of the spedificed - /// - /// the instance - /// the event to register - public static void Register(this AsyncEvent asyncEvent, Delegate @delegate) => asyncEventRegisterMethod.Invoke(asyncEvent, [@delegate]); - - /// - /// The internal "orchestrator" property info in . - /// - // https://github.com/DSharpPlus/DSharpPlus/blob/master/DSharpPlus/Clients/DiscordClient.cs#L47 - private static readonly FieldInfo orchestratorField = - typeof(DiscordClient).GetField("orchestrator", BindingFlags.NonPublic | BindingFlags.Instance)!; - - /// - /// The internal "shardCount" property info in . - /// - // https://github.com/DSharpPlus/DSharpPlus/blob/master/DSharpPlus/Clients/DiscordClient.cs#L47 - private static readonly FieldInfo shardCountField = - typeof(MultiShardOrchestrator).GetField("shardCount", BindingFlags.NonPublic | BindingFlags.Instance)!; - - /// - /// Gets the number of connected shards or this client - /// - public static async ValueTask GetShardCountAsync(this DiscordClient client) - { - var orchestrator = (IShardOrchestrator)orchestratorField.GetValue(client)!; - - if (orchestrator is SingleShardOrchestrator) - return 1; - - if (orchestrator is MultiShardOrchestrator multiShardOrchestrator) - return (int)(uint)shardCountField.GetValue(multiShardOrchestrator)!; - - // If the orchestrator is neither a Single nor Multi sharded orchestrator, that means it - // is using a custom solution implemented by the end user. There is no way to directly access - // the shard count in this case, so instead we estimate it by using Discord's recommended - // shard amount. - return (await client.GetGatewayInfoAsync()).ShardCount; - } -} diff --git a/src/Lavalink4NET.DSharpPlus.Nightly/DiscordClientWrapper.cs b/src/Lavalink4NET.DSharpPlus.Nightly/DiscordClientWrapper.cs index 767f8393..60a9fb7d 100644 --- a/src/Lavalink4NET.DSharpPlus.Nightly/DiscordClientWrapper.cs +++ b/src/Lavalink4NET.DSharpPlus.Nightly/DiscordClientWrapper.cs @@ -5,16 +5,15 @@ namespace Lavalink4NET.DSharpPlus; using System.Threading; using System.Threading.Tasks; using global::DSharpPlus; -using global::DSharpPlus.AsyncEvents; +using global::DSharpPlus.Clients; using global::DSharpPlus.Entities; using global::DSharpPlus.EventArgs; using global::DSharpPlus.Exceptions; using global::DSharpPlus.Net.Abstractions; using Lavalink4NET.Clients; -using L4N = Clients.Events; +using Lavalink4NET = Clients.Events; using Lavalink4NET.Events; using Microsoft.Extensions.Logging; -using System.Collections.Concurrent; /// /// Wraps a instance. @@ -22,12 +21,13 @@ namespace Lavalink4NET.DSharpPlus; public sealed class DiscordClientWrapper : IDiscordClientWrapper { /// - public event AsyncEventHandler? VoiceServerUpdated; + public event AsyncEventHandler? VoiceServerUpdated; /// - public event AsyncEventHandler? VoiceStateUpdated; + public event AsyncEventHandler? VoiceStateUpdated; - private readonly DiscordClient _client; // sharded clients are now also managed by the same DiscordClient type + private readonly DiscordClient _client; + private readonly IShardOrchestrator _shardOrchestrator; private readonly ILogger _logger; private readonly TaskCompletionSource _readyTaskCompletionSource; @@ -35,35 +35,18 @@ public sealed class DiscordClientWrapper : IDiscordClientWrapper /// Creates a new instance of . /// /// The Discord Client to wrap. - /// a logger associated with this wrapper. - public DiscordClientWrapper(DiscordClient discordClient, ILogger logger) + /// The Discord shard orchestrator associated with this client. + /// A logger associated with this wrapper. + public DiscordClientWrapper(DiscordClient discordClient, IShardOrchestrator shardOrchestrator, ILogger logger) { ArgumentNullException.ThrowIfNull(discordClient); + ArgumentNullException.ThrowIfNull(shardOrchestrator); ArgumentNullException.ThrowIfNull(logger); _client = discordClient; + _shardOrchestrator = shardOrchestrator; _logger = logger; - _readyTaskCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - void AddEventHandler(Type eventArgsType, Delegate eventHandler) - { - IClientErrorHandler errorHandler = discordClient.GetErrorHandler(); - ConcurrentDictionary events = discordClient.GetEvents(); - - Type asyncEventType = typeof(AsyncEvent<,>).MakeGenericType(discordClient.GetType(), eventArgsType); - AsyncEvent asyncEvent = events.GetOrAdd(eventArgsType, _ => (AsyncEvent)Activator.CreateInstance - ( - type: asyncEventType, - args: [errorHandler] - )!); - - asyncEvent.Register(eventHandler); - } - - AddEventHandler(typeof(VoiceStateUpdatedEventArgs), new AsyncEventHandler(OnVoiceStateUpdated)); - AddEventHandler(typeof(VoiceServerUpdatedEventArgs), new AsyncEventHandler(OnVoiceServerUpdated)); - AddEventHandler(typeof(GuildDownloadCompletedEventArgs), new AsyncEventHandler(OnGuildDownloadCompleted)); } /// @@ -88,6 +71,7 @@ public async ValueTask> GetChannelUsersAsync( return ImmutableArray.Empty; } } + catch (DiscordException exception) { _logger.LogWarning( @@ -152,7 +136,7 @@ public ValueTask WaitForReadyAsync(CancellationToken cancella return new(_readyTaskCompletionSource.Task.WaitAsync(cancellationToken)); } - private async Task OnGuildDownloadCompleted(DiscordClient discordClient, GuildDownloadCompletedEventArgs eventArgs) + internal Task OnGuildDownloadCompleted(DiscordClient discordClient, GuildDownloadCompletedEventArgs eventArgs) { ArgumentNullException.ThrowIfNull(discordClient); ArgumentNullException.ThrowIfNull(eventArgs); @@ -160,12 +144,13 @@ private async Task OnGuildDownloadCompleted(DiscordClient discordClient, GuildDo var clientInformation = new ClientInformation( Label: "DSharpPlus", CurrentUserId: discordClient.CurrentUser.Id, - ShardCount: await discordClient.GetShardCountAsync()); + ShardCount: _shardOrchestrator.ConnectedShardCount); _readyTaskCompletionSource.TrySetResult(clientInformation); + return Task.CompletedTask; } - private async Task OnVoiceServerUpdated(DiscordClient discordClient, VoiceServerUpdatedEventArgs voiceServerUpdateEventArgs) + internal async Task OnVoiceServerUpdated(DiscordClient discordClient, VoiceServerUpdatedEventArgs voiceServerUpdateEventArgs) { ArgumentNullException.ThrowIfNull(discordClient); ArgumentNullException.ThrowIfNull(voiceServerUpdateEventArgs); @@ -174,7 +159,7 @@ private async Task OnVoiceServerUpdated(DiscordClient discordClient, VoiceServer Token: voiceServerUpdateEventArgs.VoiceToken, Endpoint: voiceServerUpdateEventArgs.Endpoint); - var eventArgs = new L4N.VoiceServerUpdatedEventArgs( + var eventArgs = new Lavalink4NET.VoiceServerUpdatedEventArgs( guildId: voiceServerUpdateEventArgs.Guild.Id, voiceServer: server); @@ -183,7 +168,7 @@ await VoiceServerUpdated .ConfigureAwait(false); } - private async Task OnVoiceStateUpdated(DiscordClient discordClient, VoiceStateUpdatedEventArgs voiceStateUpdateEventArgs) + internal async Task OnVoiceStateUpdated(DiscordClient discordClient, VoiceStateUpdatedEventArgs voiceStateUpdateEventArgs) { ArgumentNullException.ThrowIfNull(discordClient); ArgumentNullException.ThrowIfNull(voiceStateUpdateEventArgs); @@ -202,7 +187,7 @@ private async Task OnVoiceStateUpdated(DiscordClient discordClient, VoiceStateUp SessionId: sessionId); // invoke event - var eventArgs = new L4N.VoiceStateUpdatedEventArgs( + var eventArgs = new Lavalink4NET.VoiceStateUpdatedEventArgs( guildId: voiceStateUpdateEventArgs.Guild.Id, userId: voiceStateUpdateEventArgs.User.Id, isCurrentUser: voiceStateUpdateEventArgs.User.Id == discordClient.CurrentUser.Id, diff --git a/src/Lavalink4NET.DSharpPlus.Nightly/Lavalink4NET.DSharpPlus.Nightly.csproj b/src/Lavalink4NET.DSharpPlus.Nightly/Lavalink4NET.DSharpPlus.Nightly.csproj index 80045932..4c213398 100644 --- a/src/Lavalink4NET.DSharpPlus.Nightly/Lavalink4NET.DSharpPlus.Nightly.csproj +++ b/src/Lavalink4NET.DSharpPlus.Nightly/Lavalink4NET.DSharpPlus.Nightly.csproj @@ -9,9 +9,10 @@ true False + 1.0.4 - + diff --git a/src/Lavalink4NET.DSharpPlus.Nightly/Lavalink4NETInvokeHandlers.cs b/src/Lavalink4NET.DSharpPlus.Nightly/Lavalink4NETInvokeHandlers.cs new file mode 100644 index 00000000..c1c24838 --- /dev/null +++ b/src/Lavalink4NET.DSharpPlus.Nightly/Lavalink4NETInvokeHandlers.cs @@ -0,0 +1,26 @@ +namespace Lavalink4NET.DSharpPlus; + +using System.Threading.Tasks; +using global::DSharpPlus; +using global::DSharpPlus.EventArgs; +using Lavalink4NET.Clients; + +/// +/// Forwards event triggers to the Lavalink4NET client wrapper +/// +internal sealed class Lavalink4NETInvokeHandlers(IDiscordClientWrapper wrapper) : + IEventHandler, + IEventHandler, + IEventHandler +{ + private readonly DiscordClientWrapper wrapper = (DiscordClientWrapper)wrapper; + + public async Task HandleEventAsync(DiscordClient sender, GuildDownloadCompletedEventArgs eventArgs) + => await wrapper.OnGuildDownloadCompleted(sender, eventArgs); + + public async Task HandleEventAsync(DiscordClient sender, VoiceServerUpdatedEventArgs eventArgs) + => await wrapper.OnVoiceServerUpdated(sender, eventArgs); + + public async Task HandleEventAsync(DiscordClient sender, VoiceStateUpdatedEventArgs eventArgs) + => await wrapper.OnVoiceStateUpdated(sender, eventArgs); +} diff --git a/src/Lavalink4NET.DSharpPlus.Nightly/ServiceCollectionExtensions.cs b/src/Lavalink4NET.DSharpPlus.Nightly/ServiceCollectionExtensions.cs index f127068f..392827c9 100644 --- a/src/Lavalink4NET.DSharpPlus.Nightly/ServiceCollectionExtensions.cs +++ b/src/Lavalink4NET.DSharpPlus.Nightly/ServiceCollectionExtensions.cs @@ -1,6 +1,7 @@ namespace Lavalink4NET.Extensions; using System; +using global::DSharpPlus.Extensions; using Lavalink4NET.DSharpPlus; using Microsoft.Extensions.DependencyInjection; @@ -17,6 +18,10 @@ public static class ServiceCollectionExtensions public static IServiceCollection AddLavalink(this IServiceCollection services) { ArgumentNullException.ThrowIfNull(services); - return services.AddLavalink(); + + services.AddLavalink(); + services.ConfigureEventHandlers(events => events.AddEventHandlers(ServiceLifetime.Transient)); + + return services; } }