using System.Net.WebSockets; using System.Text; using System.Text.Json; using TermRemoteCtl.Agent.Sessions; using TermRemoteCtl.Agent.Terminal.Screen; using TermRemoteCtl.Agent.Terminal; namespace TermRemoteCtl.Agent.Realtime; public static class TerminalWebSocketHandler { private static readonly JsonSerializerOptions JsonOptions = new(JsonSerializerDefaults.Web); public static WebApplication MapTerminalSocket(this WebApplication app) { app.UseWebSockets(); app.Map("/ws/terminal", HandleTerminalSocketAsync); return app; } private static async Task HandleTerminalSocketAsync(HttpContext context) { if (!context.WebSockets.IsWebSocketRequest) { context.Response.StatusCode = StatusCodes.Status400BadRequest; return; } var sessionId = context.Request.Query["sessionId"].ToString(); if (string.IsNullOrWhiteSpace(sessionId)) { context.Response.StatusCode = StatusCodes.Status400BadRequest; return; } var registry = context.RequestServices.GetRequiredService(); if (!registry.TryGet(sessionId, out _)) { context.Response.StatusCode = StatusCodes.Status404NotFound; return; } var host = context.RequestServices.GetRequiredService(); var diagnostics = context.RequestServices.GetRequiredService(); using var socket = await context.WebSockets.AcceptWebSocketAsync().ConfigureAwait(false); try { await host.StartAsync(sessionId, context.RequestAborted).ConfigureAwait(false); } catch (Exception ex) { if (socket.State == WebSocketState.Open) { await socket.CloseAsync(WebSocketCloseStatus.InternalServerError, ex.Message, context.RequestAborted).ConfigureAwait(false); } return; } using var sendGate = new SemaphoreSlim(1, 1); void HandleOutput(object? sender, TerminalOutputEventArgs args) { if (string.Equals(args.SessionId, sessionId, StringComparison.Ordinal)) { _ = SendJsonAsync( socket, new TerminalOutputResponse(args.SessionId, args.Sequence, args.Chunk), sendGate, context.RequestAborted); } } try { await registry.RecordAttachAsync(sessionId, context.RequestAborted).ConfigureAwait(false); var restore = registry.GetRestoreSnapshot(sessionId); await SendJsonAsync(socket, new TerminalAttachResponse(sessionId), sendGate, context.RequestAborted).ConfigureAwait(false); await SendJsonAsync( socket, new TerminalRestoreResponse( restore.SessionId, restore.Sequence, restore.ScreenText, restore.PendingInput, restore.CursorRow, restore.CursorColumn, MapScreenSnapshot(restore.ScreenSnapshot)), sendGate, context.RequestAborted).ConfigureAwait(false); host.OutputReceived += HandleOutput; await ReceiveLoopAsync( context, socket, host, registry, diagnostics, sessionId, sendGate).ConfigureAwait(false); } finally { host.OutputReceived -= HandleOutput; try { await registry.RecordDetachAsync(sessionId, CancellationToken.None).ConfigureAwait(false); } catch { } } } private static async Task ReceiveLoopAsync( HttpContext context, WebSocket socket, ISessionHost host, SessionRegistry registry, ITerminalDiagnosticsSink diagnostics, string sessionId, SemaphoreSlim sendGate) { var buffer = new byte[4096]; while (socket.State == WebSocketState.Open && !context.RequestAborted.IsCancellationRequested) { using var message = new MemoryStream(); WebSocketReceiveResult receiveResult; do { receiveResult = await socket.ReceiveAsync(buffer, context.RequestAborted).ConfigureAwait(false); if (receiveResult.MessageType == WebSocketMessageType.Close) { return; } message.Write(buffer, 0, receiveResult.Count); } while (!receiveResult.EndOfMessage); if (receiveResult.MessageType != WebSocketMessageType.Text) { continue; } await HandleClientMessageAsync( Encoding.UTF8.GetString(message.ToArray()), socket, registry, host, diagnostics, sessionId, sendGate, context.RequestAborted).ConfigureAwait(false); } } private static async Task HandleClientMessageAsync( string payload, WebSocket socket, SessionRegistry registry, ISessionHost host, ITerminalDiagnosticsSink diagnostics, string sessionId, SemaphoreSlim sendGate, CancellationToken cancellationToken) { TerminalClientMessage? message; try { message = JsonSerializer.Deserialize(payload, JsonOptions); } catch (JsonException) { return; } if (message is null || !string.Equals(message.Type, "input", StringComparison.OrdinalIgnoreCase) && !string.Equals(message.Type, "resize", StringComparison.OrdinalIgnoreCase)) { if (message is not null && string.Equals(message.Type, "attach", StringComparison.OrdinalIgnoreCase)) { return; } return; } if (string.Equals(message.Type, "input", StringComparison.OrdinalIgnoreCase)) { if (!string.IsNullOrEmpty(message.Input)) { if (!string.IsNullOrWhiteSpace(message.InputId)) { if (!registry.TryBeginInputReceipt(sessionId, message.InputId, out var existingReceipt)) { if (existingReceipt is not null && await existingReceipt.ConfigureAwait(false)) { await SendJsonAsync( socket, new TerminalInputAckResponse(sessionId, message.InputId), sendGate, cancellationToken).ConfigureAwait(false); } return; } } try { diagnostics.Record("backend.input.received", sessionId, SanitizeDiagnosticText(message.Input)); await host.WriteInputAsync(sessionId, message.Input, cancellationToken).ConfigureAwait(false); await registry.RecordInputAsync(sessionId, message.Input, cancellationToken).ConfigureAwait(false); if (!string.IsNullOrWhiteSpace(message.InputId)) { registry.CompleteInputReceipt(sessionId, message.InputId, succeeded: true); await SendJsonAsync( socket, new TerminalInputAckResponse(sessionId, message.InputId), sendGate, cancellationToken).ConfigureAwait(false); } } catch { if (!string.IsNullOrWhiteSpace(message.InputId)) { registry.CompleteInputReceipt(sessionId, message.InputId, succeeded: false); } throw; } } return; } if (message.Columns is > 0 && message.Rows is > 0) { await registry.RecordResizeAsync( sessionId, message.Columns.Value, message.Rows.Value, cancellationToken).ConfigureAwait(false); await host.ResizeAsync(sessionId, message.Columns.Value, message.Rows.Value, cancellationToken).ConfigureAwait(false); } } private static string SanitizeDiagnosticText(string input) { return input.Replace("\r", "\\r", StringComparison.Ordinal).Replace("\n", "\\n", StringComparison.Ordinal); } private static TerminalScreenSnapshotResponse? MapScreenSnapshot( TerminalScreenSnapshot? snapshot) { if (snapshot is null) { return null; } return new TerminalScreenSnapshotResponse( snapshot.ScreenVersion, snapshot.SourceSequence, snapshot.Rows, snapshot.Columns, snapshot.CursorRow, snapshot.CursorColumn, snapshot.CursorVisible, snapshot.ActiveBuffer, MapScreenBuffer(snapshot.PrimaryBuffer), snapshot.AlternateBuffer is null ? null : MapScreenBuffer(snapshot.AlternateBuffer)); } private static TerminalScreenBufferSnapshotResponse MapScreenBuffer( TerminalScreenBufferSnapshot buffer) { return new TerminalScreenBufferSnapshotResponse( buffer.Viewport .Select(static line => new TerminalScreenLineSnapshotResponse(line.Index, line.Text)) .ToArray()); } private static async Task SendJsonAsync( WebSocket socket, object response, SemaphoreSlim sendGate, CancellationToken cancellationToken) { var json = JsonSerializer.SerializeToUtf8Bytes(response, JsonOptions); await SendAsync(socket, json, sendGate, cancellationToken).ConfigureAwait(false); } private static async Task SendAsync( WebSocket socket, byte[] payload, SemaphoreSlim sendGate, CancellationToken cancellationToken) { await sendGate.WaitAsync(cancellationToken).ConfigureAwait(false); try { if (socket.State == WebSocketState.Open) { await socket.SendAsync(payload, WebSocketMessageType.Text, true, cancellationToken).ConfigureAwait(false); } } catch (WebSocketException) { } finally { sendGate.Release(); } } private sealed record TerminalAttachResponse(string SessionId, string Type = "attached"); private sealed record TerminalRestoreResponse( string SessionId, long Sequence, string ScreenText, string PendingInput, int? CursorRow, int? CursorColumn, TerminalScreenSnapshotResponse? ScreenSnapshot, string Type = "restore"); private sealed record TerminalOutputResponse( string SessionId, long Sequence, string Chunk, string Type = "output"); private sealed record TerminalInputAckResponse( string SessionId, string InputId, string Type = "inputAck"); private sealed record TerminalScreenSnapshotResponse( long ScreenVersion, long SourceSequence, int Rows, int Columns, int CursorRow, int CursorColumn, bool CursorVisible, string ActiveBuffer, TerminalScreenBufferSnapshotResponse PrimaryBuffer, TerminalScreenBufferSnapshotResponse? AlternateBuffer); private sealed record TerminalScreenBufferSnapshotResponse( IReadOnlyList Viewport); private sealed record TerminalScreenLineSnapshotResponse( int Index, string Text); private sealed record TerminalClientMessage( string Type, string? SessionId, string? Input, string? InputId, int? Columns, int? Rows); }