diff --git a/package.json b/package.json index b83f23f..58bca3a 100644 --- a/package.json +++ b/package.json @@ -21,6 +21,7 @@ "@nestjs/core": "^11.0.1", "@nestjs/platform-express": "^11.0.1", "@nestjs/platform-socket.io": "^11.1.6", + "@nestjs/platform-ws": "^11.1.6", "@nestjs/swagger": "^11.2.0", "@nestjs/websockets": "^11.1.6", "axios": "^1.10.0", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 4ce78f4..20c6be0 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -26,6 +26,9 @@ importers: '@nestjs/platform-socket.io': specifier: ^11.1.6 version: 11.1.6(@nestjs/common@11.1.5(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/websockets@11.1.6)(rxjs@7.8.2) + '@nestjs/platform-ws': + specifier: ^11.1.6 + version: 11.1.6(@nestjs/common@11.1.5(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/websockets@11.1.6)(rxjs@7.8.2) '@nestjs/swagger': specifier: ^11.2.0 version: 11.2.0(@nestjs/common@11.1.5(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.5)(reflect-metadata@0.2.2) @@ -845,6 +848,13 @@ packages: '@nestjs/websockets': ^11.0.0 rxjs: ^7.1.0 + '@nestjs/platform-ws@11.1.6': + resolution: {integrity: sha512-dGek3sBpjMNDpuOyadi1/j5cWpjBVHR7ApjvpMqgnTqKxvK9ljdXlQ0kglKLNQBBEa2wczHEDmte3cKk+/G4zw==} + peerDependencies: + '@nestjs/common': ^11.0.0 + '@nestjs/websockets': ^11.0.0 + rxjs: ^7.1.0 + '@nestjs/schematics@11.0.5': resolution: {integrity: sha512-T50SCNyqCZ/fDssaOD7meBKLZ87ebRLaJqZTJPvJKjlib1VYhMOCwXYsr7bjMPmuPgiQHOwvppz77xN/m6GM7A==} peerDependencies: @@ -4522,6 +4532,17 @@ snapshots: - supports-color - utf-8-validate + '@nestjs/platform-ws@11.1.6(@nestjs/common@11.1.5(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/websockets@11.1.6)(rxjs@7.8.2)': + dependencies: + '@nestjs/common': 11.1.5(reflect-metadata@0.2.2)(rxjs@7.8.2) + '@nestjs/websockets': 11.1.6(@nestjs/common@11.1.5(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.5)(@nestjs/platform-socket.io@11.1.6)(reflect-metadata@0.2.2)(rxjs@7.8.2) + rxjs: 7.8.2 + tslib: 2.8.1 + ws: 8.18.3 + transitivePeerDependencies: + - bufferutil + - utf-8-validate + '@nestjs/schematics@11.0.5(chokidar@4.0.3)(typescript@5.8.3)': dependencies: '@angular-devkit/core': 19.2.6(chokidar@4.0.3) diff --git a/src/core/ws/handlers/ping.handler.ts b/src/core/ws/handlers/ping.handler.ts new file mode 100644 index 0000000..e1ccb8d --- /dev/null +++ b/src/core/ws/handlers/ping.handler.ts @@ -0,0 +1,12 @@ +import { Injectable } from '@nestjs/common'; +import { WsTools } from '../ws.tools'; +import { IMessageHandler } from 'src/types/ws/ws.handlers.interface'; +import { AuthenticatedSocket } from '../../../types/ws/ws.interface'; + +@Injectable() +export class PingHandler implements IMessageHandler { + type = 'ping'; + async handle(socket: AuthenticatedSocket, msg: any) { + await WsTools.send(socket, { type: 'pong' }); + } +} diff --git a/src/core/ws/handlers/pong.handler.ts b/src/core/ws/handlers/pong.handler.ts new file mode 100644 index 0000000..fc3bca1 --- /dev/null +++ b/src/core/ws/handlers/pong.handler.ts @@ -0,0 +1,13 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { AuthenticatedSocket } from '../../../types/ws/ws.interface'; +import { IMessageHandler } from '../../../types/ws/ws.handlers.interface'; + +@Injectable() +export class PongHandler implements IMessageHandler { + type = 'pong'; + private readonly logger = new Logger(PongHandler.name); + + async handle(socket: AuthenticatedSocket, msg: any) { + //this.logger.debug(`收到 pong 消息: ${JSON.stringify(msg)}`); + } +} diff --git a/src/core/ws/handlers/report-bots.handler.ts b/src/core/ws/handlers/report-bots.handler.ts new file mode 100644 index 0000000..785410e --- /dev/null +++ b/src/core/ws/handlers/report-bots.handler.ts @@ -0,0 +1,26 @@ +import { Inject, Injectable } from '@nestjs/common'; +import { Logger } from '@nestjs/common'; +import { IMessageHandler } from '../../../types/ws/ws.handlers.interface'; +import { RedisService } from '../../redis/redis.service'; +import { AuthenticatedSocket } from '../../../types/ws/ws.interface'; + +@Injectable() +export class ReportBotsHandler implements IMessageHandler { + type = 'reportBots'; + private readonly logger = new Logger(ReportBotsHandler.name); + + constructor( + @Inject(RedisService) + private readonly redisService: RedisService, + ) {} + + async handle(socket: AuthenticatedSocket, msg: any) { + this.logger.debug(`received reportBots: ${msg.data}`); + const clientId = msg.data[0].client; + const botsData = msg.data.slice(1); + await this.redisService.persistData('crystelfBots', botsData, clientId); + this.logger.debug( + `保存了 ${botsData.length} 个 bot(client: ${clientId})`, + ); + } +} diff --git a/src/core/ws/handlers/test.handler.ts b/src/core/ws/handlers/test.handler.ts new file mode 100644 index 0000000..fc178da --- /dev/null +++ b/src/core/ws/handlers/test.handler.ts @@ -0,0 +1,16 @@ +import { Injectable } from '@nestjs/common'; +import { WsTools } from '../ws.tools'; +import { IMessageHandler } from '../../../types/ws/ws.handlers.interface'; +import { AuthenticatedSocket } from '../../../types/ws/ws.interface'; + +@Injectable() +export class TestHandler implements IMessageHandler { + type = 'test'; + + async handle(socket: AuthenticatedSocket, msg: any) { + await WsTools.send(socket, { + type: 'test', + data: { status: 'ok' }, + }); + } +} diff --git a/src/core/ws/handlers/unknown.handler.ts b/src/core/ws/handlers/unknown.handler.ts new file mode 100644 index 0000000..04ed151 --- /dev/null +++ b/src/core/ws/handlers/unknown.handler.ts @@ -0,0 +1,18 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { WsTools } from '../ws.tools'; +import { IMessageHandler } from '../../../types/ws/ws.handlers.interface'; +import { AuthenticatedSocket } from '../../../types/ws/ws.interface'; + +@Injectable() +export class UnknownHandler implements IMessageHandler { + type = 'unknown'; + private readonly logger = new Logger(UnknownHandler.name); + + async handle(socket: AuthenticatedSocket, msg: any) { + this.logger.warn(`收到未知消息类型: ${msg.type}`); + await WsTools.send(socket, { + type: 'error', + message: `未知消息类型: ${msg.type}`, + }); + } +} diff --git a/src/core/ws/ws-client.manager.ts b/src/core/ws/ws-client.manager.ts index 21a9dd4..009e5ae 100644 --- a/src/core/ws/ws-client.manager.ts +++ b/src/core/ws/ws-client.manager.ts @@ -5,28 +5,55 @@ import { v4 as uuidv4 } from 'uuid'; type ClientID = string; const pendingRequests = new Map void>(); +/** + * 客户端管理 + */ @Injectable() export class WsClientManager { private clients = new Map(); + /** + * 增添新的客户端 + * @param id 编号 + * @param socket 客户端 + */ add(id: ClientID, socket: WebSocket) { this.clients.set(id, socket); } + /** + * 移除客户端 + * @param id 编号 + */ remove(id: ClientID) { this.clients.delete(id); } + /** + * 获取客户端单例 + * @param id 编号 + */ get(id: ClientID): WebSocket | undefined { return this.clients.get(id); } + /** + * 发送消息到客户端 + * @param id 编号 + * @param data 要发送的信息 + */ async send(id: ClientID, data: any): Promise { const socket = this.clients.get(id); if (!socket || socket.readyState !== WebSocket.OPEN) return false; return this.safeSend(socket, data); } + /** + * 发送消息并等待返回 + * @param id 编号 + * @param data 消息 + * @param timeout + */ async sendAndWait(id: ClientID, data: any, timeout = 5000): Promise { const socket = this.clients.get(id); if (!socket) return; @@ -54,6 +81,11 @@ export class WsClientManager { }); } + /** + * 处理回调 + * @param requestId 请求id + * @param data 内容 + */ resolvePendingRequest(requestId: string, data: any): boolean { const callback = pendingRequests.get(requestId); if (callback) { @@ -64,6 +96,10 @@ export class WsClientManager { return false; } + /** + * 广播消息 + * @param data 内容 + */ async broadcast(data: any): Promise { const tasks = Array.from(this.clients.values()).map((socket) => { if (socket.readyState === WebSocket.OPEN) { @@ -75,6 +111,12 @@ export class WsClientManager { await Promise.all(tasks); } + /** + * 安全发送 + * @param socket + * @param data + * @private + */ private async safeSend(socket: WebSocket, data: any): Promise { return new Promise((resolve, reject) => { socket.send(JSON.stringify(data), (err) => { diff --git a/src/core/ws/ws-message.handler.ts b/src/core/ws/ws-message.handler.ts new file mode 100644 index 0000000..bd92f27 --- /dev/null +++ b/src/core/ws/ws-message.handler.ts @@ -0,0 +1,57 @@ +import { Inject, Injectable, Logger } from '@nestjs/common'; +import { WsTools } from './ws.tools'; +import { WsClientManager } from './ws-client.manager'; +import { IMessageHandler } from '../../types/ws/ws.handlers.interface'; +import { AuthenticatedSocket } from '../../types/ws/ws.interface'; +import { TestHandler } from './handlers/test.handler'; + +@Injectable() +export class WsMessageHandler { + private readonly logger = new Logger(WsMessageHandler.name); + private handlers = new Map(); + + constructor( + private readonly wsClientManager: WsClientManager, + @Inject('WS_HANDLERS') handlers: IMessageHandler[], + ) { + handlers.forEach((h) => this.handlers.set(h.type, h)); + this.logger.log(`已注册 ${handlers.length} 个 WS handler`); + } + + async handle(socket: AuthenticatedSocket, clientId: string, msg: any) { + try { + // 如果是 pendingRequests 的回包 + if ( + msg.requestId && + this.wsClientManager.resolvePendingRequest(msg.requestId, msg) + ) { + return; + } + const handler = + this.handlers.get(msg.type) || this.handlers.get('unknown'); + if (handler) { + await handler.handle(socket, msg); + } else { + await this.handleUnknown(socket, msg); + } + } catch (err) { + this.logger.error(`ws消息处理时出错: ${err}`); + await WsTools.send(socket, { + type: 'error', + message: 'error message', + }); + } + } + + private async handleUnknown(socket: AuthenticatedSocket, msg: any) { + this.logger.warn(`收到未知消息类型: ${msg.type}`); + await WsTools.send(socket, { + type: 'error', + message: `未知消息类型: ${msg.type}`, + }); + } + + public registerHandler(handler: IMessageHandler): void { + this.handlers.set(handler.type, handler); + } +} diff --git a/src/core/ws/ws.gateway.ts b/src/core/ws/ws.gateway.ts index 395ed04..975c691 100644 --- a/src/core/ws/ws.gateway.ts +++ b/src/core/ws/ws.gateway.ts @@ -8,11 +8,17 @@ import { Inject, Logger } from '@nestjs/common'; import { Server, WebSocket } from 'ws'; import { WsTools } from './ws.tools'; import { WsClientManager } from './ws-client.manager'; -import { AuthenticatedSocket, AuthMessage, WSMessage } from '../../types/ws'; +import { + AuthenticatedSocket, + AuthMessage, + WSMessage, +} from '../../types/ws/ws.interface'; import { AppConfigService } from '../../config/config.service'; +import { WsMessageHandler } from './ws-message.handler'; -@WebSocketGateway({ +@WebSocketGateway(7001, { cors: { origin: '*' }, + driver: 'ws', }) export class WsGateway implements OnGatewayConnection, OnGatewayDisconnect { private readonly logger = new Logger(WsGateway.name); @@ -24,11 +30,19 @@ export class WsGateway implements OnGatewayConnection, OnGatewayDisconnect { constructor( @Inject(AppConfigService) private readonly configService: AppConfigService, + @Inject(WsClientManager) private readonly wsClientManager: WsClientManager, + @Inject(WsMessageHandler) + private readonly wsMessageHandler: WsMessageHandler, ) { this.secret = this.configService.get('WS_SECRET'); } + /** + * 新的连接请求 + * @param client 客户端 + * @param req + */ async handleConnection(client: AuthenticatedSocket, req: any) { const ip = req.socket.remoteAddress || 'unknown'; this.logger.log(`收到来自 ${ip} 的 WebSocket 连接请求..`); @@ -49,6 +63,10 @@ export class WsGateway implements OnGatewayConnection, OnGatewayDisconnect { }); } + /** + * 断开某个连接 + * @param client 客户端 + */ async handleDisconnect(client: AuthenticatedSocket) { if (client.heartbeat) clearInterval(client.heartbeat); if (client.clientId) { @@ -57,6 +75,12 @@ export class WsGateway implements OnGatewayConnection, OnGatewayDisconnect { } } + /** + * 不合法消息 + * @param client 客户端 + * @param ip + * @private + */ private async handleInvalidMessage(client: WebSocket, ip: string) { this.logger.warn(`Invalid message received from ${ip}`); await WsTools.send(client, { @@ -65,6 +89,13 @@ export class WsGateway implements OnGatewayConnection, OnGatewayDisconnect { }); } + /** + * 消息路由 + * @param client 客户端 + * @param msg 消息 + * @param ip + * @private + */ private async routeMessage( client: AuthenticatedSocket, msg: WSMessage, @@ -86,13 +117,20 @@ export class WsGateway implements OnGatewayConnection, OnGatewayDisconnect { this.logger.debug( `Routing message from ${client.clientId}: ${JSON.stringify(msg)}`, ); - // TODO: 注入 handler 服务 + await this.wsMessageHandler.handle(client, client.clientId!, msg); } private isAuthMessage(msg: WSMessage): msg is AuthMessage { return msg.type === 'auth'; } + /** + * 连接验证 + * @param client 客户端 + * @param msg 消息 + * @param ip + * @private + */ private async handleAuth( client: AuthenticatedSocket, msg: AuthMessage, diff --git a/src/core/ws/ws.module.ts b/src/core/ws/ws.module.ts index 5a3a322..00797b2 100644 --- a/src/core/ws/ws.module.ts +++ b/src/core/ws/ws.module.ts @@ -2,10 +2,43 @@ import { Module } from '@nestjs/common'; import { WsGateway } from './ws.gateway'; import { WsClientManager } from './ws-client.manager'; import { AppConfigModule } from '../../config/config.module'; +import { WsMessageHandler } from './ws-message.handler'; +import { TestHandler } from './handlers/test.handler'; +import { PingHandler } from './handlers/ping.handler'; +import { PongHandler } from './handlers/pong.handler'; +import { ReportBotsHandler } from './handlers/report-bots.handler'; +import { UnknownHandler } from './handlers/unknown.handler'; +import { RedisModule } from '../redis/redis.module'; @Module({ - imports: [AppConfigModule], - providers: [WsGateway, WsClientManager], - exports: [WsClientManager], + imports: [AppConfigModule, RedisModule], + providers: [ + WsGateway, + WsClientManager, + WsMessageHandler, + TestHandler, + PingHandler, + PongHandler, + ReportBotsHandler, + UnknownHandler, + { + provide: 'WS_HANDLERS', + useFactory: ( + test: TestHandler, + ping: PingHandler, + pong: PongHandler, + reportBots: ReportBotsHandler, + unknown: UnknownHandler, + ) => [test, ping, pong, reportBots, unknown], + inject: [ + TestHandler, + PingHandler, + PongHandler, + ReportBotsHandler, + UnknownHandler, + ], + }, + ], + exports: [WsClientManager, WsMessageHandler, WsGateway], }) export class WsModule {} diff --git a/src/core/ws/ws.tools.ts b/src/core/ws/ws.tools.ts index 6086ce4..a2bc376 100644 --- a/src/core/ws/ws.tools.ts +++ b/src/core/ws/ws.tools.ts @@ -1,20 +1,33 @@ -import WebSocket from 'ws'; +import type { WebSocket, RawData } from 'ws'; import { Logger } from '@nestjs/common'; export class WsTools { private static readonly logger = new Logger(WsTools.name); static async send(socket: WebSocket, data: unknown): Promise { - if (socket.readyState !== WebSocket.OPEN) return false; + if (socket.readyState !== 1) { + this.logger.warn('尝试向非 OPEN 状态的 socket 发送消息,已丢弃'); + return false; + } return new Promise((resolve) => { - socket.send(JSON.stringify(data), (err) => { - resolve(!err); - }); + try { + socket.send(JSON.stringify(data), (err) => { + if (err) { + this.logger.error(`WS send error: ${err.message}`); + resolve(false); + } else { + resolve(true); + } + }); + } catch (err: any) { + this.logger.error(`WS send exception: ${err.message}`); + resolve(false); + } }); } - static parseMessage(data: WebSocket.RawData): T | null { + static parseMessage(data: RawData): T | null { try { return JSON.parse(data.toString()) as T; } catch (err) { @@ -24,9 +37,9 @@ export class WsTools { } static setUpHeartbeat(socket: WebSocket, interval = 30000): NodeJS.Timeout { - const heartbeat = () => { - if (socket.readyState === WebSocket.OPEN) { - WsTools.send(socket, { type: 'ping' }); + const heartbeat = async () => { + if (socket.readyState === 1) { + await WsTools.send(socket, { type: 'ping' }); } }; return setInterval(heartbeat, interval); diff --git a/src/main.ts b/src/main.ts index 7036821..6699525 100644 --- a/src/main.ts +++ b/src/main.ts @@ -5,6 +5,7 @@ import { SwaggerModule, DocumentBuilder } from '@nestjs/swagger'; import { ResponseInterceptor } from './common/interceptors/response.interceptor'; import { AllExceptionsFilter } from './common/filters/all-exception.filter'; import { SystemService } from './core/system/system.service'; +import { WsAdapter } from '@nestjs/platform-ws'; async function bootstrap() { Logger.log('晶灵核心初始化..'); @@ -24,6 +25,7 @@ async function bootstrap() { .build(); const document = () => SwaggerModule.createDocument(app, config); SwaggerModule.setup('', app, document); + app.useWebSocketAdapter(new WsAdapter(app)); await app.listen(7000); await systemService.checkUpdate().catch((err) => { Logger.error(`自动更新失败: ${err?.message}`, '', 'System'); diff --git a/src/types/ws/ws.handlers.interface.ts b/src/types/ws/ws.handlers.interface.ts new file mode 100644 index 0000000..81fbb90 --- /dev/null +++ b/src/types/ws/ws.handlers.interface.ts @@ -0,0 +1,6 @@ +import { AuthenticatedSocket } from './ws.interface'; + +export interface IMessageHandler { + type: string; //消息类型 + handle(socket: AuthenticatedSocket, msg: any): Promise; +} diff --git a/src/types/ws.ts b/src/types/ws/ws.interface.ts similarity index 100% rename from src/types/ws.ts rename to src/types/ws/ws.interface.ts