114514+114514

This commit is contained in:
3 2025-06-19 14:52:59 +08:00
parent 2231b8cf82
commit 7c17bb931b
7 changed files with 567 additions and 77 deletions

View File

@ -2,6 +2,7 @@ from fastapi import FastAPI, responses
from src.backend.app.api.endpoints import router from src.backend.app.api.endpoints import router
from src.backend.app.utils.logger import setup_logging from src.backend.app.utils.logger import setup_logging
from src.backend.config import settings from src.backend.config import settings
from .services.switch_traffic_monitor import get_switch_monitor
# 添加正确的导入 # 添加正确的导入
from .services.traffic_monitor import traffic_monitor from .services.traffic_monitor import traffic_monitor
@ -37,10 +38,12 @@ def create_app() -> FastAPI:
async def favicon(): async def favicon():
return responses.Response(status_code=204) return responses.Response(status_code=204)
# 添加API路由 # 添加API路由
app.include_router(router, prefix=settings.API_PREFIX) app.include_router(router, prefix=settings.API_PREFIX)
return app return app
app = create_app() app = create_app()

View File

@ -1,20 +1,23 @@
from datetime import datetime from datetime import datetime, timedelta
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
from fastapi import (APIRouter, HTTPException, Response,WebSocket, WebSocketDisconnect)
from typing import List from typing import List
from pydantic import BaseModel from pydantic import BaseModel
import asyncio
from fastapi.responses import HTMLResponse
import matplotlib.pyplot as plt
import io
import base64
from ..services.switch_traffic_monitor import get_switch_monitor
from ..utils import logger
from ...app.services.ai_services import AIService from ...app.services.ai_services import AIService
from ...app.api.network_config import SwitchConfigurator from ...app.api.network_config import SwitchConfigurator
from ...config import settings from ...config import settings
from ..services.network_scanner import NetworkScanner from ..services.network_scanner import NetworkScanner
from ...app.services.traffic_monitor import traffic_monitor from ...app.services.traffic_monitor import traffic_monitor
from ...app.models.traffic_models import TrafficRecord from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
from src.backend.app.api.database import (SessionLocal) from src.backend.app.api.database import SessionLocal
import asyncio
from fastapi.responses import HTMLResponse
import matplotlib.pyplot as plt
import io
import base64
@ -34,6 +37,8 @@ async def root():
"/scan_network", "/scan_network",
"/list_devices", "/list_devices",
"/batch_apply_config" "/batch_apply_config"
"/traffic/switch/current", # 添加交换机流量端点
"/traffic/switch/history" # 添加交换机历史流量端点
] ]
} }
@ -163,42 +168,223 @@ async def websocket_traffic(websocket: WebSocket):
print("客户端断开连接") print("客户端断开连接")
@router.get("/traffic/plot", response_class=HTMLResponse, summary="流量可视化图表") @router.get("/", include_in_schema=False)
async def plot_traffic(interface: str = "eth0", minutes: int = 10): async def root():
# 获取历史数据 return {
history = traffic_monitor.get_traffic_history(interface) "message": "欢迎使用AI交换机配置系统",
time_points = history["time"][-minutes * 60:] "docs": f"{settings.API_PREFIX}/docs",
sent = history["sent"][-minutes * 60:] "redoc": f"{settings.API_PREFIX}/redoc",
recv = history["recv"][-minutes * 60:] "endpoints": [
"/parse_command",
"/apply_config",
"/scan_network",
"/list_devices",
"/batch_apply_config",
"/traffic/switch/current", # 添加交换机流量端点
"/traffic/switch/history" # 添加交换机历史流量端点
]
}
# 创建图表
plt.figure(figsize=(10, 6))
plt.plot(time_points, sent, label="发送流量 (B/s)")
plt.plot(time_points, recv, label="接收流量 (B/s)")
plt.title(f"{interface} 流量监控 - 最近 {minutes} 分钟")
plt.xlabel("时间")
plt.ylabel("流量 (字节/秒)")
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
# 转换为HTML图像 # ... 其他路由保持不变 ...
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
return f""" @router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
<html> async def get_switch_interfaces(switch_ip: str):
<head> """获取指定交换机的所有接口"""
<title>网络流量监控</title> try:
</head> monitor = get_switch_monitor(switch_ip)
<body> interfaces = list(monitor.interface_oids.keys())
<h1>{interface} 网络流量监控</h1> return {
<img src="data:image/png;base64,{image_base64}" alt="流量图表"> "switch_ip": switch_ip,
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p> "interfaces": interfaces
</body> }
</html> except Exception as e:
""" logger.error(f"获取交换机接口失败: {str(e)}")
raise HTTPException(500, f"获取接口失败: {str(e)}")
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
"""获取交换机的当前流量数据"""
try:
monitor = get_switch_monitor(switch_ip)
# 如果没有指定接口,获取所有接口的流量
if not interface:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
return {
"switch_ip": switch_ip,
"traffic": traffic_data
}
# 获取指定接口的流量
return await get_interface_current_traffic(switch_ip, interface)
except Exception as e:
logger.error(f"获取交换机流量失败: {str(e)}")
raise HTTPException(500, f"获取流量失败: {str(e)}")
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
"""获取指定交换机接口的当前流量数据"""
try:
with SessionLocal() as session:
# 获取最新记录
record = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
if not record:
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": 0.0,
"rate_out": 0.0,
"bytes_in": 0,
"bytes_out": 0
}
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": record.rate_in,
"rate_out": record.rate_out,
"bytes_in": record.bytes_in,
"bytes_out": record.bytes_out
}
except Exception as e:
logger.error(f"获取接口流量失败: {str(e)}")
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
"""获取交换机的流量历史数据"""
try:
monitor = get_switch_monitor(switch_ip)
# 如果没有指定接口,返回所有接口的历史数据
if not interface:
return {
"switch_ip": switch_ip,
"history": monitor.get_traffic_history()
}
# 获取指定接口的历史数据
with SessionLocal() as session:
# 计算时间范围
time_threshold = datetime.now() - timedelta(minutes=minutes)
# 查询历史记录
records = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface,
SwitchTrafficRecord.timestamp >= time_threshold
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
# 提取数据
history_data = {
"in": [record.rate_in for record in records],
"out": [record.rate_out for record in records],
"time": [record.timestamp.isoformat() for record in records]
}
return {
"switch_ip": switch_ip,
"interface": interface,
"history": history_data
}
except Exception as e:
logger.error(f"获取历史流量失败: {str(e)}")
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
@router.websocket("/ws/traffic/switch")
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
"""交换机实时流量WebSocket"""
await websocket.accept()
try:
monitor = get_switch_monitor(switch_ip)
while True:
# 获取流量数据
if interface:
# 单个接口
traffic_data = await get_interface_current_traffic(switch_ip, interface)
await websocket.send_json(traffic_data)
else:
# 所有接口
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
await websocket.send_json({
"switch_ip": switch_ip,
"traffic": traffic_data
})
# 每秒更新一次
await asyncio.sleep(1)
except WebSocketDisconnect:
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
except Exception as e:
logger.error(f"交换机流量WebSocket错误: {str(e)}")
await websocket.close(code=1011, reason=str(e))
# 交换机流量可视化
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
"""生成交换机流量图表"""
try:
# 获取历史数据
history = await get_switch_traffic_history(switch_ip, interface, minutes)
history_data = history["history"]
# 提取数据点
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
in_rates = history_data["in"]
out_rates = history_data["out"]
# 创建图表
plt.figure(figsize=(12, 6))
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
plt.xlabel("时间")
plt.ylabel("流量 (字节/秒)")
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
# 转换为HTML图像
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
return f"""
<html>
<head>
<title>交换机流量监控</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.container {{ max-width: 900px; margin: 0 auto; }}
</style>
</head>
<body>
<div class="container">
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</div>
</body>
</html>
"""
except Exception as e:
logger.error(f"生成流量图表失败: {str(e)}")
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)

View File

@ -1,6 +1,6 @@
# 添加正确的导入 # 添加正确的导入
from src.backend.app.api.database import Base # 修复:导入 Base from src.backend.app.api.database import Base # 修复:导入 Base
from sqlalchemy import Column, Integer, String, DateTime from sqlalchemy import Column, Integer, String, DateTime, BigInteger, Float
class TrafficRecord(Base): class TrafficRecord(Base):
@ -25,3 +25,28 @@ class TrafficRecord(Base):
"packets_recv": self.packets_recv, "packets_recv": self.packets_recv,
"timestamp": self.timestamp.isoformat() "timestamp": self.timestamp.isoformat()
} }
class SwitchTrafficRecord(Base):
__tablename__ = "switch_traffic_records"
id = Column(Integer, primary_key=True, index=True)
switch_ip = Column(String(50), index=True)
interface = Column(String(50))
bytes_in = Column(BigInteger) # 累计流入字节数
bytes_out = Column(BigInteger) # 累计流出字节数
rate_in = Column(Float) # 当前流入速率(字节/秒)
rate_out = Column(Float) # 当前流出速率(字节/秒)
timestamp = Column(DateTime)
def to_dict(self):
return {
"id": self.id,
"switch_ip": self.switch_ip,
"interface": self.interface,
"bytes_in": self.bytes_in,
"bytes_out": self.bytes_out,
"rate_in": self.rate_in,
"rate_out": self.rate_out,
"timestamp": self.timestamp.isoformat()
}

View File

@ -0,0 +1,196 @@
import asyncio
from datetime import datetime
from collections import deque
from typing import Optional, List, Dict
from pysnmp.hlapi import *
from ..models.traffic_models import SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
from ..utils.logger import logger
class SwitchTrafficMonitor:
def __init__(
self,
switch_ip: str,
community: str = 'public',
update_interval: int = 5,
interfaces: Optional[List[str]] = None
):
self.switch_ip = switch_ip
self.community = community
self.update_interval = update_interval
self.running = False
self.task = None
self.interface_history = {}
self.history = {
"in": deque(maxlen=300),
"out": deque(maxlen=300),
"time": deque(maxlen=300)
}
# 基本接口OID映射
self.interface_oids = {
"GigabitEthernet0/0/1": {
"in": '1.3.6.1.2.1.2.2.1.10.1', # ifInOctets
"out": '1.3.6.1.2.1.2.2.1.16.1' # ifOutOctets
},
"GigabitEthernet0/0/24": {
"in": '1.3.6.1.2.1.2.2.1.10.24',
"out": '1.3.6.1.2.1.2.2.1.16.24'
}
}
# 接口过滤
if interfaces:
self.interface_oids = {
iface: oid for iface, oid in self.interface_oids.items()
if iface in interfaces
}
logger.info(f"监控指定接口: {', '.join(interfaces)}")
else:
logger.info("监控所有接口")
def start_monitoring(self):
"""启动交换机流量监控"""
if not self.running:
self.running = True
self.task = asyncio.create_task(self._monitor_loop())
logger.success(f"交换机流量监控已启动: {self.switch_ip}")
async def stop_monitoring(self):
"""停止监控"""
if self.running:
self.running = False
if self.task:
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
logger.info(f"交换机流量监控已停止: {self.switch_ip}")
async def _monitor_loop(self):
"""监控主循环"""
last_values = {iface: {"in": 0, "out": 0} for iface in self.interface_oids}
last_time = datetime.now()
while self.running:
await asyncio.sleep(self.update_interval)
try:
current_time = datetime.now()
elapsed = (current_time - last_time).total_seconds()
# 获取所有接口流量
for iface, oids in self.interface_oids.items():
in_octets = self._snmp_get(oids["in"])
out_octets = self._snmp_get(oids["out"])
if in_octets is not None and out_octets is not None:
# 计算速率(字节/秒)
# 修复字典访问问题
iface_values = last_values[iface]
in_rate = (in_octets - iface_values["in"]) / elapsed if iface_values["in"] > 0 else 0
out_rate = (out_octets - iface_values["out"]) / elapsed if iface_values["out"] > 0 else 0
# 保存历史数据
self.history["in"].append(in_rate)
self.history["out"].append(out_rate)
self.history["time"].append(current_time)
# 保存到数据库
self._save_to_db(iface, in_octets, out_octets, in_rate, out_rate, current_time)
# 更新最后的值
iface_values["in"] = in_octets
iface_values["out"] = out_octets
last_time = current_time
except Exception as e:
logger.error(f"监控交换机流量出错: {str(e)}")
def _snmp_get(self, oid) -> Optional[int]:
"""执行SNMP GET请求"""
try:
# 正确格式化的SNMP请求
cmd = getCmd(
SnmpEngine(),
CommunityData(self.community),
UdpTransportTarget((self.switch_ip, 161)),
ContextData(),
ObjectType(ObjectIdentity(oid)))
# 执行命令
errorIndication, errorStatus, errorIndex, varBinds = next(cmd)
except Exception as e:
logger.error(f"SNMP请求失败: {str(e)}")
return None
if errorIndication:
logger.error(f"SNMP错误: {errorIndication}")
return None
elif errorStatus:
try:
# 修复括号问题
if errorIndex:
index_val = int(errorIndex) - 1
error_item = varBinds[index_val] if index_val < len(varBinds) else '?'
else:
error_item = '?'
error_msg = f"SNMP错误: {errorStatus.prettyPrint()} at {error_item}"
logger.error(error_msg)
except Exception as e:
logger.error(f"解析SNMP错误失败: {str(e)}")
return None
else:
for varBind in varBinds:
try:
return int(varBind[1])
except Exception as e:
logger.error(f"转换SNMP值失败: {str(e)}")
return None
return None
def _save_to_db(self, interface: str, in_octets: int, out_octets: int,
in_rate: float, out_rate: float, timestamp: datetime):
"""保存流量数据到数据库"""
try:
with SessionLocal() as session:
record = SwitchTrafficRecord(
switch_ip=self.switch_ip,
interface=interface,
bytes_in=in_octets,
bytes_out=out_octets,
rate_in=in_rate,
rate_out=out_rate,
timestamp=timestamp
)
session.add(record)
session.commit()
except Exception as e:
logger.error(f"保存流量数据到数据库失败: {str(e)}")
def get_traffic_history(self) -> Dict[str, List]:
"""获取流量历史数据"""
return {
"in": list(self.history["in"]),
"out": list(self.history["out"]),
"time": list(self.history["time"])
}
# 全局监控器字典(支持多个交换机)
switch_monitors = {}
def get_switch_monitor(switch_ip: str, community: str = 'public', interfaces: Optional[List[str]] = None):
"""获取或创建交换机监控器(添加接口过滤参数)"""
if switch_ip not in switch_monitors:
switch_monitors[switch_ip] = SwitchTrafficMonitor(
switch_ip,
community,
interfaces=interfaces
)
return switch_monitors[switch_ip]

View File

@ -1,15 +1,13 @@
import logging import logging
from loguru import logger
import sys import sys
from loguru import logger as loguru_logger
class InterceptHandler(logging.Handler): class InterceptHandler(logging.Handler):
def emit(self, record): def emit(self, record):
# 获取对应的Loguru日志级别 # 获取对应的Loguru日志级别
try: try:
level = logger.level(record.levelname).name level = loguru_logger.level(record.levelname).name
except ValueError: except ValueError:
level = record.levelno level = record.levelno
@ -20,7 +18,7 @@ class InterceptHandler(logging.Handler):
depth += 1 depth += 1
# 使用Loguru记录日志 # 使用Loguru记录日志
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) loguru_logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
def setup_logging(): def setup_logging():
@ -28,10 +26,10 @@ def setup_logging():
logging.basicConfig(handlers=[InterceptHandler()], level=logging.NOTSET) logging.basicConfig(handlers=[InterceptHandler()], level=logging.NOTSET)
# 移除所有现有处理器 # 移除所有现有处理器
logger.remove() loguru_logger.remove()
# 直接添加处理器避免configure方法 # 添加新的处理器
logger.add( loguru_logger.add(
sys.stdout, sys.stdout,
format=( format=(
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | " "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
@ -42,3 +40,50 @@ def setup_logging():
level="DEBUG", level="DEBUG",
enqueue=True enqueue=True
) )
# 添加文件日志
loguru_logger.add(
"app.log",
rotation="10 MB",
retention="30 days",
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level} | {message}"
)
# 创建通用logger接口
class Logger:
@staticmethod
def debug(msg, *args, **kwargs):
loguru_logger.debug(msg, *args, **kwargs)
@staticmethod
def info(msg, *args, **kwargs):
loguru_logger.info(msg, *args, **kwargs)
@staticmethod
def warning(msg, *args, **kwargs):
loguru_logger.warning(msg, *args, **kwargs)
@staticmethod
def error(msg, *args, **kwargs):
loguru_logger.error(msg, *args, **kwargs)
@staticmethod
def critical(msg, *args, **kwargs):
loguru_logger.critical(msg, *args, **kwargs)
@staticmethod
def exception(msg, *args, **kwargs):
loguru_logger.exception(msg, *args, **kwargs)
@staticmethod
def success(msg, *args, **kwargs):
loguru_logger.success(msg, *args, **kwargs)
# 创建全局logger实例
logger = Logger()
# 初始化日志系统
setup_logging()

View File

@ -12,6 +12,7 @@ asyncssh==2.14.2
telnetlib3==2.0.3 telnetlib3==2.0.3
httpx==0.27.0 httpx==0.27.0
python-nmap==0.7.1 python-nmap==0.7.1
pysnmp==4.4.12
# 异步文件操作 # 异步文件操作
aiofiles==23.2.1 aiofiles==23.2.1

View File

@ -1,39 +1,73 @@
import asyncio import asyncio
import logging import logging
from src.backend.app.api.network_config import SwitchConfigurator # 导入你的核心类 from src.backend.app.api.network_config import SwitchConfigurator
#该文件用于测试 #该文件用于测试
# 设置日志
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG, level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s' format='%(asctime)s - %(levelname)s - %(message)s'
) )
async def test_ensp():
"""eNSP测试函数"""
# 1. 初始化配置器对应eNSP设备设置
configurator = SwitchConfigurator(
ensp_mode=True, # 启用eNSP模式
ensp_port=2000, # 必须与eNSP中设备设置的Telnet端口一致
username="admin", # 默认账号
password="admin", # 默认密码
timeout=15 # 建议超时设长些
)
# 2. 执行配置示例创建VLAN100 async def test_connection(configurator):
"""测试基础连接"""
try:
version = await configurator._send_commands("127.0.0.1", ["display version"])
print("交换机版本信息:\n", version)
return True
except Exception as e:
print("❌ 连接测试失败:", str(e))
return False
async def test_vlan_config(configurator):
"""测试 VLAN 配置"""
try: try:
result = await configurator.safe_apply( result = await configurator.safe_apply(
ip="127.0.0.1", # 本地连接固定用这个地址 "127.0.0.1",
config={ config={
"type": "vlan", "type": "vlan",
"vlan_id": 100, "vlan_id": 100,
"name": "测试VLAN" "name": "自动化测试VLAN"
} }
) )
print("✅ 配置结果:", result) print("VLAN 配置结果:", result)
except Exception as e:
print("❌ 配置失败:", str(e)) # 验证配置
vlan_list = await configurator._send_commands("127.0.0.1", ["display vlan"])
print("当前VLAN列表:\n", vlan_list)
return "success" in result.get("status", "")
except Exception as e:
print("❌ VLAN 配置失败:", str(e))
return False
async def main():
"""主测试流程"""
# 尝试不同端口
for port in [2000, 2010, 2020, 23]:
print(f"\n尝试端口: {port}")
configurator = SwitchConfigurator(
ensp_mode=True,
ensp_port=port,
username="",
password="admin",
timeout=15
)
if await test_connection(configurator):
print(f"✅ 成功连接到端口 {port}")
if await test_vlan_config(configurator):
print("✅ 所有测试通过!")
return
else:
print("⚠️ VLAN 配置失败,继续尝试其他端口...")
else:
print("⚠️ 连接失败,尝试下一个端口...")
print("❌ 所有端口尝试失败,请检查配置")
# 运行测试
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(test_ensp()) asyncio.run(main())