Skip to content

Commit

Permalink
Fixed session renaming not working, fixed command handling
Browse files Browse the repository at this point in the history
  • Loading branch information
milutinke committed May 28, 2023
1 parent 1efa552 commit 95f6c57
Showing 1 changed file with 56 additions and 28 deletions.
84 changes: 56 additions & 28 deletions MinecraftClient/ChatBots/WebSocketBot.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,21 @@ public MessageReceivedEventArgs(string sessionId, string message)
}
}

internal class WebSocketSession
{
public string SessionId { get; set; }
public WebSocket WebSocket { get; set; }

public WebSocketSession(string sessionId, WebSocket webSocket)
{
SessionId = sessionId;
WebSocket = webSocket;
}
}

internal class WebSocketServer
{
public readonly ConcurrentDictionary<string, WebSocket> Sessions;
public readonly ConcurrentDictionary<string, WebSocketSession> Sessions;
public event EventHandler<SessionEventArgs>? NewSession;
public event EventHandler<SessionEventArgs>? SessionDropped;
public event EventHandler<MessageReceivedEventArgs>? MessageReceived;
Expand All @@ -52,7 +64,7 @@ internal class WebSocketServer

public WebSocketServer()
{
Sessions = new ConcurrentDictionary<string, WebSocket>();
Sessions = new ConcurrentDictionary<string, WebSocketSession>();
}

public async Task Start(string ipAddress, int port)
Expand All @@ -69,9 +81,11 @@ public async Task Start(string ipAddress, int port)
var sessionGuid = Guid.NewGuid().ToString();
var webSocketContext = await context.AcceptWebSocketAsync(null);
var webSocket = webSocketContext.WebSocket;
Sessions.TryAdd(sessionGuid, webSocket);
var webSocketSession = new WebSocketSession(sessionGuid, webSocket);

NewSession?.Invoke(this, new SessionEventArgs(sessionGuid));
_ = ProcessWebSocketSession(sessionGuid, webSocket);
Sessions.TryAdd(sessionGuid, webSocketSession);
_ = ProcessWebSocketSession(webSocketSession);
}
else
{
Expand All @@ -85,42 +99,45 @@ public async Task Stop()
{
foreach (var session in Sessions)
{
await session.Value.CloseAsync(WebSocketCloseStatus.NormalClosure, "Server shutting down",
await session.Value.WebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Server shutting down",
CancellationToken.None);
}

Sessions.Clear();
listener?.Stop();
}

private async Task ProcessWebSocketSession(string sessionId, WebSocket webSocket)
private async Task ProcessWebSocketSession(WebSocketSession webSocketSession)
{
var buffer = new byte[1024];

try
{
while (webSocket.State == WebSocketState.Open)
while (webSocketSession.WebSocket.State == WebSocketState.Open)
{
var receiveResult =
await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None);
await webSocketSession.WebSocket.ReceiveAsync(new ArraySegment<byte>(buffer),
CancellationToken.None);

if (receiveResult.MessageType == WebSocketMessageType.Text)
{
var message = Encoding.UTF8.GetString(buffer, 0, receiveResult.Count);
MessageReceived?.Invoke(this, new MessageReceivedEventArgs(sessionId, message));
MessageReceived?.Invoke(this, new MessageReceivedEventArgs(webSocketSession.SessionId, message));
}
else if (receiveResult.MessageType == WebSocketMessageType.Close)
{
await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Connection closed by the client",
await webSocketSession.WebSocket.CloseAsync(
WebSocketCloseStatus.NormalClosure,
"Connection closed by the client",
CancellationToken.None);
break;
}
}
}
finally
{
Sessions.TryRemove(sessionId, out _);
SessionDropped?.Invoke(this, new SessionEventArgs(sessionId));
Sessions.TryRemove(webSocketSession.SessionId, out _);
SessionDropped?.Invoke(this, new SessionEventArgs(webSocketSession.SessionId));
}
}

Expand All @@ -129,17 +146,18 @@ public bool RenameSession(string oldSessionId, string newSessionId)
if (!Sessions.ContainsKey(oldSessionId) || Sessions.ContainsKey(newSessionId))
return false;

if (!Sessions.TryRemove(oldSessionId, out var webSocket))
if (!Sessions.TryRemove(oldSessionId, out var webSocketSession))
return false;

if (Sessions.TryAdd(newSessionId, webSocket))
webSocketSession.SessionId = newSessionId;

if (Sessions.TryAdd(newSessionId, webSocketSession))
return true;

if (!Sessions.TryAdd(oldSessionId, webSocket))
{
// handle the rare case when adding back the old session fails
webSocketSession.SessionId = oldSessionId;

if (!Sessions.TryAdd(oldSessionId, webSocketSession))
throw new Exception("Failed to add back the old session after failed rename");
}

return false;
}
Expand All @@ -148,10 +166,11 @@ public async Task SendToSession(string sessionId, string message)
{
try
{
if (Sessions.TryGetValue(sessionId, out var webSocket))
if (Sessions.TryGetValue(sessionId, out var webSocketSession))
{
var buffer = Encoding.UTF8.GetBytes(message);
await webSocket.SendAsync(new ArraySegment<byte>(buffer), WebSocketMessageType.Text, true,
await webSocketSession.WebSocket.SendAsync(new ArraySegment<byte>(buffer), WebSocketMessageType.Text,
true,
CancellationToken.None);
}
}
Expand Down Expand Up @@ -302,15 +321,15 @@ public override void Initialize()
if (_server != null)
{
SendEvent("OnWsRestarting", "");
_server.Stop();
_server.Stop(); // If you await, this will freeze the task and the websocket won't work
_server = null;
}
try
{
LogToConsole(Translations.bot_WebSocketBot_starting);
_server = new();
_server.Start(_ip!, _port);
_server.Start(_ip!, _port); // If you await, this will freeze the task and the websocket won't work
LogToConsole(string.Format(Translations.bot_WebSocketBot_started, _ip, _port.ToString()));
Expand All @@ -323,18 +342,21 @@ public override void Initialize()
return;
}
_server.NewSession += (sender, session) =>
_server.NewSession += (_, session) =>
LogToConsole(string.Format(Translations.bot_WebSocketBot_new_session, session.SessionId));
_server.SessionDropped += (sender, session) =>
_server.SessionDropped += (_, session) =>
LogToConsole(string.Format(Translations.bot_WebSocketBot_session_disconnected, session.SessionId));
_server.MessageReceived += (sender, messageObject) =>
_server.MessageReceived += (_, messageObject) =>
{
if (!ProcessWebsocketCommand(messageObject.SessionId, _password!, messageObject.Message))
return;
var command = messageObject.Message;
command = command.StartsWith('/') ? command[1..] : $"send {command}";
CmdResult response = new();
PerformInternalCommand(messageObject.Message, ref response);
PerformInternalCommand(command, ref response);
SendSessionEvent(messageObject.SessionId, "OnMccCommandResponse", $"{{\"response\": \"{response}\"}}");
};
});
Expand Down Expand Up @@ -391,6 +413,13 @@ private bool ProcessWebsocketCommand(string sessionId, string password, string m
return false;
}

// If the session is authenticated, remove the old session id and add the new one
if (_authenticatedSessions.Contains(sessionId))
{
_authenticatedSessions.Remove(sessionId);
_authenticatedSessions.Add(newId);
}

responder.SendSuccessResponse(
responder.Quote("The session ID was successfully changed to: '" + newId + "'"), true);
LogToConsole(string.Format(Translations.bot_WebSocketBot_session_id_changed, sessionId, newId));
Expand Down Expand Up @@ -969,7 +998,7 @@ private bool ProcessWebsocketCommand(string sessionId, string password, string m
case "GetProtocolVersion":
responder.SendSuccessResponse(JsonConvert.SerializeObject(GetProtocolVersion()));
break;

default:
responder.SendErrorResponse(
responder.Quote($"Unknown command {cmd.Command} received!"));
Expand Down Expand Up @@ -1009,7 +1038,6 @@ private bool ProcessWebsocketCommand(string sessionId, string password, string m
}
}

SendText(message);
return true;
}

Expand Down

0 comments on commit 95f6c57

Please sign in to comment.