lampac/Lampac/Engine/NativeWebSocket.cs
2026-02-07 00:18:50 +03:00

568 lines
21 KiB
C#

using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Caching.Memory;
using Shared;
using Shared.Engine;
using Shared.Models;
using Shared.Models.Events;
using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net.WebSockets;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
namespace Lampac.Engine
{
public class NativeWebSocket : INws
{
#region fields
public static IMemoryCache memoryCache;
static readonly JsonSerializerOptions serializerOptions = new JsonSerializerOptions
{
WriteIndented = false
};
public static readonly ConcurrentDictionary<string, NwsConnection> _connections = new();
static readonly Timer ConnectionMonitorTimer = new Timer(ConnectionMonitorCallback, null, TimeSpan.FromMinutes(1), TimeSpan.FromSeconds(5));
public readonly static ConcurrentDictionary<string, byte> weblog_clients = new();
public readonly static ConcurrentDictionary<string, string> event_clients = new();
public static int CountConnection => _connections.Count;
public int CountWeblogClients => weblog_clients.Count;
public int CountEventClients => event_clients.Count;
#endregion
#region interface
public void WebLog(string message, string plugin)
=> SendLog(message, plugin);
public Task EventsAsync(string connectionId, string uid, string name, string data)
=> SendEvents(connectionId, uid, name, data);
public Task SendAsync(string connectionId, string method, params object[] args)
{
if (connectionId != null && _connections.TryGetValue(connectionId, out var client))
return SendAsync(client, method, args);
return Task.CompletedTask;
}
public ConcurrentDictionary<string, NwsConnection> AllConnections()
=> _connections;
#endregion
#region handle
public static async Task HandleWebSocketAsync(HttpContext context)
{
if (!context.WebSockets.IsWebSocketRequest)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
return;
}
using (var socket = await context.WebSockets.AcceptWebSocketAsync().ConfigureAwait(false))
{
string connectionId = null;
if (context.Request.Query.TryGetValue("id", out var _connectionId))
{
string _id = _connectionId.ToString();
if (!string.IsNullOrWhiteSpace(_id))
{
connectionId = _id;
Cleanup(connectionId);
}
}
if (connectionId == null)
connectionId = Guid.NewGuid().ToString("N");
try
{
var requestInfo = context.Features.Get<RequestModel>();
var connection = new NwsConnection(connectionId, socket, AppInit.Host(context), requestInfo);
var cancellationSource = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted);
connection.SetCancellationSource(cancellationSource);
_connections.AddOrUpdate(connectionId, connection, (k, v) => connection);
await SendAsync(connection, "Connected", connectionId).ConfigureAwait(false);
if (InvkEvent.IsNwsConnected())
InvkEvent.NwsConnected(new EventNwsConnected(connectionId, requestInfo, connection, cancellationSource.Token));
await ReceiveLoopAsync(connection, cancellationSource.Token).ConfigureAwait(false);
}
finally
{
Cleanup(connectionId);
if (InvkEvent.IsNwsDisconnected())
InvkEvent.NwsDisconnected(new EventNwsDisconnected(connectionId));
}
}
}
#endregion
#region receive loop
static async Task ReceiveLoopAsync(NwsConnection connection, CancellationToken token)
{
WebSocket socket = connection.Socket;
var decoder = Encoding.UTF8.GetDecoder();
var buffer = ArrayPool<byte>.Shared.Rent(1024);
var charBuffer = ArrayPool<char>.Shared.Rent(1024);
var rebulder = new StringBuilder(1024);
try
{
while (socket.State == WebSocketState.Open && !token.IsCancellationRequested)
{
#region stats
if (AppInit.conf.openstat.enable && memoryCache != null)
{
var counter = memoryCache.GetOrCreate("stats:nws", entry =>
{
entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(1);
return new CounterNws();
});
Interlocked.Increment(ref counter.receive);
}
#endregion
decoder.Reset();
StringBuilder builder = null;
WebSocketReceiveResult result;
do
{
result = await socket.ReceiveAsync(new ArraySegment<byte>(buffer), token).ConfigureAwait(false);
if (result.MessageType == WebSocketMessageType.Close)
{
using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20)))
await socket.CloseAsync(result.CloseStatus ?? WebSocketCloseStatus.NormalClosure, result.CloseStatusDescription, cts.Token).ConfigureAwait(false);
return;
}
if (result.Count > 0 && result.MessageType == WebSocketMessageType.Text)
{
decoder.Convert(
buffer, 0, result.Count,
charBuffer, 0, charBuffer.Length,
result.EndOfMessage,
out _,
out int charsUsed,
out _
);
if (builder == null)
{
if (rebulder.Capacity > result.Count)
{
rebulder.Clear();
builder = rebulder;
}
else
{
builder = new StringBuilder(result.Count);
}
}
if (charsUsed > 0)
builder.Append(charBuffer, 0, charsUsed);
if (builder.Length > 10_000000)
{
using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20)))
await socket.CloseAsync(WebSocketCloseStatus.MessageTooBig, "payload too large", cts.Token).ConfigureAwait(false);
return;
}
}
}
while (!result.EndOfMessage);
connection.UpdateActivity();
if (builder != null)
{
if (2 > builder.Length)
continue;
if (builder.Length == 4 &&
builder[0] == 'p' &&
builder[1] == 'i' &&
builder[2] == 'n' &&
builder[3] == 'g')
{
continue;
}
await HandleMessageAsync(connection, builder).ConfigureAwait(false);
}
}
}
catch (OperationCanceledException)
{
}
catch (WebSocketException)
{
}
catch (Exception)
{
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
ArrayPool<char>.Shared.Return(charBuffer);
}
}
#endregion
#region message handling
static async Task HandleMessageAsync(NwsConnection connection, StringBuilder payload)
{
try
{
using JsonDocument document = JsonDocument.Parse(payload.ToString());
if (document.RootElement.ValueKind != JsonValueKind.Object)
return;
if (!document.RootElement.TryGetProperty("method", out var methodProp))
return;
string method = methodProp.GetString();
JsonElement args = default;
if (document.RootElement.TryGetProperty("args", out var argsProp) && argsProp.ValueKind == JsonValueKind.Array)
args = argsProp;
if (InvkEvent.IsNwsMessage())
InvkEvent.NwsMessage(new EventNwsMessage(connection.ConnectionId, payload, method, args));
await InvokeAsync(connection, method, args).ConfigureAwait(false);
}
catch (JsonException)
{
}
}
static async Task InvokeAsync(NwsConnection connection, string method, JsonElement args)
{
if (string.IsNullOrEmpty(method))
return;
switch (method.ToLower())
{
case "rchregistry":
if (AppInit.conf.rch.enable)
{
string json = GetStringArg(args, 0);
RchClient.Registry(connection.Ip, connection.ConnectionId, connection.Host, json, connection);
await SendAsync(connection, "RchRegistry", connection.Ip).ConfigureAwait(false);
}
break;
case "rchresult":
if (AppInit.conf.rch.enable)
{
string id = GetStringArg(args, 0);
string value = GetStringArg(args, 1) ?? string.Empty;
if (!string.IsNullOrEmpty(id) && RchClient.rchIds.TryGetValue(id, out var rchHub))
rchHub.tcs.TrySetResult(value);
}
break;
case "registryweblog":
if (AppInit.conf.weblog.enable && Startup.WebLogEnableController)
{
string token = GetStringArg(args, 0);
if (string.IsNullOrEmpty(AppInit.conf.weblog.token) || AppInit.conf.weblog.token == token)
weblog_clients.AddOrUpdate(connection.ConnectionId, 0, static (_, __) => 0);
}
break;
case "weblog":
SendLog(GetStringArg(args, 0), GetStringArg(args, 1));
break;
case "registryevent":
{
string uid = GetStringArg(args, 0);
if (!string.IsNullOrEmpty(uid))
event_clients.AddOrUpdate(connection.ConnectionId, uid, (_, __) => uid);
break;
}
case "events":
{
string uid = GetStringArg(args, 0);
string name = GetStringArg(args, 1);
string data = GetStringArg(args, 2);
if (name == "devices")
{
var uidClients = event_clients
.Where(i => i.Value == uid)
.ToDictionary();
var devices = _connections
.Where(i => i.Value.ConnectionId != connection.ConnectionId)
.Where(i =>
(!AppInit.conf.accsdb.enable && i.Value.Ip == connection.Ip)
|| uidClients.Keys.Contains(i.Key))
.Select(i => new {
uid = uidClients.TryGetValue(i.Value.ConnectionId, out var targetUid) ? targetUid : null,
i.Value.ConnectionId,
i.Value.RequestInfo.UserAgent
})
.ToArray();
await SendAsync(connection, "event", uid, name, devices).ConfigureAwait(false);
break;
}
await SendEvents(connection.ConnectionId, uid, name, data).ConfigureAwait(false);
break;
}
case "eventsid":
{
string targetConnection = GetStringArg(args, 0);
string uid = GetStringArg(args, 1);
string name = GetStringArg(args, 2);
string data = GetStringArg(args, 3);
await SendEventToConnection(targetConnection, uid, name, data).ConfigureAwait(false);
break;
}
case "ping":
await SendAsync(connection, "pong").ConfigureAwait(false);
break;
}
}
static string GetStringArg(JsonElement args, int index)
{
if (args.ValueKind != JsonValueKind.Array || args.GetArrayLength() <= index)
return null;
var element = args[index];
if (element.ValueKind == JsonValueKind.String)
return element.GetString();
if (element.ValueKind == JsonValueKind.Null)
return null;
return element.ToString();
}
#endregion
#region SendAsync
sealed record NwsSendModel(string method, object[] args);
static async Task SendAsync(NwsConnection connection, string method, params object[] args)
{
if (connection.Socket.State != WebSocketState.Open || string.IsNullOrEmpty(method))
return;
try
{
#region stats
if (AppInit.conf.openstat.enable && memoryCache != null)
{
var counter = memoryCache.GetOrCreate("stats:nws", entry =>
{
entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(1);
return new CounterNws();
});
Interlocked.Increment(ref counter.send);
}
#endregion
await connection.SendLock.WaitAsync(TimeSpan.FromSeconds(20)).ConfigureAwait(false);
using (var ms = PoolInvk.msm.GetStream())
{
JsonSerializer.Serialize(ms, new NwsSendModel(method, args = args ?? Array.Empty<object>()), serializerOptions);
ms.Position = 0;
if (connection.Socket.State == WebSocketState.Open)
{
using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)))
{
await connection.Socket
.SendAsync(ms.GetBuffer().AsMemory(0, (int)ms.Length), WebSocketMessageType.Text, true, cts.Token)
.ConfigureAwait(false);
}
connection.UpdateActivity();
connection.UpdateSendActivity();
}
}
}
catch (WebSocketException)
{
}
catch (OperationCanceledException)
{
}
finally
{
connection.SendLock.Release();
}
}
public static void SendLog(string message, string plugin)
{
if (!AppInit.conf.weblog.enable || string.IsNullOrEmpty(message) || string.IsNullOrEmpty(plugin) || message.Length > 4_000000)
return;
if (weblog_clients.IsEmpty || !Startup.WebLogEnableController)
return;
foreach (string connectionId in weblog_clients.Keys)
{
if (_connections.TryGetValue(connectionId, out var client))
_ = SendAsync(client, "Receive", message, plugin).ConfigureAwait(false);
}
}
public static Task SendEvents(string connectionId, string uid, string name, string data)
{
if (string.IsNullOrEmpty(uid) || string.IsNullOrEmpty(name))
return Task.CompletedTask;
var targets = event_clients.Where(i => i.Value == uid && (connectionId == null || i.Key != connectionId)).Select(i => i.Key);
if (!targets.Any())
return Task.CompletedTask;
var tasks = new List<Task>();
foreach (string targetId in targets)
{
if (_connections.TryGetValue(targetId, out var client))
tasks.Add(SendAsync(client, "event", uid, name, data ?? string.Empty));
}
if (tasks.Count == 0)
return Task.CompletedTask;
return Task.WhenAll(tasks);
}
static Task SendEventToConnection(string connectionId, string uid, string name, string data)
{
if (string.IsNullOrEmpty(connectionId) || string.IsNullOrEmpty(name))
return Task.CompletedTask;
if (_connections.TryGetValue(connectionId, out var client))
return SendAsync(client, "event", uid ?? string.Empty, name, data ?? string.Empty);
return Task.CompletedTask;
}
public static Task SendRchRequestAsync(string connectionId, string rchId, string url, string data, Dictionary<string, string> headers, bool returnHeaders)
{
if (string.IsNullOrEmpty(connectionId))
return Task.CompletedTask;
if (_connections.TryGetValue(connectionId, out var client))
return SendAsync(client, "RchClient", rchId, url, data, headers, returnHeaders);
return Task.CompletedTask;
}
#endregion
#region cleanup
public static void Cleanup(string connectionId)
{
if (string.IsNullOrEmpty(connectionId))
return;
if (_connections.TryRemove(connectionId, out var connection))
{
connection.Cancel();
connection.Dispose();
}
weblog_clients.TryRemove(connectionId, out _);
event_clients.TryRemove(connectionId, out _);
RchClient.OnDisconnected(connectionId);
}
#endregion
#region ConnectionMonitorCallback
static int _updatingMonitorCallback = 0;
static void ConnectionMonitorCallback(object state)
{
if (_connections.IsEmpty)
return;
if (Interlocked.Exchange(ref _updatingMonitorCallback, 1) == 1)
return;
try
{
var now = DateTime.UtcNow;
var cutoff = now.AddSeconds(-125);
foreach (var connection in _connections)
{
if (cutoff >= connection.Value.LastActivityUtc)
connection.Value.Cancel();
int inactiveAfterMinutes = AppInit.conf.WebSocket.inactiveAfterMinutes;
if (inactiveAfterMinutes > 0)
{
if (now.AddMinutes(-inactiveAfterMinutes) >= connection.Value.LastSendActivityUtc)
connection.Value.Cancel();
}
}
}
catch
{
}
finally
{
Volatile.Write(ref _updatingMonitorCallback, 0);
}
}
#endregion
#region FullDispose
public static void FullDispose()
{
ConnectionMonitorTimer.Dispose();
foreach (var connection in _connections)
connection.Value.Cancel();
}
#endregion
}
public class CounterNws
{
public int receive;
public int send;
}
}