using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Caching.Memory; using Shared; using Shared.Engine; using Shared.Models; using Shared.Models.Events; using System; using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Net.WebSockets; using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; namespace Lampac.Engine { public class NativeWebSocket : INws { #region fields public static IMemoryCache memoryCache; static readonly JsonSerializerOptions serializerOptions = new JsonSerializerOptions { WriteIndented = false }; public static readonly ConcurrentDictionary _connections = new(); static readonly Timer ConnectionMonitorTimer = new Timer(ConnectionMonitorCallback, null, TimeSpan.FromMinutes(1), TimeSpan.FromSeconds(5)); public readonly static ConcurrentDictionary weblog_clients = new(); public readonly static ConcurrentDictionary event_clients = new(); public static int CountConnection => _connections.Count; public int CountWeblogClients => weblog_clients.Count; public int CountEventClients => event_clients.Count; #endregion #region interface public void WebLog(string message, string plugin) => SendLog(message, plugin); public Task EventsAsync(string connectionId, string uid, string name, string data) => SendEvents(connectionId, uid, name, data); public Task SendAsync(string connectionId, string method, params object[] args) { if (connectionId != null && _connections.TryGetValue(connectionId, out var client)) return SendAsync(client, method, args); return Task.CompletedTask; } public ConcurrentDictionary AllConnections() => _connections; #endregion #region handle public static async Task HandleWebSocketAsync(HttpContext context) { if (!context.WebSockets.IsWebSocketRequest) { context.Response.StatusCode = StatusCodes.Status400BadRequest; return; } using (var socket = await context.WebSockets.AcceptWebSocketAsync().ConfigureAwait(false)) { string connectionId = null; if (context.Request.Query.TryGetValue("id", out var _connectionId)) { string _id = _connectionId.ToString(); if (!string.IsNullOrWhiteSpace(_id)) { connectionId = _id; Cleanup(connectionId); } } if (connectionId == null) connectionId = Guid.NewGuid().ToString("N"); try { var requestInfo = context.Features.Get(); var connection = new NwsConnection(connectionId, socket, AppInit.Host(context), requestInfo); var cancellationSource = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted); connection.SetCancellationSource(cancellationSource); _connections.AddOrUpdate(connectionId, connection, (k, v) => connection); await SendAsync(connection, "Connected", connectionId).ConfigureAwait(false); if (InvkEvent.IsNwsConnected()) InvkEvent.NwsConnected(new EventNwsConnected(connectionId, requestInfo, connection, cancellationSource.Token)); await ReceiveLoopAsync(connection, cancellationSource.Token).ConfigureAwait(false); } finally { Cleanup(connectionId); if (InvkEvent.IsNwsDisconnected()) InvkEvent.NwsDisconnected(new EventNwsDisconnected(connectionId)); } } } #endregion #region receive loop static async Task ReceiveLoopAsync(NwsConnection connection, CancellationToken token) { WebSocket socket = connection.Socket; var decoder = Encoding.UTF8.GetDecoder(); var buffer = ArrayPool.Shared.Rent(1024); var charBuffer = ArrayPool.Shared.Rent(1024); var rebulder = new StringBuilder(1024); try { while (socket.State == WebSocketState.Open && !token.IsCancellationRequested) { #region stats if (AppInit.conf.openstat.enable && memoryCache != null) { var counter = memoryCache.GetOrCreate("stats:nws", entry => { entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(1); return new CounterNws(); }); Interlocked.Increment(ref counter.receive); } #endregion decoder.Reset(); StringBuilder builder = null; WebSocketReceiveResult result; do { result = await socket.ReceiveAsync(new ArraySegment(buffer), token).ConfigureAwait(false); if (result.MessageType == WebSocketMessageType.Close) { using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20))) await socket.CloseAsync(result.CloseStatus ?? WebSocketCloseStatus.NormalClosure, result.CloseStatusDescription, cts.Token).ConfigureAwait(false); return; } if (result.Count > 0 && result.MessageType == WebSocketMessageType.Text) { decoder.Convert( buffer, 0, result.Count, charBuffer, 0, charBuffer.Length, result.EndOfMessage, out _, out int charsUsed, out _ ); if (builder == null) { if (rebulder.Capacity > result.Count) { rebulder.Clear(); builder = rebulder; } else { builder = new StringBuilder(result.Count); } } if (charsUsed > 0) builder.Append(charBuffer, 0, charsUsed); if (builder.Length > 10_000000) { using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20))) await socket.CloseAsync(WebSocketCloseStatus.MessageTooBig, "payload too large", cts.Token).ConfigureAwait(false); return; } } } while (!result.EndOfMessage); connection.UpdateActivity(); if (builder != null) { if (2 > builder.Length) continue; if (builder.Length == 4 && builder[0] == 'p' && builder[1] == 'i' && builder[2] == 'n' && builder[3] == 'g') { continue; } await HandleMessageAsync(connection, builder).ConfigureAwait(false); } } } catch (OperationCanceledException) { } catch (WebSocketException) { } catch (Exception) { } finally { ArrayPool.Shared.Return(buffer); ArrayPool.Shared.Return(charBuffer); } } #endregion #region message handling static async Task HandleMessageAsync(NwsConnection connection, StringBuilder payload) { try { using JsonDocument document = JsonDocument.Parse(payload.ToString()); if (document.RootElement.ValueKind != JsonValueKind.Object) return; if (!document.RootElement.TryGetProperty("method", out var methodProp)) return; string method = methodProp.GetString(); JsonElement args = default; if (document.RootElement.TryGetProperty("args", out var argsProp) && argsProp.ValueKind == JsonValueKind.Array) args = argsProp; if (InvkEvent.IsNwsMessage()) InvkEvent.NwsMessage(new EventNwsMessage(connection.ConnectionId, payload, method, args)); await InvokeAsync(connection, method, args).ConfigureAwait(false); } catch (JsonException) { } } static async Task InvokeAsync(NwsConnection connection, string method, JsonElement args) { if (string.IsNullOrEmpty(method)) return; switch (method.ToLower()) { case "rchregistry": if (AppInit.conf.rch.enable) { string json = GetStringArg(args, 0); RchClient.Registry(connection.Ip, connection.ConnectionId, connection.Host, json, connection); await SendAsync(connection, "RchRegistry", connection.Ip).ConfigureAwait(false); } break; case "rchresult": if (AppInit.conf.rch.enable) { string id = GetStringArg(args, 0); string value = GetStringArg(args, 1) ?? string.Empty; if (!string.IsNullOrEmpty(id) && RchClient.rchIds.TryGetValue(id, out var rchHub)) rchHub.tcs.TrySetResult(value); } break; case "registryweblog": if (AppInit.conf.weblog.enable && Startup.WebLogEnableController) { string token = GetStringArg(args, 0); if (string.IsNullOrEmpty(AppInit.conf.weblog.token) || AppInit.conf.weblog.token == token) weblog_clients.AddOrUpdate(connection.ConnectionId, 0, static (_, __) => 0); } break; case "weblog": SendLog(GetStringArg(args, 0), GetStringArg(args, 1)); break; case "registryevent": { string uid = GetStringArg(args, 0); if (!string.IsNullOrEmpty(uid)) event_clients.AddOrUpdate(connection.ConnectionId, uid, (_, __) => uid); break; } case "events": { string uid = GetStringArg(args, 0); string name = GetStringArg(args, 1); string data = GetStringArg(args, 2); if (name == "devices") { var uidClients = event_clients .Where(i => i.Value == uid) .ToDictionary(); var devices = _connections .Where(i => i.Value.ConnectionId != connection.ConnectionId) .Where(i => (!AppInit.conf.accsdb.enable && i.Value.Ip == connection.Ip) || uidClients.Keys.Contains(i.Key)) .Select(i => new { uid = uidClients.TryGetValue(i.Value.ConnectionId, out var targetUid) ? targetUid : null, i.Value.ConnectionId, i.Value.RequestInfo.UserAgent }) .ToArray(); await SendAsync(connection, "event", uid, name, devices).ConfigureAwait(false); break; } await SendEvents(connection.ConnectionId, uid, name, data).ConfigureAwait(false); break; } case "eventsid": { string targetConnection = GetStringArg(args, 0); string uid = GetStringArg(args, 1); string name = GetStringArg(args, 2); string data = GetStringArg(args, 3); await SendEventToConnection(targetConnection, uid, name, data).ConfigureAwait(false); break; } case "ping": await SendAsync(connection, "pong").ConfigureAwait(false); break; } } static string GetStringArg(JsonElement args, int index) { if (args.ValueKind != JsonValueKind.Array || args.GetArrayLength() <= index) return null; var element = args[index]; if (element.ValueKind == JsonValueKind.String) return element.GetString(); if (element.ValueKind == JsonValueKind.Null) return null; return element.ToString(); } #endregion #region SendAsync sealed record NwsSendModel(string method, object[] args); static async Task SendAsync(NwsConnection connection, string method, params object[] args) { if (connection.Socket.State != WebSocketState.Open || string.IsNullOrEmpty(method)) return; try { #region stats if (AppInit.conf.openstat.enable && memoryCache != null) { var counter = memoryCache.GetOrCreate("stats:nws", entry => { entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(1); return new CounterNws(); }); Interlocked.Increment(ref counter.send); } #endregion await connection.SendLock.WaitAsync(TimeSpan.FromSeconds(20)).ConfigureAwait(false); using (var ms = PoolInvk.msm.GetStream()) { JsonSerializer.Serialize(ms, new NwsSendModel(method, args = args ?? Array.Empty()), serializerOptions); ms.Position = 0; if (connection.Socket.State == WebSocketState.Open) { using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10))) { await connection.Socket .SendAsync(ms.GetBuffer().AsMemory(0, (int)ms.Length), WebSocketMessageType.Text, true, cts.Token) .ConfigureAwait(false); } connection.UpdateActivity(); connection.UpdateSendActivity(); } } } catch (WebSocketException) { } catch (OperationCanceledException) { } finally { connection.SendLock.Release(); } } public static void SendLog(string message, string plugin) { if (!AppInit.conf.weblog.enable || string.IsNullOrEmpty(message) || string.IsNullOrEmpty(plugin) || message.Length > 4_000000) return; if (weblog_clients.IsEmpty || !Startup.WebLogEnableController) return; foreach (string connectionId in weblog_clients.Keys) { if (_connections.TryGetValue(connectionId, out var client)) _ = SendAsync(client, "Receive", message, plugin).ConfigureAwait(false); } } public static Task SendEvents(string connectionId, string uid, string name, string data) { if (string.IsNullOrEmpty(uid) || string.IsNullOrEmpty(name)) return Task.CompletedTask; var targets = event_clients.Where(i => i.Value == uid && (connectionId == null || i.Key != connectionId)).Select(i => i.Key); if (!targets.Any()) return Task.CompletedTask; var tasks = new List(); foreach (string targetId in targets) { if (_connections.TryGetValue(targetId, out var client)) tasks.Add(SendAsync(client, "event", uid, name, data ?? string.Empty)); } if (tasks.Count == 0) return Task.CompletedTask; return Task.WhenAll(tasks); } static Task SendEventToConnection(string connectionId, string uid, string name, string data) { if (string.IsNullOrEmpty(connectionId) || string.IsNullOrEmpty(name)) return Task.CompletedTask; if (_connections.TryGetValue(connectionId, out var client)) return SendAsync(client, "event", uid ?? string.Empty, name, data ?? string.Empty); return Task.CompletedTask; } public static Task SendRchRequestAsync(string connectionId, string rchId, string url, string data, Dictionary headers, bool returnHeaders) { if (string.IsNullOrEmpty(connectionId)) return Task.CompletedTask; if (_connections.TryGetValue(connectionId, out var client)) return SendAsync(client, "RchClient", rchId, url, data, headers, returnHeaders); return Task.CompletedTask; } #endregion #region cleanup public static void Cleanup(string connectionId) { if (string.IsNullOrEmpty(connectionId)) return; if (_connections.TryRemove(connectionId, out var connection)) { connection.Cancel(); connection.Dispose(); } weblog_clients.TryRemove(connectionId, out _); event_clients.TryRemove(connectionId, out _); RchClient.OnDisconnected(connectionId); } #endregion #region ConnectionMonitorCallback static int _updatingMonitorCallback = 0; static void ConnectionMonitorCallback(object state) { if (_connections.IsEmpty) return; if (Interlocked.Exchange(ref _updatingMonitorCallback, 1) == 1) return; try { var now = DateTime.UtcNow; var cutoff = now.AddSeconds(-125); foreach (var connection in _connections) { if (cutoff >= connection.Value.LastActivityUtc) connection.Value.Cancel(); int inactiveAfterMinutes = AppInit.conf.WebSocket.inactiveAfterMinutes; if (inactiveAfterMinutes > 0) { if (now.AddMinutes(-inactiveAfterMinutes) >= connection.Value.LastSendActivityUtc) connection.Value.Cancel(); } } } catch { } finally { Volatile.Write(ref _updatingMonitorCallback, 0); } } #endregion #region FullDispose public static void FullDispose() { ConnectionMonitorTimer.Dispose(); foreach (var connection in _connections) connection.Value.Cancel(); } #endregion } public class CounterNws { public int receive; public int send; } }