diff --git a/src/main.ts b/src/main.ts index 0310a23..396f3eb 100644 --- a/src/main.ts +++ b/src/main.ts @@ -3,7 +3,7 @@ import logger from './utils/core/logger'; import config from './utils/core/config'; import redis from './services/redis/redis'; -config.check(['PORT', 'DEBUG', 'RD_PORT', 'RD_ADD']); +config.check(['PORT', 'DEBUG', 'RD_PORT', 'RD_ADD', 'WS_SECRET', 'WS_PORT']); const PORT = config.get('PORT') || 3000; apps diff --git a/src/services/redis/redis.ts b/src/services/redis/redis.ts index 467802a..e89aebe 100644 --- a/src/services/redis/redis.ts +++ b/src/services/redis/redis.ts @@ -17,6 +17,7 @@ class RedisService { private async initialize() { await this.connectWithRetry(); this.setupEventListeners(); + //await this.test(); } private async connectWithRetry(): Promise { diff --git a/src/services/ws/handler.ts b/src/services/ws/handler.ts index ffe5831..fadca4f 100644 --- a/src/services/ws/handler.ts +++ b/src/services/ws/handler.ts @@ -1,27 +1,41 @@ -import WebSocket from 'ws'; -import WsMessage from '../../types/wsMessage'; +import { AuthenticatedSocket } from '../../types/ws'; +import wsTools from '../../utils/ws/wsTools'; +import * as ws from 'ws'; + +type WebSocket = ws.WebSocket; class WSMessageHandler { - public async handle(socket: WebSocket, clientID: string, msg: WsMessage) { - switch (msg.type) { - case 'test': - await this.reply(socket, { type: 'test', data: 'hi' }); - break; - case 'ping': - await this.reply(socket, { type: 'pong' }); - break; - default: - await this.reply(socket, { type: 'error', message: 'Unknown message' }); - break; + async handle(socket: AuthenticatedSocket, clientId: string, msg: any) { + try { + switch (msg.type) { + case 'test': + await this.handleTest(socket); + break; + case 'ping': + await wsTools.send(socket, { type: 'pong' }); + break; + default: + await this.handleUnknown(socket); + } + } catch (err) { + await wsTools.send(socket, { + type: 'error', + message: 'Processing failed', + }); } } - private async reply(socket: WebSocket, data: any): Promise { - return new Promise((resolve, reject) => { - socket.send(JSON.stringify(data), (err) => { - if (err) reject(err); - else resolve(); - }); + private async handleTest(socket: WebSocket) { + await wsTools.send(socket, { + type: 'test', + data: { status: 'ok' }, + }); + } + + private async handleUnknown(socket: WebSocket) { + await wsTools.send(socket, { + type: 'error', + message: 'Unknown message type', }); } } diff --git a/src/services/ws/wsServer.ts b/src/services/ws/wsServer.ts index 268a049..2cca1d3 100644 --- a/src/services/ws/wsServer.ts +++ b/src/services/ws/wsServer.ts @@ -1,72 +1,83 @@ import WebSocket, { WebSocketServer } from 'ws'; import config from '../../utils/core/config'; -import wsClientManager from './wsClientManager'; -import wsHandler from './handler'; import logger from '../../utils/core/logger'; - -interface AuthenticatedSocket extends WebSocket { - isAuthed?: boolean; - clientId?: string; -} +import { AuthenticatedSocket, AuthMessage, WSMessage } from '../../types/ws'; +import WsTools from '../../utils/ws/wsTools'; +import wsHandler from './handler'; +import { clearInterval } from 'node:timers'; +import wsClientManager from './wsClientManager'; class WSServer { - private wss: WebSocketServer; - private PORT = config.get('WS_PORT'); - private WS_SECRET = config.get('WS_SECRET'); + private readonly wss: WebSocketServer; + private readonly port = Number(config.get('WS_PORT')); + private readonly secret = config.get('WS_SECRET'); constructor() { - this.wss = new WebSocketServer({ port: Number(this.PORT) }); + this.wss = new WebSocketServer({ port: this.port }); this.init(); - logger.info(`WebSocket Server started at ws://localhost:${this.PORT}`); + logger.info(`WS Server listening on ws://localhost:${this.port}`); } - private init() { + private init(): void { this.wss.on('connection', (socket: AuthenticatedSocket) => { + socket.heartbeat = WsTools.setUpHeartbeat(socket); + socket.on('message', async (raw) => { - let msg: any; - try { - msg = JSON.parse(raw.toString()); - } catch { - return this.send(socket, { type: 'error', message: 'JSON 解析失败' }); - } + const msg = WsTools.parseMessage(raw); + if (!msg) return this.handleInvalidMessage(socket); - // 鉴权 - if (!socket.isAuthed) { - if (msg.type === 'auth' && msg.secret === this.WS_SECRET && msg.clientId) { - socket.isAuthed = true; - socket.clientId = msg.clientId; - wsClientManager.add(msg.clientId, socket); - return this.send(socket, { type: 'auth', success: true }); - } - return this.send(socket, { type: 'auth', success: false }); - } - - // 业务处理 - if (socket.clientId) { - try { - await wsHandler.handle(socket, socket.clientId, msg); - } catch (e) { - await this.send(socket, { type: 'error', message: '处理出错' }); - } - } + await this.routeMessage(socket, msg); }); socket.on('close', () => { - if (socket.clientId) { - wsClientManager.remove(socket.clientId); - } + this.handleDisconnect(socket); }); }); } - private async send(socket: WebSocket, data: any): Promise { - return new Promise((resolve, reject) => { - socket.send(JSON.stringify(data), (err) => { - if (err) reject(err); - else resolve(); - }); + private async handleInvalidMessage(socket: WebSocket) { + await WsTools.send(socket, { + type: 'error', + message: 'Invalid message format', }); } + + private async routeMessage(socket: AuthenticatedSocket, msg: WSMessage) { + if (!socket.isAuthed) { + if (this.isAuthMessage(msg)) { + await this.handleAuth(socket, msg); + } + return; + } + + if (socket.clientId) { + await wsHandler.handle(socket, socket.clientId, msg); + } + } + + private isAuthMessage(msg: WSMessage): msg is AuthMessage { + return ( + msg.type === 'auth' && + typeof (msg as AuthMessage).secret === 'string' && + typeof (msg as AuthMessage).clientId === 'string' + ); + } + + private async handleAuth(socket: AuthenticatedSocket, msg: AuthMessage) { + if (msg.secret === this.secret) { + socket.isAuthed = true; + socket.clientId = msg.clientId; + wsClientManager.add(msg.clientId, socket); + await WsTools.send(socket, { type: 'auth', success: true }); + } else { + await WsTools.send(socket, { type: 'auth', success: false }); + } + } + + private handleDisconnect(socket: AuthenticatedSocket) { + if (socket.heartbeat) clearInterval(socket.heartbeat); + if (socket.clientId) wsClientManager.remove(socket.clientId); + } } const wsServer = new WSServer(); diff --git a/src/types/ws.ts b/src/types/ws.ts new file mode 100644 index 0000000..43419eb --- /dev/null +++ b/src/types/ws.ts @@ -0,0 +1,18 @@ +import WebSocket from 'ws'; + +export interface AuthenticatedSocket extends WebSocket { + isAuthed?: boolean; + clientId?: string; + heartbeat?: NodeJS.Timeout; +} + +export interface WSMessage { + type: string; + [key: string]: unknown; +} + +export interface AuthMessage extends WSMessage { + type: 'auth'; + secret: string; + clientId: string; +} diff --git a/src/types/wsMessage.ts b/src/types/wsMessage.ts deleted file mode 100644 index fb2af4a..0000000 --- a/src/types/wsMessage.ts +++ /dev/null @@ -1,6 +0,0 @@ -interface wsMessage { - type: string; - [key: string]: any; -} - -export default wsMessage; diff --git a/src/utils/core/file.ts b/src/utils/core/file.ts index 6eee77b..015209f 100644 --- a/src/utils/core/file.ts +++ b/src/utils/core/file.ts @@ -36,6 +36,10 @@ class fc { } } + /** + * 输出日志到文件 + * @param message + */ public static async logToFile(message: string): Promise { const logFile = path.join(paths.get('log'), `${date.getCurrentDate()}.log`); const logMessage = `${message}\n`; diff --git a/src/utils/ws/wsTools.ts b/src/utils/ws/wsTools.ts new file mode 100644 index 0000000..2820599 --- /dev/null +++ b/src/utils/ws/wsTools.ts @@ -0,0 +1,44 @@ +import WebSocket from 'ws'; +import logger from '../core/logger'; +import { setInterval } from 'node:timers'; + +class WsTools { + /** + * 发送消息 + */ + static async send(socket: WebSocket, data: unknown): Promise { + if (socket.readyState !== WebSocket.OPEN) return false; + + return new Promise((resolve) => { + socket.send(JSON.stringify(data), (err) => { + resolve(!err); + }); + }); + } + + /** + * 解析消息 + */ + static parseMessage(data: WebSocket.RawData): T | null { + try { + return JSON.parse(data.toString()) as T; + } catch (err) { + logger.error(err); + return null; + } + } + + /** + * 心跳检测 + */ + static setUpHeartbeat(socket: WebSocket, interval = 30000): NodeJS.Timeout { + const heartbeat = () => { + if (socket.readyState === WebSocket.OPEN) { + WsTools.send(socket, { type: 'ping' }); + } + }; + return setInterval(heartbeat, interval); + } +} + +export default WsTools;