114514+114514

This commit is contained in:
3 2025-06-19 14:52:59 +08:00
parent 2231b8cf82
commit 7c17bb931b
7 changed files with 567 additions and 77 deletions

View File

@ -2,6 +2,7 @@ from fastapi import FastAPI, responses
from src.backend.app.api.endpoints import router
from src.backend.app.utils.logger import setup_logging
from src.backend.config import settings
from .services.switch_traffic_monitor import get_switch_monitor
# 添加正确的导入
from .services.traffic_monitor import traffic_monitor
@ -37,10 +38,12 @@ def create_app() -> FastAPI:
async def favicon():
return responses.Response(status_code=204)
# 添加API路由
app.include_router(router, prefix=settings.API_PREFIX)
return app
app = create_app()

View File

@ -1,21 +1,24 @@
from datetime import datetime
from datetime import datetime, timedelta
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
from typing import List
from pydantic import BaseModel
from ...app.services.ai_services import AIService
from ...app.api.network_config import SwitchConfigurator
from ...config import settings
from ..services.network_scanner import NetworkScanner
from ...app.services.traffic_monitor import traffic_monitor
from ...app.models.traffic_models import TrafficRecord
from src.backend.app.api.database import (SessionLocal)
import asyncio
from fastapi.responses import HTMLResponse
import matplotlib.pyplot as plt
import io
import base64
from ..services.switch_traffic_monitor import get_switch_monitor
from ..utils import logger
from ...app.services.ai_services import AIService
from ...app.api.network_config import SwitchConfigurator
from ...config import settings
from ..services.network_scanner import NetworkScanner
from ...app.services.traffic_monitor import traffic_monitor
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
router = APIRouter(prefix="", tags=["API"])
@ -34,6 +37,8 @@ async def root():
"/scan_network",
"/list_devices",
"/batch_apply_config"
"/traffic/switch/current", # 添加交换机流量端点
"/traffic/switch/history" # 添加交换机历史流量端点
]
}
@ -163,19 +168,191 @@ async def websocket_traffic(websocket: WebSocket):
print("客户端断开连接")
@router.get("/traffic/plot", response_class=HTMLResponse, summary="流量可视化图表")
async def plot_traffic(interface: str = "eth0", minutes: int = 10):
@router.get("/", include_in_schema=False)
async def root():
return {
"message": "欢迎使用AI交换机配置系统",
"docs": f"{settings.API_PREFIX}/docs",
"redoc": f"{settings.API_PREFIX}/redoc",
"endpoints": [
"/parse_command",
"/apply_config",
"/scan_network",
"/list_devices",
"/batch_apply_config",
"/traffic/switch/current", # 添加交换机流量端点
"/traffic/switch/history" # 添加交换机历史流量端点
]
}
# ... 其他路由保持不变 ...
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
async def get_switch_interfaces(switch_ip: str):
"""获取指定交换机的所有接口"""
try:
monitor = get_switch_monitor(switch_ip)
interfaces = list(monitor.interface_oids.keys())
return {
"switch_ip": switch_ip,
"interfaces": interfaces
}
except Exception as e:
logger.error(f"获取交换机接口失败: {str(e)}")
raise HTTPException(500, f"获取接口失败: {str(e)}")
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
"""获取交换机的当前流量数据"""
try:
monitor = get_switch_monitor(switch_ip)
# 如果没有指定接口,获取所有接口的流量
if not interface:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
return {
"switch_ip": switch_ip,
"traffic": traffic_data
}
# 获取指定接口的流量
return await get_interface_current_traffic(switch_ip, interface)
except Exception as e:
logger.error(f"获取交换机流量失败: {str(e)}")
raise HTTPException(500, f"获取流量失败: {str(e)}")
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
"""获取指定交换机接口的当前流量数据"""
try:
with SessionLocal() as session:
# 获取最新记录
record = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
if not record:
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": 0.0,
"rate_out": 0.0,
"bytes_in": 0,
"bytes_out": 0
}
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": record.rate_in,
"rate_out": record.rate_out,
"bytes_in": record.bytes_in,
"bytes_out": record.bytes_out
}
except Exception as e:
logger.error(f"获取接口流量失败: {str(e)}")
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
"""获取交换机的流量历史数据"""
try:
monitor = get_switch_monitor(switch_ip)
# 如果没有指定接口,返回所有接口的历史数据
if not interface:
return {
"switch_ip": switch_ip,
"history": monitor.get_traffic_history()
}
# 获取指定接口的历史数据
with SessionLocal() as session:
# 计算时间范围
time_threshold = datetime.now() - timedelta(minutes=minutes)
# 查询历史记录
records = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface,
SwitchTrafficRecord.timestamp >= time_threshold
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
# 提取数据
history_data = {
"in": [record.rate_in for record in records],
"out": [record.rate_out for record in records],
"time": [record.timestamp.isoformat() for record in records]
}
return {
"switch_ip": switch_ip,
"interface": interface,
"history": history_data
}
except Exception as e:
logger.error(f"获取历史流量失败: {str(e)}")
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
@router.websocket("/ws/traffic/switch")
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
"""交换机实时流量WebSocket"""
await websocket.accept()
try:
monitor = get_switch_monitor(switch_ip)
while True:
# 获取流量数据
if interface:
# 单个接口
traffic_data = await get_interface_current_traffic(switch_ip, interface)
await websocket.send_json(traffic_data)
else:
# 所有接口
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
await websocket.send_json({
"switch_ip": switch_ip,
"traffic": traffic_data
})
# 每秒更新一次
await asyncio.sleep(1)
except WebSocketDisconnect:
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
except Exception as e:
logger.error(f"交换机流量WebSocket错误: {str(e)}")
await websocket.close(code=1011, reason=str(e))
# 交换机流量可视化
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
"""生成交换机流量图表"""
try:
# 获取历史数据
history = traffic_monitor.get_traffic_history(interface)
time_points = history["time"][-minutes * 60:]
sent = history["sent"][-minutes * 60:]
recv = history["recv"][-minutes * 60:]
history = await get_switch_traffic_history(switch_ip, interface, minutes)
history_data = history["history"]
# 提取数据点
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
in_rates = history_data["in"]
out_rates = history_data["out"]
# 创建图表
plt.figure(figsize=(10, 6))
plt.plot(time_points, sent, label="发送流量 (B/s)")
plt.plot(time_points, recv, label="接收流量 (B/s)")
plt.title(f"{interface} 流量监控 - 最近 {minutes} 分钟")
plt.figure(figsize=(12, 6))
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
plt.xlabel("时间")
plt.ylabel("流量 (字节/秒)")
plt.legend()
@ -193,12 +370,21 @@ async def plot_traffic(interface: str = "eth0", minutes: int = 10):
return f"""
<html>
<head>
<title>网络流量监控</title>
<title>交换机流量监控</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.container {{ max-width: 900px; margin: 0 auto; }}
</style>
</head>
<body>
<h1>{interface} 网络流量监控</h1>
<div class="container">
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</div>
</body>
</html>
"""
except Exception as e:
logger.error(f"生成流量图表失败: {str(e)}")
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)

View File

@ -1,6 +1,6 @@
# 添加正确的导入
from src.backend.app.api.database import Base # 修复:导入 Base
from sqlalchemy import Column, Integer, String, DateTime
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, Float
class TrafficRecord(Base):
@ -25,3 +25,28 @@ class TrafficRecord(Base):
"packets_recv": self.packets_recv,
"timestamp": self.timestamp.isoformat()
}
class SwitchTrafficRecord(Base):
__tablename__ = "switch_traffic_records"
id = Column(Integer, primary_key=True, index=True)
switch_ip = Column(String(50), index=True)
interface = Column(String(50))
bytes_in = Column(BigInteger) # 累计流入字节数
bytes_out = Column(BigInteger) # 累计流出字节数
rate_in = Column(Float) # 当前流入速率(字节/秒)
rate_out = Column(Float) # 当前流出速率(字节/秒)
timestamp = Column(DateTime)
def to_dict(self):
return {
"id": self.id,
"switch_ip": self.switch_ip,
"interface": self.interface,
"bytes_in": self.bytes_in,
"bytes_out": self.bytes_out,
"rate_in": self.rate_in,
"rate_out": self.rate_out,
"timestamp": self.timestamp.isoformat()
}

View File

@ -0,0 +1,196 @@
import asyncio
from datetime import datetime
from collections import deque
from typing import Optional, List, Dict
from pysnmp.hlapi import *
from ..models.traffic_models import SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
from ..utils.logger import logger
class SwitchTrafficMonitor:
def __init__(
self,
switch_ip: str,
community: str = 'public',
update_interval: int = 5,
interfaces: Optional[List[str]] = None
):
self.switch_ip = switch_ip
self.community = community
self.update_interval = update_interval
self.running = False
self.task = None
self.interface_history = {}
self.history = {
"in": deque(maxlen=300),
"out": deque(maxlen=300),
"time": deque(maxlen=300)
}
# 基本接口OID映射
self.interface_oids = {
"GigabitEthernet0/0/1": {
"in": '1.3.6.1.2.1.2.2.1.10.1', # ifInOctets
"out": '1.3.6.1.2.1.2.2.1.16.1' # ifOutOctets
},
"GigabitEthernet0/0/24": {
"in": '1.3.6.1.2.1.2.2.1.10.24',
"out": '1.3.6.1.2.1.2.2.1.16.24'
}
}
# 接口过滤
if interfaces:
self.interface_oids = {
iface: oid for iface, oid in self.interface_oids.items()
if iface in interfaces
}
logger.info(f"监控指定接口: {', '.join(interfaces)}")
else:
logger.info("监控所有接口")
def start_monitoring(self):
"""启动交换机流量监控"""
if not self.running:
self.running = True
self.task = asyncio.create_task(self._monitor_loop())
logger.success(f"交换机流量监控已启动: {self.switch_ip}")
async def stop_monitoring(self):
"""停止监控"""
if self.running:
self.running = False
if self.task:
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
logger.info(f"交换机流量监控已停止: {self.switch_ip}")
async def _monitor_loop(self):
"""监控主循环"""
last_values = {iface: {"in": 0, "out": 0} for iface in self.interface_oids}
last_time = datetime.now()
while self.running:
await asyncio.sleep(self.update_interval)
try:
current_time = datetime.now()
elapsed = (current_time - last_time).total_seconds()
# 获取所有接口流量
for iface, oids in self.interface_oids.items():
in_octets = self._snmp_get(oids["in"])
out_octets = self._snmp_get(oids["out"])
if in_octets is not None and out_octets is not None:
# 计算速率(字节/秒)
# 修复字典访问问题
iface_values = last_values[iface]
in_rate = (in_octets - iface_values["in"]) / elapsed if iface_values["in"] > 0 else 0
out_rate = (out_octets - iface_values["out"]) / elapsed if iface_values["out"] > 0 else 0
# 保存历史数据
self.history["in"].append(in_rate)
self.history["out"].append(out_rate)
self.history["time"].append(current_time)
# 保存到数据库
self._save_to_db(iface, in_octets, out_octets, in_rate, out_rate, current_time)
# 更新最后的值
iface_values["in"] = in_octets
iface_values["out"] = out_octets
last_time = current_time
except Exception as e:
logger.error(f"监控交换机流量出错: {str(e)}")
def _snmp_get(self, oid) -> Optional[int]:
"""执行SNMP GET请求"""
try:
# 正确格式化的SNMP请求
cmd = getCmd(
SnmpEngine(),
CommunityData(self.community),
UdpTransportTarget((self.switch_ip, 161)),
ContextData(),
ObjectType(ObjectIdentity(oid)))
# 执行命令
errorIndication, errorStatus, errorIndex, varBinds = next(cmd)
except Exception as e:
logger.error(f"SNMP请求失败: {str(e)}")
return None
if errorIndication:
logger.error(f"SNMP错误: {errorIndication}")
return None
elif errorStatus:
try:
# 修复括号问题
if errorIndex:
index_val = int(errorIndex) - 1
error_item = varBinds[index_val] if index_val < len(varBinds) else '?'
else:
error_item = '?'
error_msg = f"SNMP错误: {errorStatus.prettyPrint()} at {error_item}"
logger.error(error_msg)
except Exception as e:
logger.error(f"解析SNMP错误失败: {str(e)}")
return None
else:
for varBind in varBinds:
try:
return int(varBind[1])
except Exception as e:
logger.error(f"转换SNMP值失败: {str(e)}")
return None
return None
def _save_to_db(self, interface: str, in_octets: int, out_octets: int,
in_rate: float, out_rate: float, timestamp: datetime):
"""保存流量数据到数据库"""
try:
with SessionLocal() as session:
record = SwitchTrafficRecord(
switch_ip=self.switch_ip,
interface=interface,
bytes_in=in_octets,
bytes_out=out_octets,
rate_in=in_rate,
rate_out=out_rate,
timestamp=timestamp
)
session.add(record)
session.commit()
except Exception as e:
logger.error(f"保存流量数据到数据库失败: {str(e)}")
def get_traffic_history(self) -> Dict[str, List]:
"""获取流量历史数据"""
return {
"in": list(self.history["in"]),
"out": list(self.history["out"]),
"time": list(self.history["time"])
}
# 全局监控器字典(支持多个交换机)
switch_monitors = {}
def get_switch_monitor(switch_ip: str, community: str = 'public', interfaces: Optional[List[str]] = None):
"""获取或创建交换机监控器(添加接口过滤参数)"""
if switch_ip not in switch_monitors:
switch_monitors[switch_ip] = SwitchTrafficMonitor(
switch_ip,
community,
interfaces=interfaces
)
return switch_monitors[switch_ip]

View File

@ -1,15 +1,13 @@
import logging
from loguru import logger
import sys
from loguru import logger as loguru_logger
class InterceptHandler(logging.Handler):
def emit(self, record):
# 获取对应的Loguru日志级别
try:
level = logger.level(record.levelname).name
level = loguru_logger.level(record.levelname).name
except ValueError:
level = record.levelno
@ -20,7 +18,7 @@ class InterceptHandler(logging.Handler):
depth += 1
# 使用Loguru记录日志
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
loguru_logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
def setup_logging():
@ -28,10 +26,10 @@ def setup_logging():
logging.basicConfig(handlers=[InterceptHandler()], level=logging.NOTSET)
# 移除所有现有处理器
logger.remove()
loguru_logger.remove()
# 直接添加处理器避免configure方法
logger.add(
# 添加新的处理器
loguru_logger.add(
sys.stdout,
format=(
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
@ -42,3 +40,50 @@ def setup_logging():
level="DEBUG",
enqueue=True
)
# 添加文件日志
loguru_logger.add(
"app.log",
rotation="10 MB",
retention="30 days",
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level} | {message}"
)
# 创建通用logger接口
class Logger:
@staticmethod
def debug(msg, *args, **kwargs):
loguru_logger.debug(msg, *args, **kwargs)
@staticmethod
def info(msg, *args, **kwargs):
loguru_logger.info(msg, *args, **kwargs)
@staticmethod
def warning(msg, *args, **kwargs):
loguru_logger.warning(msg, *args, **kwargs)
@staticmethod
def error(msg, *args, **kwargs):
loguru_logger.error(msg, *args, **kwargs)
@staticmethod
def critical(msg, *args, **kwargs):
loguru_logger.critical(msg, *args, **kwargs)
@staticmethod
def exception(msg, *args, **kwargs):
loguru_logger.exception(msg, *args, **kwargs)
@staticmethod
def success(msg, *args, **kwargs):
loguru_logger.success(msg, *args, **kwargs)
# 创建全局logger实例
logger = Logger()
# 初始化日志系统
setup_logging()

View File

@ -12,6 +12,7 @@ asyncssh==2.14.2
telnetlib3==2.0.3
httpx==0.27.0
python-nmap==0.7.1
pysnmp==4.4.12
# 异步文件操作
aiofiles==23.2.1

View File

@ -1,39 +1,73 @@
import asyncio
import logging
from src.backend.app.api.network_config import SwitchConfigurator # 导入你的核心类
from src.backend.app.api.network_config import SwitchConfigurator
#该文件用于测试
# 设置日志
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s'
)
async def test_ensp():
"""eNSP测试函数"""
# 1. 初始化配置器对应eNSP设备设置
configurator = SwitchConfigurator(
ensp_mode=True, # 启用eNSP模式
ensp_port=2000, # 必须与eNSP中设备设置的Telnet端口一致
username="admin", # 默认账号
password="admin", # 默认密码
timeout=15 # 建议超时设长些
)
# 2. 执行配置示例创建VLAN100
async def test_connection(configurator):
"""测试基础连接"""
try:
version = await configurator._send_commands("127.0.0.1", ["display version"])
print("交换机版本信息:\n", version)
return True
except Exception as e:
print("❌ 连接测试失败:", str(e))
return False
async def test_vlan_config(configurator):
"""测试 VLAN 配置"""
try:
result = await configurator.safe_apply(
ip="127.0.0.1", # 本地连接固定用这个地址
"127.0.0.1",
config={
"type": "vlan",
"vlan_id": 100,
"name": "测试VLAN"
"name": "自动化测试VLAN"
}
)
print("✅ 配置结果:", result)
except Exception as e:
print("❌ 配置失败:", str(e))
print("VLAN 配置结果:", result)
# 验证配置
vlan_list = await configurator._send_commands("127.0.0.1", ["display vlan"])
print("当前VLAN列表:\n", vlan_list)
return "success" in result.get("status", "")
except Exception as e:
print("❌ VLAN 配置失败:", str(e))
return False
async def main():
"""主测试流程"""
# 尝试不同端口
for port in [2000, 2010, 2020, 23]:
print(f"\n尝试端口: {port}")
configurator = SwitchConfigurator(
ensp_mode=True,
ensp_port=port,
username="",
password="admin",
timeout=15
)
if await test_connection(configurator):
print(f"✅ 成功连接到端口 {port}")
if await test_vlan_config(configurator):
print("✅ 所有测试通过!")
return
else:
print("⚠️ VLAN 配置失败,继续尝试其他端口...")
else:
print("⚠️ 连接失败,尝试下一个端口...")
print("❌ 所有端口尝试失败,请检查配置")
# 运行测试
if __name__ == "__main__":
asyncio.run(test_ensp())
asyncio.run(main())