From b52e265e767d4d32eb5d2bb1bbc34d3ee6daeea0 Mon Sep 17 00:00:00 2001 From: NycroV <83246959+NycroV@users.noreply.github.com> Date: Tue, 13 Aug 2024 08:06:39 -0400 Subject: [PATCH] Prepare for incoming DSharpPlus IShardOrchestrator changes --- .../DSharpPlusUtilities.cs | 35 +++++++++++++++++++ .../DiscordClientWrapper.cs | 30 ++-------------- 2 files changed, 38 insertions(+), 27 deletions(-) create mode 100644 src/Lavalink4NET.DSharpPlus.Nightly/DSharpPlusUtilities.cs diff --git a/src/Lavalink4NET.DSharpPlus.Nightly/DSharpPlusUtilities.cs b/src/Lavalink4NET.DSharpPlus.Nightly/DSharpPlusUtilities.cs new file mode 100644 index 00000000..afbfea71 --- /dev/null +++ b/src/Lavalink4NET.DSharpPlus.Nightly/DSharpPlusUtilities.cs @@ -0,0 +1,35 @@ +namespace Lavalink4NET.DSharpPlus; + +using System.Reflection; +using global::DSharpPlus; +using global::DSharpPlus.Clients; + +/// +/// An utility for getting internal / private fields from DSharpPlus WebSocket Gateway Payloads. +/// +public static partial class DSharpPlusUtilities +{ + /// + /// The internal "orchestrator" property info in . + /// + private static readonly FieldInfo orchestratorField = + typeof(DiscordClient).GetField("orchestrator", BindingFlags.NonPublic | BindingFlags.Instance)!; + + /// + /// Gets the amount of shards handled by this client's orchestrator. + /// + public static int GetConnectedShardCount(this DiscordClient client) + { + var orchestrator = (IShardOrchestrator)orchestratorField.GetValue(client)!; + return orchestrator.ConnectedShardCount; + } + + /// + /// Gets the total amount of shards connected to this bot. + /// + public static int GetTotalShardCount(this DiscordClient client) + { + var orchestrator = (IShardOrchestrator)orchestratorField.GetValue(client)!; + return orchestrator.TotalShardCount; + } +} diff --git a/src/Lavalink4NET.DSharpPlus.Nightly/DiscordClientWrapper.cs b/src/Lavalink4NET.DSharpPlus.Nightly/DiscordClientWrapper.cs index 23b2051d..dd095220 100644 --- a/src/Lavalink4NET.DSharpPlus.Nightly/DiscordClientWrapper.cs +++ b/src/Lavalink4NET.DSharpPlus.Nightly/DiscordClientWrapper.cs @@ -31,12 +31,6 @@ public sealed class DiscordClientWrapper : IDiscordClientWrapper private readonly ILogger _logger; private readonly TaskCompletionSource _readyTaskCompletionSource; - /// - /// Re-assign this delegate in the client configuration to change the way the connected shard count is retrived.
- /// You will only need to do this if you are using a custom IShardOrchestrator for your client. - ///
- public Func> GetShardCount { get; set; } - /// /// Creates a new instance of . /// @@ -50,25 +44,6 @@ public DiscordClientWrapper(DiscordClient discordClient, ILogger(TaskCreationOptions.RunContinuationsAsynchronously); - - FieldInfo orchestratorField = typeof(DiscordClient).GetField("orchestrator", BindingFlags.NonPublic | BindingFlags.Instance)!; - var orchestrator = (IShardOrchestrator)orchestratorField.GetValue(discordClient)!; - - if (orchestrator is SingleShardOrchestrator) - GetShardCount = () => Task.FromResult(1); - - else if (orchestrator is MultiShardOrchestrator multiShardOrchestrator) - { - FieldInfo shardCountField = typeof(MultiShardOrchestrator).GetField("shardCount", BindingFlags.NonPublic | BindingFlags.Instance)!; - GetShardCount = () => Task.Run(() => (int)(uint)shardCountField.GetValue(multiShardOrchestrator)!); - } - - else - { - GetShardCount = () => Task.Run(async () => (await discordClient.GetGatewayInfoAsync()).ShardCount); - _logger.LogInformation("The DiscordClient is configured to use a non-default Shard Orchestrator - " + - "make sure that this wrapper's GetShardCount property is configured to properly retrieve the shard count"); - } } /// @@ -158,7 +133,7 @@ public ValueTask WaitForReadyAsync(CancellationToken cancella return new(_readyTaskCompletionSource.Task.WaitAsync(cancellationToken)); } - internal async Task OnGuildDownloadCompleted(DiscordClient discordClient, GuildDownloadCompletedEventArgs eventArgs) + internal Task OnGuildDownloadCompleted(DiscordClient discordClient, GuildDownloadCompletedEventArgs eventArgs) { ArgumentNullException.ThrowIfNull(discordClient); ArgumentNullException.ThrowIfNull(eventArgs); @@ -166,9 +141,10 @@ internal async Task OnGuildDownloadCompleted(DiscordClient discordClient, GuildD var clientInformation = new ClientInformation( Label: "DSharpPlus", CurrentUserId: discordClient.CurrentUser.Id, - ShardCount: await GetShardCount()); + ShardCount: discordClient.GetConnectedShardCount()); _readyTaskCompletionSource.TrySetResult(clientInformation); + return Task.CompletedTask; } internal async Task OnVoiceServerUpdated(DiscordClient discordClient, VoiceServerUpdatedEventArgs voiceServerUpdateEventArgs)