diff --git a/MinecraftClient/ChatBots/WebSocketBot.cs b/MinecraftClient/ChatBots/WebSocketBot.cs index b4e7df6435..cb19e0e803 100644 --- a/MinecraftClient/ChatBots/WebSocketBot.cs +++ b/MinecraftClient/ChatBots/WebSocketBot.cs @@ -41,9 +41,21 @@ public MessageReceivedEventArgs(string sessionId, string message) } } +internal class WebSocketSession +{ + public string SessionId { get; set; } + public WebSocket WebSocket { get; set; } + + public WebSocketSession(string sessionId, WebSocket webSocket) + { + SessionId = sessionId; + WebSocket = webSocket; + } +} + internal class WebSocketServer { - public readonly ConcurrentDictionary Sessions; + public readonly ConcurrentDictionary Sessions; public event EventHandler? NewSession; public event EventHandler? SessionDropped; public event EventHandler? MessageReceived; @@ -52,7 +64,7 @@ internal class WebSocketServer public WebSocketServer() { - Sessions = new ConcurrentDictionary(); + Sessions = new ConcurrentDictionary(); } public async Task Start(string ipAddress, int port) @@ -69,9 +81,11 @@ public async Task Start(string ipAddress, int port) var sessionGuid = Guid.NewGuid().ToString(); var webSocketContext = await context.AcceptWebSocketAsync(null); var webSocket = webSocketContext.WebSocket; - Sessions.TryAdd(sessionGuid, webSocket); + var webSocketSession = new WebSocketSession(sessionGuid, webSocket); + NewSession?.Invoke(this, new SessionEventArgs(sessionGuid)); - _ = ProcessWebSocketSession(sessionGuid, webSocket); + Sessions.TryAdd(sessionGuid, webSocketSession); + _ = ProcessWebSocketSession(webSocketSession); } else { @@ -85,7 +99,7 @@ public async Task Stop() { foreach (var session in Sessions) { - await session.Value.CloseAsync(WebSocketCloseStatus.NormalClosure, "Server shutting down", + await session.Value.WebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Server shutting down", CancellationToken.None); } @@ -93,25 +107,28 @@ await session.Value.CloseAsync(WebSocketCloseStatus.NormalClosure, "Server shutt listener?.Stop(); } - private async Task ProcessWebSocketSession(string sessionId, WebSocket webSocket) + private async Task ProcessWebSocketSession(WebSocketSession webSocketSession) { var buffer = new byte[1024]; try { - while (webSocket.State == WebSocketState.Open) + while (webSocketSession.WebSocket.State == WebSocketState.Open) { var receiveResult = - await webSocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + await webSocketSession.WebSocket.ReceiveAsync(new ArraySegment(buffer), + CancellationToken.None); if (receiveResult.MessageType == WebSocketMessageType.Text) { var message = Encoding.UTF8.GetString(buffer, 0, receiveResult.Count); - MessageReceived?.Invoke(this, new MessageReceivedEventArgs(sessionId, message)); + MessageReceived?.Invoke(this, new MessageReceivedEventArgs(webSocketSession.SessionId, message)); } else if (receiveResult.MessageType == WebSocketMessageType.Close) { - await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Connection closed by the client", + await webSocketSession.WebSocket.CloseAsync( + WebSocketCloseStatus.NormalClosure, + "Connection closed by the client", CancellationToken.None); break; } @@ -119,8 +136,8 @@ await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Connection close } finally { - Sessions.TryRemove(sessionId, out _); - SessionDropped?.Invoke(this, new SessionEventArgs(sessionId)); + Sessions.TryRemove(webSocketSession.SessionId, out _); + SessionDropped?.Invoke(this, new SessionEventArgs(webSocketSession.SessionId)); } } @@ -129,17 +146,18 @@ public bool RenameSession(string oldSessionId, string newSessionId) if (!Sessions.ContainsKey(oldSessionId) || Sessions.ContainsKey(newSessionId)) return false; - if (!Sessions.TryRemove(oldSessionId, out var webSocket)) + if (!Sessions.TryRemove(oldSessionId, out var webSocketSession)) return false; - if (Sessions.TryAdd(newSessionId, webSocket)) + webSocketSession.SessionId = newSessionId; + + if (Sessions.TryAdd(newSessionId, webSocketSession)) return true; - if (!Sessions.TryAdd(oldSessionId, webSocket)) - { - // handle the rare case when adding back the old session fails + webSocketSession.SessionId = oldSessionId; + + if (!Sessions.TryAdd(oldSessionId, webSocketSession)) throw new Exception("Failed to add back the old session after failed rename"); - } return false; } @@ -148,10 +166,11 @@ public async Task SendToSession(string sessionId, string message) { try { - if (Sessions.TryGetValue(sessionId, out var webSocket)) + if (Sessions.TryGetValue(sessionId, out var webSocketSession)) { var buffer = Encoding.UTF8.GetBytes(message); - await webSocket.SendAsync(new ArraySegment(buffer), WebSocketMessageType.Text, true, + await webSocketSession.WebSocket.SendAsync(new ArraySegment(buffer), WebSocketMessageType.Text, + true, CancellationToken.None); } } @@ -302,7 +321,7 @@ public override void Initialize() if (_server != null) { SendEvent("OnWsRestarting", ""); - _server.Stop(); + _server.Stop(); // If you await, this will freeze the task and the websocket won't work _server = null; } @@ -310,7 +329,7 @@ public override void Initialize() { LogToConsole(Translations.bot_WebSocketBot_starting); _server = new(); - _server.Start(_ip!, _port); + _server.Start(_ip!, _port); // If you await, this will freeze the task and the websocket won't work LogToConsole(string.Format(Translations.bot_WebSocketBot_started, _ip, _port.ToString())); @@ -323,18 +342,21 @@ public override void Initialize() return; } - _server.NewSession += (sender, session) => + _server.NewSession += (_, session) => LogToConsole(string.Format(Translations.bot_WebSocketBot_new_session, session.SessionId)); - _server.SessionDropped += (sender, session) => + _server.SessionDropped += (_, session) => LogToConsole(string.Format(Translations.bot_WebSocketBot_session_disconnected, session.SessionId)); - _server.MessageReceived += (sender, messageObject) => + _server.MessageReceived += (_, messageObject) => { if (!ProcessWebsocketCommand(messageObject.SessionId, _password!, messageObject.Message)) return; + var command = messageObject.Message; + command = command.StartsWith('/') ? command[1..] : $"send {command}"; + CmdResult response = new(); - PerformInternalCommand(messageObject.Message, ref response); + PerformInternalCommand(command, ref response); SendSessionEvent(messageObject.SessionId, "OnMccCommandResponse", $"{{\"response\": \"{response}\"}}"); }; }); @@ -391,6 +413,13 @@ private bool ProcessWebsocketCommand(string sessionId, string password, string m return false; } + // If the session is authenticated, remove the old session id and add the new one + if (_authenticatedSessions.Contains(sessionId)) + { + _authenticatedSessions.Remove(sessionId); + _authenticatedSessions.Add(newId); + } + responder.SendSuccessResponse( responder.Quote("The session ID was successfully changed to: '" + newId + "'"), true); LogToConsole(string.Format(Translations.bot_WebSocketBot_session_id_changed, sessionId, newId)); @@ -969,7 +998,7 @@ private bool ProcessWebsocketCommand(string sessionId, string password, string m case "GetProtocolVersion": responder.SendSuccessResponse(JsonConvert.SerializeObject(GetProtocolVersion())); break; - + default: responder.SendErrorResponse( responder.Quote($"Unknown command {cmd.Command} received!")); @@ -1009,7 +1038,6 @@ private bool ProcessWebsocketCommand(string sessionId, string password, string m } } - SendText(message); return true; }