diff --git a/src/backend/app/__init__.py b/src/backend/app/__init__.py index f614c63..af0548e 100644 --- a/src/backend/app/__init__.py +++ b/src/backend/app/__init__.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, responses from src.backend.app.api.endpoints import router from src.backend.app.utils.logger import setup_logging from src.backend.config import settings +from .services.switch_traffic_monitor import get_switch_monitor # 添加正确的导入 from .services.traffic_monitor import traffic_monitor @@ -37,10 +38,12 @@ def create_app() -> FastAPI: async def favicon(): return responses.Response(status_code=204) + # 添加API路由 app.include_router(router, prefix=settings.API_PREFIX) return app + app = create_app() \ No newline at end of file diff --git a/src/backend/app/api/endpoints.py b/src/backend/app/api/endpoints.py index 20ab3d6..e2dd6fd 100644 --- a/src/backend/app/api/endpoints.py +++ b/src/backend/app/api/endpoints.py @@ -1,20 +1,23 @@ -from datetime import datetime - -from fastapi import (APIRouter, HTTPException, Response,WebSocket, WebSocketDisconnect) +from datetime import datetime, timedelta +from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect) from typing import List from pydantic import BaseModel +import asyncio +from fastapi.responses import HTMLResponse +import matplotlib.pyplot as plt +import io +import base64 + +from ..services.switch_traffic_monitor import get_switch_monitor +from ..utils import logger from ...app.services.ai_services import AIService from ...app.api.network_config import SwitchConfigurator from ...config import settings from ..services.network_scanner import NetworkScanner from ...app.services.traffic_monitor import traffic_monitor -from ...app.models.traffic_models import TrafficRecord -from src.backend.app.api.database import (SessionLocal) -import asyncio -from fastapi.responses import HTMLResponse -import matplotlib.pyplot as plt -import io -import base64 +from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord +from src.backend.app.api.database import SessionLocal + @@ -34,6 +37,8 @@ async def root(): "/scan_network", "/list_devices", "/batch_apply_config" + "/traffic/switch/current", # 添加交换机流量端点 + "/traffic/switch/history" # 添加交换机历史流量端点 ] } @@ -163,42 +168,223 @@ async def websocket_traffic(websocket: WebSocket): print("客户端断开连接") -@router.get("/traffic/plot", response_class=HTMLResponse, summary="流量可视化图表") -async def plot_traffic(interface: str = "eth0", minutes: int = 10): - # 获取历史数据 - history = traffic_monitor.get_traffic_history(interface) - time_points = history["time"][-minutes * 60:] - sent = history["sent"][-minutes * 60:] - recv = history["recv"][-minutes * 60:] +@router.get("/", include_in_schema=False) +async def root(): + return { + "message": "欢迎使用AI交换机配置系统", + "docs": f"{settings.API_PREFIX}/docs", + "redoc": f"{settings.API_PREFIX}/redoc", + "endpoints": [ + "/parse_command", + "/apply_config", + "/scan_network", + "/list_devices", + "/batch_apply_config", + "/traffic/switch/current", # 添加交换机流量端点 + "/traffic/switch/history" # 添加交换机历史流量端点 + ] + } - # 创建图表 - plt.figure(figsize=(10, 6)) - plt.plot(time_points, sent, label="发送流量 (B/s)") - plt.plot(time_points, recv, label="接收流量 (B/s)") - plt.title(f"{interface} 流量监控 - 最近 {minutes} 分钟") - plt.xlabel("时间") - plt.ylabel("流量 (字节/秒)") - plt.legend() - plt.grid(True) - plt.xticks(rotation=45) - plt.tight_layout() - # 转换为HTML图像 - buf = io.BytesIO() - plt.savefig(buf, format="png") - buf.seek(0) - image_base64 = base64.b64encode(buf.read()).decode("utf-8") - plt.close() +# ... 其他路由保持不变 ... - return f""" - - - 网络流量监控 - - -

{interface} 网络流量监控

- 流量图表 -

更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

- - - """ \ No newline at end of file +@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口") +async def get_switch_interfaces(switch_ip: str): + """获取指定交换机的所有接口""" + try: + monitor = get_switch_monitor(switch_ip) + interfaces = list(monitor.interface_oids.keys()) + return { + "switch_ip": switch_ip, + "interfaces": interfaces + } + except Exception as e: + logger.error(f"获取交换机接口失败: {str(e)}") + raise HTTPException(500, f"获取接口失败: {str(e)}") + + +@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据") +async def get_switch_current_traffic(switch_ip: str, interface: str = None): + """获取交换机的当前流量数据""" + try: + monitor = get_switch_monitor(switch_ip) + + # 如果没有指定接口,获取所有接口的流量 + if not interface: + traffic_data = {} + for iface in monitor.interface_oids: + traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface) + return { + "switch_ip": switch_ip, + "traffic": traffic_data + } + + # 获取指定接口的流量 + return await get_interface_current_traffic(switch_ip, interface) + except Exception as e: + logger.error(f"获取交换机流量失败: {str(e)}") + raise HTTPException(500, f"获取流量失败: {str(e)}") + + +async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict: + """获取指定交换机接口的当前流量数据""" + try: + with SessionLocal() as session: + # 获取最新记录 + record = session.query(SwitchTrafficRecord).filter( + SwitchTrafficRecord.switch_ip == switch_ip, + SwitchTrafficRecord.interface == interface + ).order_by(SwitchTrafficRecord.timestamp.desc()).first() + + if not record: + return { + "switch_ip": switch_ip, + "interface": interface, + "rate_in": 0.0, + "rate_out": 0.0, + "bytes_in": 0, + "bytes_out": 0 + } + + return { + "switch_ip": switch_ip, + "interface": interface, + "rate_in": record.rate_in, + "rate_out": record.rate_out, + "bytes_in": record.bytes_in, + "bytes_out": record.bytes_out + } + except Exception as e: + logger.error(f"获取接口流量失败: {str(e)}") + raise HTTPException(500, f"获取接口流量失败: {str(e)}") + + +@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据") +async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10): + """获取交换机的流量历史数据""" + try: + monitor = get_switch_monitor(switch_ip) + + # 如果没有指定接口,返回所有接口的历史数据 + if not interface: + return { + "switch_ip": switch_ip, + "history": monitor.get_traffic_history() + } + + # 获取指定接口的历史数据 + with SessionLocal() as session: + # 计算时间范围 + time_threshold = datetime.now() - timedelta(minutes=minutes) + + # 查询历史记录 + records = session.query(SwitchTrafficRecord).filter( + SwitchTrafficRecord.switch_ip == switch_ip, + SwitchTrafficRecord.interface == interface, + SwitchTrafficRecord.timestamp >= time_threshold + ).order_by(SwitchTrafficRecord.timestamp.asc()).all() + + # 提取数据 + history_data = { + "in": [record.rate_in for record in records], + "out": [record.rate_out for record in records], + "time": [record.timestamp.isoformat() for record in records] + } + + return { + "switch_ip": switch_ip, + "interface": interface, + "history": history_data + } + except Exception as e: + logger.error(f"获取历史流量失败: {str(e)}") + raise HTTPException(500, f"获取历史流量失败: {str(e)}") + + +@router.websocket("/ws/traffic/switch") +async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None): + """交换机实时流量WebSocket""" + await websocket.accept() + try: + monitor = get_switch_monitor(switch_ip) + + while True: + # 获取流量数据 + if interface: + # 单个接口 + traffic_data = await get_interface_current_traffic(switch_ip, interface) + await websocket.send_json(traffic_data) + else: + # 所有接口 + traffic_data = {} + for iface in monitor.interface_oids: + traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface) + + await websocket.send_json({ + "switch_ip": switch_ip, + "traffic": traffic_data + }) + + # 每秒更新一次 + await asyncio.sleep(1) + except WebSocketDisconnect: + logger.info(f"客户端断开交换机流量连接: {switch_ip}") + except Exception as e: + logger.error(f"交换机流量WebSocket错误: {str(e)}") + await websocket.close(code=1011, reason=str(e)) + + +# 交换机流量可视化 +@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化") +async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10): + """生成交换机流量图表""" + try: + # 获取历史数据 + history = await get_switch_traffic_history(switch_ip, interface, minutes) + history_data = history["history"] + + # 提取数据点 + time_points = [datetime.fromisoformat(t) for t in history_data["time"]] + in_rates = history_data["in"] + out_rates = history_data["out"] + + # 创建图表 + plt.figure(figsize=(12, 6)) + plt.plot(time_points, in_rates, label="流入流量 (B/s)") + plt.plot(time_points, out_rates, label="流出流量 (B/s)") + plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟") + plt.xlabel("时间") + plt.ylabel("流量 (字节/秒)") + plt.legend() + plt.grid(True) + plt.xticks(rotation=45) + plt.tight_layout() + + # 转换为HTML图像 + buf = io.BytesIO() + plt.savefig(buf, format="png") + buf.seek(0) + image_base64 = base64.b64encode(buf.read()).decode("utf-8") + plt.close() + + return f""" + + + 交换机流量监控 + + + +
+

交换机 {switch_ip} 接口 {interface} 流量监控

+ 流量图表 +

更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

+
+ + + """ + except Exception as e: + logger.error(f"生成流量图表失败: {str(e)}") + return HTMLResponse(content=f"

错误

{str(e)}

", status_code=500) \ No newline at end of file diff --git a/src/backend/app/models/traffic_models.py b/src/backend/app/models/traffic_models.py index f6900f2..43dceb9 100644 --- a/src/backend/app/models/traffic_models.py +++ b/src/backend/app/models/traffic_models.py @@ -1,6 +1,6 @@ # 添加正确的导入 from src.backend.app.api.database import Base # 修复:导入 Base -from sqlalchemy import Column, Integer, String, DateTime +from sqlalchemy import Column, Integer, String, DateTime, BigInteger, Float class TrafficRecord(Base): @@ -24,4 +24,29 @@ class TrafficRecord(Base): "packets_sent": self.packets_sent, "packets_recv": self.packets_recv, "timestamp": self.timestamp.isoformat() - } \ No newline at end of file + } + + +class SwitchTrafficRecord(Base): + __tablename__ = "switch_traffic_records" + + id = Column(Integer, primary_key=True, index=True) + switch_ip = Column(String(50), index=True) + interface = Column(String(50)) + bytes_in = Column(BigInteger) # 累计流入字节数 + bytes_out = Column(BigInteger) # 累计流出字节数 + rate_in = Column(Float) # 当前流入速率(字节/秒) + rate_out = Column(Float) # 当前流出速率(字节/秒) + timestamp = Column(DateTime) + + def to_dict(self): + return { + "id": self.id, + "switch_ip": self.switch_ip, + "interface": self.interface, + "bytes_in": self.bytes_in, + "bytes_out": self.bytes_out, + "rate_in": self.rate_in, + "rate_out": self.rate_out, + "timestamp": self.timestamp.isoformat() + } diff --git a/src/backend/app/services/switch_traffic_monitor.py b/src/backend/app/services/switch_traffic_monitor.py new file mode 100644 index 0000000..578d7c7 --- /dev/null +++ b/src/backend/app/services/switch_traffic_monitor.py @@ -0,0 +1,196 @@ +import asyncio +from datetime import datetime +from collections import deque +from typing import Optional, List, Dict +from pysnmp.hlapi import * +from ..models.traffic_models import SwitchTrafficRecord +from src.backend.app.api.database import SessionLocal +from ..utils.logger import logger + + +class SwitchTrafficMonitor: + def __init__( + self, + switch_ip: str, + community: str = 'public', + update_interval: int = 5, + interfaces: Optional[List[str]] = None + ): + self.switch_ip = switch_ip + self.community = community + self.update_interval = update_interval + self.running = False + self.task = None + self.interface_history = {} + self.history = { + "in": deque(maxlen=300), + "out": deque(maxlen=300), + "time": deque(maxlen=300) + } + + # 基本接口OID映射 + self.interface_oids = { + "GigabitEthernet0/0/1": { + "in": '1.3.6.1.2.1.2.2.1.10.1', # ifInOctets + "out": '1.3.6.1.2.1.2.2.1.16.1' # ifOutOctets + }, + "GigabitEthernet0/0/24": { + "in": '1.3.6.1.2.1.2.2.1.10.24', + "out": '1.3.6.1.2.1.2.2.1.16.24' + } + } + + # 接口过滤 + if interfaces: + self.interface_oids = { + iface: oid for iface, oid in self.interface_oids.items() + if iface in interfaces + } + logger.info(f"监控指定接口: {', '.join(interfaces)}") + else: + logger.info("监控所有接口") + + def start_monitoring(self): + """启动交换机流量监控""" + if not self.running: + self.running = True + self.task = asyncio.create_task(self._monitor_loop()) + logger.success(f"交换机流量监控已启动: {self.switch_ip}") + + async def stop_monitoring(self): + """停止监控""" + if self.running: + self.running = False + if self.task: + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + logger.info(f"交换机流量监控已停止: {self.switch_ip}") + + async def _monitor_loop(self): + """监控主循环""" + last_values = {iface: {"in": 0, "out": 0} for iface in self.interface_oids} + last_time = datetime.now() + + while self.running: + await asyncio.sleep(self.update_interval) + + try: + current_time = datetime.now() + elapsed = (current_time - last_time).total_seconds() + + # 获取所有接口流量 + for iface, oids in self.interface_oids.items(): + in_octets = self._snmp_get(oids["in"]) + out_octets = self._snmp_get(oids["out"]) + + if in_octets is not None and out_octets is not None: + # 计算速率(字节/秒) + # 修复字典访问问题 + iface_values = last_values[iface] + in_rate = (in_octets - iface_values["in"]) / elapsed if iface_values["in"] > 0 else 0 + out_rate = (out_octets - iface_values["out"]) / elapsed if iface_values["out"] > 0 else 0 + + # 保存历史数据 + self.history["in"].append(in_rate) + self.history["out"].append(out_rate) + self.history["time"].append(current_time) + + # 保存到数据库 + self._save_to_db(iface, in_octets, out_octets, in_rate, out_rate, current_time) + + # 更新最后的值 + iface_values["in"] = in_octets + iface_values["out"] = out_octets + + last_time = current_time + except Exception as e: + logger.error(f"监控交换机流量出错: {str(e)}") + + def _snmp_get(self, oid) -> Optional[int]: + """执行SNMP GET请求""" + try: + # 正确格式化的SNMP请求 + cmd = getCmd( + SnmpEngine(), + CommunityData(self.community), + UdpTransportTarget((self.switch_ip, 161)), + ContextData(), + ObjectType(ObjectIdentity(oid))) + + # 执行命令 + errorIndication, errorStatus, errorIndex, varBinds = next(cmd) + except Exception as e: + logger.error(f"SNMP请求失败: {str(e)}") + return None + + if errorIndication: + logger.error(f"SNMP错误: {errorIndication}") + return None + elif errorStatus: + try: + # 修复括号问题 + if errorIndex: + index_val = int(errorIndex) - 1 + error_item = varBinds[index_val] if index_val < len(varBinds) else '?' + else: + error_item = '?' + + error_msg = f"SNMP错误: {errorStatus.prettyPrint()} at {error_item}" + logger.error(error_msg) + except Exception as e: + logger.error(f"解析SNMP错误失败: {str(e)}") + return None + else: + for varBind in varBinds: + try: + return int(varBind[1]) + except Exception as e: + logger.error(f"转换SNMP值失败: {str(e)}") + return None + + return None + + def _save_to_db(self, interface: str, in_octets: int, out_octets: int, + in_rate: float, out_rate: float, timestamp: datetime): + """保存流量数据到数据库""" + try: + with SessionLocal() as session: + record = SwitchTrafficRecord( + switch_ip=self.switch_ip, + interface=interface, + bytes_in=in_octets, + bytes_out=out_octets, + rate_in=in_rate, + rate_out=out_rate, + timestamp=timestamp + ) + session.add(record) + session.commit() + except Exception as e: + logger.error(f"保存流量数据到数据库失败: {str(e)}") + + def get_traffic_history(self) -> Dict[str, List]: + """获取流量历史数据""" + return { + "in": list(self.history["in"]), + "out": list(self.history["out"]), + "time": list(self.history["time"]) + } + + +# 全局监控器字典(支持多个交换机) +switch_monitors = {} + + +def get_switch_monitor(switch_ip: str, community: str = 'public', interfaces: Optional[List[str]] = None): + """获取或创建交换机监控器(添加接口过滤参数)""" + if switch_ip not in switch_monitors: + switch_monitors[switch_ip] = SwitchTrafficMonitor( + switch_ip, + community, + interfaces=interfaces + ) + return switch_monitors[switch_ip] \ No newline at end of file diff --git a/src/backend/app/utils/logger.py b/src/backend/app/utils/logger.py index f0177f8..86ff736 100644 --- a/src/backend/app/utils/logger.py +++ b/src/backend/app/utils/logger.py @@ -1,15 +1,13 @@ - import logging -from loguru import logger import sys - +from loguru import logger as loguru_logger class InterceptHandler(logging.Handler): def emit(self, record): # 获取对应的Loguru日志级别 try: - level = logger.level(record.levelname).name + level = loguru_logger.level(record.levelname).name except ValueError: level = record.levelno @@ -20,7 +18,7 @@ class InterceptHandler(logging.Handler): depth += 1 # 使用Loguru记录日志 - logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) + loguru_logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) def setup_logging(): @@ -28,10 +26,10 @@ def setup_logging(): logging.basicConfig(handlers=[InterceptHandler()], level=logging.NOTSET) # 移除所有现有处理器 - logger.remove() + loguru_logger.remove() - # 直接添加处理器(避免configure方法) - logger.add( + # 添加新的处理器 + loguru_logger.add( sys.stdout, format=( "{time:YYYY-MM-DD HH:mm:ss.SSS} | " @@ -41,4 +39,51 @@ def setup_logging(): ), level="DEBUG", enqueue=True - ) \ No newline at end of file + ) + + # 添加文件日志 + loguru_logger.add( + "app.log", + rotation="10 MB", + retention="30 days", + level="INFO", + format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level} | {message}" + ) + + +# 创建通用logger接口 +class Logger: + @staticmethod + def debug(msg, *args, **kwargs): + loguru_logger.debug(msg, *args, **kwargs) + + @staticmethod + def info(msg, *args, **kwargs): + loguru_logger.info(msg, *args, **kwargs) + + @staticmethod + def warning(msg, *args, **kwargs): + loguru_logger.warning(msg, *args, **kwargs) + + @staticmethod + def error(msg, *args, **kwargs): + loguru_logger.error(msg, *args, **kwargs) + + @staticmethod + def critical(msg, *args, **kwargs): + loguru_logger.critical(msg, *args, **kwargs) + + @staticmethod + def exception(msg, *args, **kwargs): + loguru_logger.exception(msg, *args, **kwargs) + + @staticmethod + def success(msg, *args, **kwargs): + loguru_logger.success(msg, *args, **kwargs) + + +# 创建全局logger实例 +logger = Logger() + +# 初始化日志系统 +setup_logging() \ No newline at end of file diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index 24c8fed..c10a975 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -12,6 +12,7 @@ asyncssh==2.14.2 telnetlib3==2.0.3 httpx==0.27.0 python-nmap==0.7.1 +pysnmp==4.4.12 # 异步文件操作 aiofiles==23.2.1 diff --git a/src/backend/test_ensp.py b/src/backend/test_ensp.py index f2a4140..19803a8 100644 --- a/src/backend/test_ensp.py +++ b/src/backend/test_ensp.py @@ -1,39 +1,73 @@ import asyncio import logging -from src.backend.app.api.network_config import SwitchConfigurator # 导入你的核心类 +from src.backend.app.api.network_config import SwitchConfigurator #该文件用于测试 -# 设置日志 logging.basicConfig( level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s' ) -async def test_ensp(): - """eNSP测试函数""" - # 1. 初始化配置器(对应eNSP设备设置) - configurator = SwitchConfigurator( - ensp_mode=True, # 启用eNSP模式 - ensp_port=2000, # 必须与eNSP中设备设置的Telnet端口一致 - username="admin", # 默认账号 - password="admin", # 默认密码 - timeout=15 # 建议超时设长些 - ) - # 2. 执行配置(示例:创建VLAN100) +async def test_connection(configurator): + """测试基础连接""" + try: + version = await configurator._send_commands("127.0.0.1", ["display version"]) + print("交换机版本信息:\n", version) + return True + except Exception as e: + print("❌ 连接测试失败:", str(e)) + return False + + +async def test_vlan_config(configurator): + """测试 VLAN 配置""" try: result = await configurator.safe_apply( - ip="127.0.0.1", # 本地连接固定用这个地址 + "127.0.0.1", config={ "type": "vlan", "vlan_id": 100, - "name": "测试VLAN" + "name": "自动化测试VLAN" } ) - print("✅ 配置结果:", result) - except Exception as e: - print("❌ 配置失败:", str(e)) + print("VLAN 配置结果:", result) + + # 验证配置 + vlan_list = await configurator._send_commands("127.0.0.1", ["display vlan"]) + print("当前VLAN列表:\n", vlan_list) + + return "success" in result.get("status", "") + except Exception as e: + print("❌ VLAN 配置失败:", str(e)) + return False + + +async def main(): + """主测试流程""" + # 尝试不同端口 + for port in [2000, 2010, 2020, 23]: + print(f"\n尝试端口: {port}") + configurator = SwitchConfigurator( + ensp_mode=True, + ensp_port=port, + username="", + password="admin", + timeout=15 + ) + + if await test_connection(configurator): + print(f"✅ 成功连接到端口 {port}") + if await test_vlan_config(configurator): + print("✅ 所有测试通过!") + return + else: + print("⚠️ VLAN 配置失败,继续尝试其他端口...") + else: + print("⚠️ 连接失败,尝试下一个端口...") + + print("❌ 所有端口尝试失败,请检查配置") + -# 运行测试 if __name__ == "__main__": - asyncio.run(test_ensp()) \ No newline at end of file + asyncio.run(main()) \ No newline at end of file