Fixed session renaming not working, fixed command handling

This commit is contained in:
Anon 2023-05-28 15:15:43 +02:00
parent 1efa55206f
commit 95f6c5768d

View file

@ -41,9 +41,21 @@ internal class MessageReceivedEventArgs : EventArgs
} }
} }
internal class WebSocketSession
{
public string SessionId { get; set; }
public WebSocket WebSocket { get; set; }
public WebSocketSession(string sessionId, WebSocket webSocket)
{
SessionId = sessionId;
WebSocket = webSocket;
}
}
internal class WebSocketServer internal class WebSocketServer
{ {
public readonly ConcurrentDictionary<string, WebSocket> Sessions; public readonly ConcurrentDictionary<string, WebSocketSession> Sessions;
public event EventHandler<SessionEventArgs>? NewSession; public event EventHandler<SessionEventArgs>? NewSession;
public event EventHandler<SessionEventArgs>? SessionDropped; public event EventHandler<SessionEventArgs>? SessionDropped;
public event EventHandler<MessageReceivedEventArgs>? MessageReceived; public event EventHandler<MessageReceivedEventArgs>? MessageReceived;
@ -52,7 +64,7 @@ internal class WebSocketServer
public WebSocketServer() public WebSocketServer()
{ {
Sessions = new ConcurrentDictionary<string, WebSocket>(); Sessions = new ConcurrentDictionary<string, WebSocketSession>();
} }
public async Task Start(string ipAddress, int port) public async Task Start(string ipAddress, int port)
@ -69,9 +81,11 @@ internal class WebSocketServer
var sessionGuid = Guid.NewGuid().ToString(); var sessionGuid = Guid.NewGuid().ToString();
var webSocketContext = await context.AcceptWebSocketAsync(null); var webSocketContext = await context.AcceptWebSocketAsync(null);
var webSocket = webSocketContext.WebSocket; var webSocket = webSocketContext.WebSocket;
Sessions.TryAdd(sessionGuid, webSocket); var webSocketSession = new WebSocketSession(sessionGuid, webSocket);
NewSession?.Invoke(this, new SessionEventArgs(sessionGuid)); NewSession?.Invoke(this, new SessionEventArgs(sessionGuid));
_ = ProcessWebSocketSession(sessionGuid, webSocket); Sessions.TryAdd(sessionGuid, webSocketSession);
_ = ProcessWebSocketSession(webSocketSession);
} }
else else
{ {
@ -85,7 +99,7 @@ internal class WebSocketServer
{ {
foreach (var session in Sessions) foreach (var session in Sessions)
{ {
await session.Value.CloseAsync(WebSocketCloseStatus.NormalClosure, "Server shutting down", await session.Value.WebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Server shutting down",
CancellationToken.None); CancellationToken.None);
} }
@ -93,25 +107,28 @@ internal class WebSocketServer
listener?.Stop(); listener?.Stop();
} }
private async Task ProcessWebSocketSession(string sessionId, WebSocket webSocket) private async Task ProcessWebSocketSession(WebSocketSession webSocketSession)
{ {
var buffer = new byte[1024]; var buffer = new byte[1024];
try try
{ {
while (webSocket.State == WebSocketState.Open) while (webSocketSession.WebSocket.State == WebSocketState.Open)
{ {
var receiveResult = var receiveResult =
await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None); await webSocketSession.WebSocket.ReceiveAsync(new ArraySegment<byte>(buffer),
CancellationToken.None);
if (receiveResult.MessageType == WebSocketMessageType.Text) if (receiveResult.MessageType == WebSocketMessageType.Text)
{ {
var message = Encoding.UTF8.GetString(buffer, 0, receiveResult.Count); var message = Encoding.UTF8.GetString(buffer, 0, receiveResult.Count);
MessageReceived?.Invoke(this, new MessageReceivedEventArgs(sessionId, message)); MessageReceived?.Invoke(this, new MessageReceivedEventArgs(webSocketSession.SessionId, message));
} }
else if (receiveResult.MessageType == WebSocketMessageType.Close) else if (receiveResult.MessageType == WebSocketMessageType.Close)
{ {
await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Connection closed by the client", await webSocketSession.WebSocket.CloseAsync(
WebSocketCloseStatus.NormalClosure,
"Connection closed by the client",
CancellationToken.None); CancellationToken.None);
break; break;
} }
@ -119,8 +136,8 @@ internal class WebSocketServer
} }
finally finally
{ {
Sessions.TryRemove(sessionId, out _); Sessions.TryRemove(webSocketSession.SessionId, out _);
SessionDropped?.Invoke(this, new SessionEventArgs(sessionId)); SessionDropped?.Invoke(this, new SessionEventArgs(webSocketSession.SessionId));
} }
} }
@ -129,17 +146,18 @@ internal class WebSocketServer
if (!Sessions.ContainsKey(oldSessionId) || Sessions.ContainsKey(newSessionId)) if (!Sessions.ContainsKey(oldSessionId) || Sessions.ContainsKey(newSessionId))
return false; return false;
if (!Sessions.TryRemove(oldSessionId, out var webSocket)) if (!Sessions.TryRemove(oldSessionId, out var webSocketSession))
return false; return false;
if (Sessions.TryAdd(newSessionId, webSocket)) webSocketSession.SessionId = newSessionId;
if (Sessions.TryAdd(newSessionId, webSocketSession))
return true; return true;
if (!Sessions.TryAdd(oldSessionId, webSocket)) webSocketSession.SessionId = oldSessionId;
{
// handle the rare case when adding back the old session fails if (!Sessions.TryAdd(oldSessionId, webSocketSession))
throw new Exception("Failed to add back the old session after failed rename"); throw new Exception("Failed to add back the old session after failed rename");
}
return false; return false;
} }
@ -148,10 +166,11 @@ internal class WebSocketServer
{ {
try try
{ {
if (Sessions.TryGetValue(sessionId, out var webSocket)) if (Sessions.TryGetValue(sessionId, out var webSocketSession))
{ {
var buffer = Encoding.UTF8.GetBytes(message); var buffer = Encoding.UTF8.GetBytes(message);
await webSocket.SendAsync(new ArraySegment<byte>(buffer), WebSocketMessageType.Text, true, await webSocketSession.WebSocket.SendAsync(new ArraySegment<byte>(buffer), WebSocketMessageType.Text,
true,
CancellationToken.None); CancellationToken.None);
} }
} }
@ -302,7 +321,7 @@ public class WebSocketBot : ChatBot
if (_server != null) if (_server != null)
{ {
SendEvent("OnWsRestarting", ""); SendEvent("OnWsRestarting", "");
_server.Stop(); _server.Stop(); // If you await, this will freeze the task and the websocket won't work
_server = null; _server = null;
} }
@ -310,7 +329,7 @@ public class WebSocketBot : ChatBot
{ {
LogToConsole(Translations.bot_WebSocketBot_starting); LogToConsole(Translations.bot_WebSocketBot_starting);
_server = new(); _server = new();
_server.Start(_ip!, _port); _server.Start(_ip!, _port); // If you await, this will freeze the task and the websocket won't work
LogToConsole(string.Format(Translations.bot_WebSocketBot_started, _ip, _port.ToString())); LogToConsole(string.Format(Translations.bot_WebSocketBot_started, _ip, _port.ToString()));
@ -323,18 +342,21 @@ public class WebSocketBot : ChatBot
return; return;
} }
_server.NewSession += (sender, session) => _server.NewSession += (_, session) =>
LogToConsole(string.Format(Translations.bot_WebSocketBot_new_session, session.SessionId)); LogToConsole(string.Format(Translations.bot_WebSocketBot_new_session, session.SessionId));
_server.SessionDropped += (sender, session) => _server.SessionDropped += (_, session) =>
LogToConsole(string.Format(Translations.bot_WebSocketBot_session_disconnected, session.SessionId)); LogToConsole(string.Format(Translations.bot_WebSocketBot_session_disconnected, session.SessionId));
_server.MessageReceived += (sender, messageObject) => _server.MessageReceived += (_, messageObject) =>
{ {
if (!ProcessWebsocketCommand(messageObject.SessionId, _password!, messageObject.Message)) if (!ProcessWebsocketCommand(messageObject.SessionId, _password!, messageObject.Message))
return; return;
var command = messageObject.Message;
command = command.StartsWith('/') ? command[1..] : $"send {command}";
CmdResult response = new(); CmdResult response = new();
PerformInternalCommand(messageObject.Message, ref response); PerformInternalCommand(command, ref response);
SendSessionEvent(messageObject.SessionId, "OnMccCommandResponse", $"{{\"response\": \"{response}\"}}"); SendSessionEvent(messageObject.SessionId, "OnMccCommandResponse", $"{{\"response\": \"{response}\"}}");
}; };
}); });
@ -391,6 +413,13 @@ public class WebSocketBot : ChatBot
return false; return false;
} }
// If the session is authenticated, remove the old session id and add the new one
if (_authenticatedSessions.Contains(sessionId))
{
_authenticatedSessions.Remove(sessionId);
_authenticatedSessions.Add(newId);
}
responder.SendSuccessResponse( responder.SendSuccessResponse(
responder.Quote("The session ID was successfully changed to: '" + newId + "'"), true); responder.Quote("The session ID was successfully changed to: '" + newId + "'"), true);
LogToConsole(string.Format(Translations.bot_WebSocketBot_session_id_changed, sessionId, newId)); LogToConsole(string.Format(Translations.bot_WebSocketBot_session_id_changed, sessionId, newId));
@ -969,7 +998,7 @@ public class WebSocketBot : ChatBot
case "GetProtocolVersion": case "GetProtocolVersion":
responder.SendSuccessResponse(JsonConvert.SerializeObject(GetProtocolVersion())); responder.SendSuccessResponse(JsonConvert.SerializeObject(GetProtocolVersion()));
break; break;
default: default:
responder.SendErrorResponse( responder.SendErrorResponse(
responder.Quote($"Unknown command {cmd.Command} received!")); responder.Quote($"Unknown command {cmd.Command} received!"));
@ -1009,7 +1038,6 @@ public class WebSocketBot : ChatBot
} }
} }
SendText(message);
return true; return true;
} }