TermRemoteCtl/apps/windows_agent/src/TermRemoteCtl.Agent/Realtime/TerminalWebSocketHandler.cs

370 lines
12 KiB
C#

using System.Net.WebSockets;
using System.Text;
using System.Text.Json;
using TermRemoteCtl.Agent.Sessions;
using TermRemoteCtl.Agent.Terminal.Screen;
using TermRemoteCtl.Agent.Terminal;
namespace TermRemoteCtl.Agent.Realtime;
public static class TerminalWebSocketHandler
{
private static readonly JsonSerializerOptions JsonOptions = new(JsonSerializerDefaults.Web);
public static WebApplication MapTerminalSocket(this WebApplication app)
{
app.UseWebSockets();
app.Map("/ws/terminal", HandleTerminalSocketAsync);
return app;
}
private static async Task HandleTerminalSocketAsync(HttpContext context)
{
if (!context.WebSockets.IsWebSocketRequest)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
return;
}
var sessionId = context.Request.Query["sessionId"].ToString();
if (string.IsNullOrWhiteSpace(sessionId))
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
return;
}
var registry = context.RequestServices.GetRequiredService<SessionRegistry>();
if (!registry.TryGet(sessionId, out _))
{
context.Response.StatusCode = StatusCodes.Status404NotFound;
return;
}
var host = context.RequestServices.GetRequiredService<ISessionHost>();
var diagnostics = context.RequestServices.GetRequiredService<ITerminalDiagnosticsSink>();
using var socket = await context.WebSockets.AcceptWebSocketAsync().ConfigureAwait(false);
try
{
await host.StartAsync(sessionId, context.RequestAborted).ConfigureAwait(false);
}
catch (Exception ex)
{
if (socket.State == WebSocketState.Open)
{
await socket.CloseAsync(WebSocketCloseStatus.InternalServerError, ex.Message, context.RequestAborted).ConfigureAwait(false);
}
return;
}
using var sendGate = new SemaphoreSlim(1, 1);
void HandleOutput(object? sender, TerminalOutputEventArgs args)
{
if (string.Equals(args.SessionId, sessionId, StringComparison.Ordinal))
{
_ = SendJsonAsync(
socket,
new TerminalOutputResponse(args.SessionId, args.Sequence, args.Chunk),
sendGate,
context.RequestAborted);
}
}
try
{
await registry.RecordAttachAsync(sessionId, context.RequestAborted).ConfigureAwait(false);
var restore = registry.GetRestoreSnapshot(sessionId);
await SendJsonAsync(socket, new TerminalAttachResponse(sessionId), sendGate, context.RequestAborted).ConfigureAwait(false);
await SendJsonAsync(
socket,
new TerminalRestoreResponse(
restore.SessionId,
restore.Sequence,
restore.ScreenText,
restore.PendingInput,
restore.CursorRow,
restore.CursorColumn,
MapScreenSnapshot(restore.ScreenSnapshot)),
sendGate,
context.RequestAborted).ConfigureAwait(false);
host.OutputReceived += HandleOutput;
await ReceiveLoopAsync(
context,
socket,
host,
registry,
diagnostics,
sessionId,
sendGate).ConfigureAwait(false);
}
finally
{
host.OutputReceived -= HandleOutput;
try
{
await registry.RecordDetachAsync(sessionId, CancellationToken.None).ConfigureAwait(false);
}
catch
{
}
}
}
private static async Task ReceiveLoopAsync(
HttpContext context,
WebSocket socket,
ISessionHost host,
SessionRegistry registry,
ITerminalDiagnosticsSink diagnostics,
string sessionId,
SemaphoreSlim sendGate)
{
var buffer = new byte[4096];
while (socket.State == WebSocketState.Open && !context.RequestAborted.IsCancellationRequested)
{
using var message = new MemoryStream();
WebSocketReceiveResult receiveResult;
do
{
receiveResult = await socket.ReceiveAsync(buffer, context.RequestAborted).ConfigureAwait(false);
if (receiveResult.MessageType == WebSocketMessageType.Close)
{
return;
}
message.Write(buffer, 0, receiveResult.Count);
}
while (!receiveResult.EndOfMessage);
if (receiveResult.MessageType != WebSocketMessageType.Text)
{
continue;
}
await HandleClientMessageAsync(
Encoding.UTF8.GetString(message.ToArray()),
socket,
registry,
host,
diagnostics,
sessionId,
sendGate,
context.RequestAborted).ConfigureAwait(false);
}
}
private static async Task HandleClientMessageAsync(
string payload,
WebSocket socket,
SessionRegistry registry,
ISessionHost host,
ITerminalDiagnosticsSink diagnostics,
string sessionId,
SemaphoreSlim sendGate,
CancellationToken cancellationToken)
{
TerminalClientMessage? message;
try
{
message = JsonSerializer.Deserialize<TerminalClientMessage>(payload, JsonOptions);
}
catch (JsonException)
{
return;
}
if (message is null ||
!string.Equals(message.Type, "input", StringComparison.OrdinalIgnoreCase) &&
!string.Equals(message.Type, "resize", StringComparison.OrdinalIgnoreCase))
{
if (message is not null && string.Equals(message.Type, "attach", StringComparison.OrdinalIgnoreCase))
{
return;
}
return;
}
if (string.Equals(message.Type, "input", StringComparison.OrdinalIgnoreCase))
{
if (!string.IsNullOrEmpty(message.Input))
{
if (!string.IsNullOrWhiteSpace(message.InputId))
{
if (!registry.TryBeginInputReceipt(sessionId, message.InputId, out var existingReceipt))
{
if (existingReceipt is not null && await existingReceipt.ConfigureAwait(false))
{
await SendJsonAsync(
socket,
new TerminalInputAckResponse(sessionId, message.InputId),
sendGate,
cancellationToken).ConfigureAwait(false);
}
return;
}
}
try
{
diagnostics.Record("backend.input.received", sessionId, SanitizeDiagnosticText(message.Input));
await host.WriteInputAsync(sessionId, message.Input, cancellationToken).ConfigureAwait(false);
await registry.RecordInputAsync(sessionId, message.Input, cancellationToken).ConfigureAwait(false);
if (!string.IsNullOrWhiteSpace(message.InputId))
{
registry.CompleteInputReceipt(sessionId, message.InputId, succeeded: true);
await SendJsonAsync(
socket,
new TerminalInputAckResponse(sessionId, message.InputId),
sendGate,
cancellationToken).ConfigureAwait(false);
}
}
catch
{
if (!string.IsNullOrWhiteSpace(message.InputId))
{
registry.CompleteInputReceipt(sessionId, message.InputId, succeeded: false);
}
throw;
}
}
return;
}
if (message.Columns is > 0 && message.Rows is > 0)
{
await registry.RecordResizeAsync(
sessionId,
message.Columns.Value,
message.Rows.Value,
cancellationToken).ConfigureAwait(false);
await host.ResizeAsync(sessionId, message.Columns.Value, message.Rows.Value, cancellationToken).ConfigureAwait(false);
}
}
private static string SanitizeDiagnosticText(string input)
{
return input.Replace("\r", "\\r", StringComparison.Ordinal).Replace("\n", "\\n", StringComparison.Ordinal);
}
private static TerminalScreenSnapshotResponse? MapScreenSnapshot(
TerminalScreenSnapshot? snapshot)
{
if (snapshot is null)
{
return null;
}
return new TerminalScreenSnapshotResponse(
snapshot.ScreenVersion,
snapshot.SourceSequence,
snapshot.Rows,
snapshot.Columns,
snapshot.CursorRow,
snapshot.CursorColumn,
snapshot.CursorVisible,
snapshot.ActiveBuffer,
MapScreenBuffer(snapshot.PrimaryBuffer),
snapshot.AlternateBuffer is null ? null : MapScreenBuffer(snapshot.AlternateBuffer));
}
private static TerminalScreenBufferSnapshotResponse MapScreenBuffer(
TerminalScreenBufferSnapshot buffer)
{
return new TerminalScreenBufferSnapshotResponse(
buffer.Viewport
.Select(static line => new TerminalScreenLineSnapshotResponse(line.Index, line.Text))
.ToArray());
}
private static async Task SendJsonAsync(
WebSocket socket,
object response,
SemaphoreSlim sendGate,
CancellationToken cancellationToken)
{
var json = JsonSerializer.SerializeToUtf8Bytes(response, JsonOptions);
await SendAsync(socket, json, sendGate, cancellationToken).ConfigureAwait(false);
}
private static async Task SendAsync(
WebSocket socket,
byte[] payload,
SemaphoreSlim sendGate,
CancellationToken cancellationToken)
{
await sendGate.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
if (socket.State == WebSocketState.Open)
{
await socket.SendAsync(payload, WebSocketMessageType.Text, true, cancellationToken).ConfigureAwait(false);
}
}
catch (WebSocketException)
{
}
finally
{
sendGate.Release();
}
}
private sealed record TerminalAttachResponse(string SessionId, string Type = "attached");
private sealed record TerminalRestoreResponse(
string SessionId,
long Sequence,
string ScreenText,
string PendingInput,
int? CursorRow,
int? CursorColumn,
TerminalScreenSnapshotResponse? ScreenSnapshot,
string Type = "restore");
private sealed record TerminalOutputResponse(
string SessionId,
long Sequence,
string Chunk,
string Type = "output");
private sealed record TerminalInputAckResponse(
string SessionId,
string InputId,
string Type = "inputAck");
private sealed record TerminalScreenSnapshotResponse(
long ScreenVersion,
long SourceSequence,
int Rows,
int Columns,
int CursorRow,
int CursorColumn,
bool CursorVisible,
string ActiveBuffer,
TerminalScreenBufferSnapshotResponse PrimaryBuffer,
TerminalScreenBufferSnapshotResponse? AlternateBuffer);
private sealed record TerminalScreenBufferSnapshotResponse(
IReadOnlyList<TerminalScreenLineSnapshotResponse> Viewport);
private sealed record TerminalScreenLineSnapshotResponse(
int Index,
string Text);
private sealed record TerminalClientMessage(
string Type,
string? SessionId,
string? Input,
string? InputId,
int? Columns,
int? Rows);
}