mirror of
https://github.com/Jerryplusy/AI-powered-switches.git
synced 2025-10-14 09:49:19 +00:00
Compare commits
8 Commits
aab5adf863
...
90f127b5ce
Author | SHA1 | Date | |
---|---|---|---|
![]() |
90f127b5ce | ||
![]() |
ba1a7c216c | ||
![]() |
71eb1ee79a | ||
![]() |
7c17bb931b | ||
![]() |
2231b8cf82 | ||
![]() |
6e5cd34da7 | ||
![]() |
6f74a80036 | ||
![]() |
60359b54ee |
2
.idea/AI-powered-switches.iml
generated
2
.idea/AI-powered-switches.iml
generated
@ -9,7 +9,7 @@
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.13" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.11 (AI-powered-switches)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
<orderEntry type="library" name="Python 3.13 interpreter library" level="application" />
|
||||
</component>
|
||||
|
@ -1,13 +0,0 @@
|
||||
# 交换机认证配置
|
||||
SWITCH_USERNAME=admin
|
||||
SWITCH_PASSWORD=your_secure_password
|
||||
SWITCH_TIMEOUT=15
|
||||
|
||||
# 硅基流动API配置
|
||||
SILICONFLOW_API_KEY=sk-114514
|
||||
SILICONFLOW_API_URL=https://api.siliconflow.ai/v1
|
||||
|
||||
# FastAPI 配置
|
||||
UVICORN_HOST=0.0.0.0
|
||||
UVICORN_PORT=8000
|
||||
UVICORN_RELOAD=false
|
@ -1,25 +1,49 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, responses
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware import Middleware # 新增导入
|
||||
from src.backend.app.api.endpoints import router
|
||||
from src.backend.app.utils.logger import setup_logging
|
||||
from src.backend.config import settings
|
||||
|
||||
api_app = FastAPI()
|
||||
api_app.include_router(router,prefix="/api")
|
||||
from .services.switch_traffic_monitor import get_switch_monitor
|
||||
from .services.traffic_monitor import traffic_monitor
|
||||
from src.backend.app.api.database import init_db
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
# 初始化数据库
|
||||
init_db()
|
||||
|
||||
# 启动流量监控
|
||||
traffic_monitor.start_monitoring()
|
||||
|
||||
# 设置日志
|
||||
setup_logging()
|
||||
|
||||
# 创建FastAPI应用
|
||||
# 创建FastAPI应用(使用新的中间件配置方式)
|
||||
app = FastAPI(
|
||||
title=settings.APP_NAME,
|
||||
debug=settings.DEBUG,
|
||||
docs_url=f"{settings.API_PREFIX}/docs",
|
||||
redoc_url=f"{settings.API_PREFIX}/redoc",
|
||||
openapi_url=f"{settings.API_PREFIX}/openapi.json"
|
||||
openapi_url=f"{settings.API_PREFIX}/openapi.json",
|
||||
middleware=[ # 这里直接配置中间件
|
||||
Middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# 添加根路径重定向
|
||||
@app.get("/", include_in_schema=False)
|
||||
async def root():
|
||||
return responses.RedirectResponse(url=f"{settings.API_PREFIX}/docs")
|
||||
|
||||
# 添加API路由
|
||||
app.include_router(router, prefix=settings.API_PREFIX)
|
||||
|
||||
return app
|
||||
return app
|
||||
|
||||
app = create_app()
|
@ -1,7 +1,8 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Dict, Any
|
||||
from src.backend.app.services.ai_services import AIService
|
||||
from src.backend.config import settings
|
||||
|
||||
|
||||
class CommandParser:
|
||||
def __init__(self):
|
||||
self.ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
|
||||
@ -19,7 +20,7 @@ class CommandParser:
|
||||
return await self.ai_service.parse_command(command)
|
||||
|
||||
@staticmethod
|
||||
def _try_local_parse(command: str) -> Optional[Dict[str, Any]]:
|
||||
def _try_local_parse(command: str) -> dict[str, str | list[Any]] | dict[str, str] | None:
|
||||
"""
|
||||
尝试本地解析常见命令
|
||||
"""
|
||||
@ -28,13 +29,12 @@ class CommandParser:
|
||||
# VLAN配置
|
||||
if "vlan" in command and "创建" in command:
|
||||
parts = command.split()
|
||||
vlan_id_str = next((p for p in parts if p.isdigit()), None)
|
||||
if vlan_id_str:
|
||||
vlan_id = int(vlan_id_str) # 转换为整数
|
||||
vlan_id = next((p for p in parts if p.isdigit()), None)
|
||||
if vlan_id:
|
||||
return {
|
||||
"type": "vlan",
|
||||
"vlan_id": vlan_id, # 使用整数
|
||||
"name": f"VLAN{vlan_id}", # 使用转换后的整数
|
||||
"vlan_id": vlan_id,
|
||||
"name": f"VLAN{vlan_id}",
|
||||
"interfaces": []
|
||||
}
|
||||
|
||||
@ -61,9 +61,9 @@ class CommandParser:
|
||||
config["description"] = description
|
||||
|
||||
if "vlan" in command:
|
||||
vlan_id_str = next((p for p in parts if p.isdigit()), None)
|
||||
if vlan_id_str:
|
||||
config["vlan"] = int(vlan_id_str) # 转换为整数
|
||||
vlan_id = next((p for p in parts if p.isdigit()), None)
|
||||
if vlan_id:
|
||||
config["vlan"] = vlan_id
|
||||
|
||||
return config
|
||||
|
||||
|
17
src/backend/app/api/database.py
Normal file
17
src/backend/app/api/database.py
Normal file
@ -0,0 +1,17 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./traffic_monitor.db"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库"""
|
||||
# 删除多余的导入
|
||||
Base.metadata.create_all(bind=engine)
|
@ -1,97 +1,73 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
import asyncio
|
||||
from fastapi.responses import HTMLResponse
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
import base64
|
||||
|
||||
from ..services.switch_traffic_monitor import get_switch_monitor
|
||||
from ..utils import logger
|
||||
from ...app.services.ai_services import AIService
|
||||
from ...app.api.network_config import SwitchConfigurator
|
||||
from ...config import settings
|
||||
from ..services.network_scanner import NetworkScanner
|
||||
from ..api.network_config import SwitchConfigurator
|
||||
from ...app.services.traffic_monitor import traffic_monitor
|
||||
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["API"])
|
||||
|
||||
|
||||
|
||||
router = APIRouter(prefix="", tags=["API"])
|
||||
scanner = NetworkScanner()
|
||||
|
||||
# 添加根路径路由
|
||||
@router.get("/", include_in_schema=False)
|
||||
async def root():
|
||||
return {
|
||||
"message": "欢迎使用AI交换机配置系统",
|
||||
"docs": f"{settings.API_PREFIX}/docs",
|
||||
"redoc": f"{settings.API_PREFIX}/redoc",
|
||||
"endpoints": [
|
||||
"/parse_command",
|
||||
"/apply_config",
|
||||
"/scan_network",
|
||||
"/list_devices",
|
||||
"/batch_apply_config"
|
||||
"/traffic/switch/current", # 添加交换机流量端点
|
||||
"/traffic/switch/history" # 添加交换机历史流量端点
|
||||
]
|
||||
}
|
||||
|
||||
# 添加favicon处理
|
||||
@router.get("/favicon.ico", include_in_schema=False)
|
||||
async def favicon():
|
||||
return Response(status_code=204)
|
||||
|
||||
# ====================
|
||||
# 请求模型
|
||||
# ====================
|
||||
class BatchConfigRequest(BaseModel):
|
||||
config: Dict
|
||||
switch_ips: List[str]
|
||||
config: dict
|
||||
switch_ips: List[str] # 支持多个IP
|
||||
|
||||
|
||||
class CommandRequest(BaseModel):
|
||||
command: str
|
||||
|
||||
|
||||
class ConfigRequest(BaseModel):
|
||||
config: Dict
|
||||
switch_ip: str
|
||||
|
||||
|
||||
# ====================
|
||||
# API端点
|
||||
# ====================
|
||||
@router.post("/batch_apply_config")
|
||||
async def batch_apply_config(request: BatchConfigRequest):
|
||||
"""
|
||||
批量配置交换机
|
||||
- 支持同时配置多台设备
|
||||
- 自动处理连接池
|
||||
- 返回每个设备的详细结果
|
||||
"""
|
||||
configurator = SwitchConfigurator(
|
||||
username=settings.SWITCH_USERNAME,
|
||||
password=settings.SWITCH_PASSWORD,
|
||||
timeout=settings.SWITCH_TIMEOUT
|
||||
)
|
||||
|
||||
results = {}
|
||||
try:
|
||||
for ip in request.switch_ips:
|
||||
try:
|
||||
# 使用公开的apply_config方法
|
||||
results[ip] = await configurator.apply_config(ip, request.config)
|
||||
except Exception as e:
|
||||
results[ip] = {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
return {"results": results}
|
||||
finally:
|
||||
await configurator.close()
|
||||
for ip in request.switch_ips:
|
||||
try:
|
||||
configurator = SwitchConfigurator()
|
||||
results[ip] = await configurator.apply_config(ip, request.config)
|
||||
except Exception as e:
|
||||
results[ip] = str(e)
|
||||
return {"results": results}
|
||||
|
||||
|
||||
@router.post("/apply_config", response_model=Dict)
|
||||
async def apply_config(request: ConfigRequest):
|
||||
"""
|
||||
单设备配置
|
||||
- 更详细的错误处理
|
||||
- 自动备份和回滚
|
||||
"""
|
||||
configurator = SwitchConfigurator(
|
||||
username=settings.SWITCH_USERNAME,
|
||||
password=settings.SWITCH_PASSWORD,
|
||||
timeout=settings.SWITCH_TIMEOUT
|
||||
)
|
||||
|
||||
try:
|
||||
result = await configurator.apply_config(request.switch_ip, request.config)
|
||||
if result["status"] != "success":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=result.get("error", "配置失败")
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
await configurator.close()
|
||||
|
||||
|
||||
# ====================
|
||||
# 其他原有端点(保持不动)
|
||||
# ====================
|
||||
@router.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "Hello World"}
|
||||
|
||||
|
||||
@router.get("/scan_network", summary="扫描网络中的交换机")
|
||||
async def scan_network(subnet: str = "192.168.1.0/24"):
|
||||
try:
|
||||
@ -104,23 +80,25 @@ async def scan_network(subnet: str = "192.168.1.0/24"):
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"扫描失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/list_devices", summary="列出已发现的交换机")
|
||||
async def list_devices():
|
||||
return {
|
||||
"devices": scanner.load_cached_devices()
|
||||
}
|
||||
|
||||
class CommandRequest(BaseModel):
|
||||
command: str
|
||||
|
||||
@router.post("/parse_command", response_model=Dict)
|
||||
class ConfigRequest(BaseModel):
|
||||
config: dict
|
||||
switch_ip: str
|
||||
|
||||
@router.post("/parse_command", response_model=dict)
|
||||
async def parse_command(request: CommandRequest):
|
||||
"""
|
||||
解析中文命令并返回JSON配置
|
||||
- 依赖AI服务
|
||||
- 返回标准化配置
|
||||
"""
|
||||
try:
|
||||
from ..services.ai_services import AIService # 延迟导入避免循环依赖
|
||||
ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
|
||||
config = await ai_service.parse_command(request.command)
|
||||
return {"success": True, "config": config}
|
||||
@ -128,4 +106,285 @@ async def parse_command(request: CommandRequest):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to parse command: {str(e)}"
|
||||
)
|
||||
)
|
||||
|
||||
@router.post("/apply_config", response_model=dict)
|
||||
async def apply_config(request: ConfigRequest):
|
||||
"""
|
||||
应用配置到交换机
|
||||
"""
|
||||
try:
|
||||
configurator = SwitchConfigurator(
|
||||
username=settings.SWITCH_USERNAME,
|
||||
password=settings.SWITCH_PASSWORD,
|
||||
timeout=settings.SWITCH_TIMEOUT
|
||||
)
|
||||
result = await configurator.safe_apply(request.switch_ip, request.config)
|
||||
return {"success": True, "result": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to apply config: {str(e)}"
|
||||
)
|
||||
@router.get("/traffic/interfaces", summary="获取所有网络接口")
|
||||
async def get_network_interfaces():
|
||||
return {
|
||||
"interfaces": traffic_monitor.get_interfaces()
|
||||
}
|
||||
|
||||
@router.get("/traffic/current", summary="获取当前流量数据")
|
||||
async def get_current_traffic(interface: str = None):
|
||||
return traffic_monitor.get_current_traffic(interface)
|
||||
|
||||
@router.get("/traffic/history", summary="获取流量历史数据")
|
||||
async def get_traffic_history(interface: str = None, limit: int = 100):
|
||||
history = traffic_monitor.get_traffic_history(interface)
|
||||
return {
|
||||
"sent": history["sent"][-limit:],
|
||||
"recv": history["recv"][-limit:],
|
||||
"time": [t.isoformat() for t in history["time"]][-limit:]
|
||||
}
|
||||
|
||||
@router.get("/traffic/records", summary="获取流量记录")
|
||||
async def get_traffic_records(interface: str = None, limit: int = 100):
|
||||
with SessionLocal() as session:
|
||||
query = session.query(TrafficRecord)
|
||||
if interface:
|
||||
query = query.filter(TrafficRecord.interface == interface)
|
||||
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
|
||||
return [record.to_dict() for record in records]
|
||||
|
||||
@router.websocket("/ws/traffic")
|
||||
async def websocket_traffic(websocket: WebSocket):
|
||||
"""实时流量WebSocket"""
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
# 获取所有接口的当前流量
|
||||
traffic_data = traffic_monitor.get_current_traffic()
|
||||
await websocket.send_json(traffic_data)
|
||||
await asyncio.sleep(1) # 每秒更新一次
|
||||
except WebSocketDisconnect:
|
||||
print("客户端断开连接")
|
||||
|
||||
|
||||
@router.get("/", include_in_schema=False)
|
||||
async def root():
|
||||
return {
|
||||
"message": "欢迎使用AI交换机配置系统",
|
||||
"docs": f"{settings.API_PREFIX}/docs",
|
||||
"redoc": f"{settings.API_PREFIX}/redoc",
|
||||
"endpoints": [
|
||||
"/parse_command",
|
||||
"/apply_config",
|
||||
"/scan_network",
|
||||
"/list_devices",
|
||||
"/batch_apply_config",
|
||||
"/traffic/switch/current", # 添加交换机流量端点
|
||||
"/traffic/switch/history" # 添加交换机历史流量端点
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# ... 其他路由保持不变 ...
|
||||
|
||||
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
|
||||
async def get_switch_interfaces(switch_ip: str):
|
||||
"""获取指定交换机的所有接口"""
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
interfaces = list(monitor.interface_oids.keys())
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interfaces": interfaces
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取交换机接口失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取接口失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
|
||||
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
|
||||
"""获取交换机的当前流量数据"""
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
|
||||
# 如果没有指定接口,获取所有接口的流量
|
||||
if not interface:
|
||||
traffic_data = {}
|
||||
for iface in monitor.interface_oids:
|
||||
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"traffic": traffic_data
|
||||
}
|
||||
|
||||
# 获取指定接口的流量
|
||||
return await get_interface_current_traffic(switch_ip, interface)
|
||||
except Exception as e:
|
||||
logger.error(f"获取交换机流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取流量失败: {str(e)}")
|
||||
|
||||
|
||||
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
|
||||
"""获取指定交换机接口的当前流量数据"""
|
||||
try:
|
||||
with SessionLocal() as session:
|
||||
# 获取最新记录
|
||||
record = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.switch_ip == switch_ip,
|
||||
SwitchTrafficRecord.interface == interface
|
||||
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
|
||||
|
||||
if not record:
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"rate_in": 0.0,
|
||||
"rate_out": 0.0,
|
||||
"bytes_in": 0,
|
||||
"bytes_out": 0
|
||||
}
|
||||
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"rate_in": record.rate_in,
|
||||
"rate_out": record.rate_out,
|
||||
"bytes_in": record.bytes_in,
|
||||
"bytes_out": record.bytes_out
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取接口流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
|
||||
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
|
||||
"""获取交换机的流量历史数据"""
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
|
||||
# 如果没有指定接口,返回所有接口的历史数据
|
||||
if not interface:
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"history": monitor.get_traffic_history()
|
||||
}
|
||||
|
||||
# 获取指定接口的历史数据
|
||||
with SessionLocal() as session:
|
||||
# 计算时间范围
|
||||
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
||||
|
||||
# 查询历史记录
|
||||
records = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.switch_ip == switch_ip,
|
||||
SwitchTrafficRecord.interface == interface,
|
||||
SwitchTrafficRecord.timestamp >= time_threshold
|
||||
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
|
||||
|
||||
# 提取数据
|
||||
history_data = {
|
||||
"in": [record.rate_in for record in records],
|
||||
"out": [record.rate_out for record in records],
|
||||
"time": [record.timestamp.isoformat() for record in records]
|
||||
}
|
||||
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"history": history_data
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取历史流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
|
||||
|
||||
|
||||
@router.websocket("/ws/traffic/switch")
|
||||
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
|
||||
"""交换机实时流量WebSocket"""
|
||||
await websocket.accept()
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
|
||||
while True:
|
||||
# 获取流量数据
|
||||
if interface:
|
||||
# 单个接口
|
||||
traffic_data = await get_interface_current_traffic(switch_ip, interface)
|
||||
await websocket.send_json(traffic_data)
|
||||
else:
|
||||
# 所有接口
|
||||
traffic_data = {}
|
||||
for iface in monitor.interface_oids:
|
||||
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
|
||||
|
||||
await websocket.send_json({
|
||||
"switch_ip": switch_ip,
|
||||
"traffic": traffic_data
|
||||
})
|
||||
|
||||
# 每秒更新一次
|
||||
await asyncio.sleep(1)
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
|
||||
except Exception as e:
|
||||
logger.error(f"交换机流量WebSocket错误: {str(e)}")
|
||||
await websocket.close(code=1011, reason=str(e))
|
||||
|
||||
|
||||
# 交换机流量可视化
|
||||
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
|
||||
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
|
||||
"""生成交换机流量图表"""
|
||||
try:
|
||||
# 获取历史数据
|
||||
history = await get_switch_traffic_history(switch_ip, interface, minutes)
|
||||
history_data = history["history"]
|
||||
|
||||
# 提取数据点
|
||||
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
|
||||
in_rates = history_data["in"]
|
||||
out_rates = history_data["out"]
|
||||
|
||||
# 创建图表
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
|
||||
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
|
||||
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
|
||||
plt.xlabel("时间")
|
||||
plt.ylabel("流量 (字节/秒)")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
|
||||
# 转换为HTML图像
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
|
||||
plt.close()
|
||||
|
||||
return f"""
|
||||
<html>
|
||||
<head>
|
||||
<title>交换机流量监控</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
||||
.container {{ max-width: 900px; margin: 0 auto; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
|
||||
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
|
||||
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.error(f"生成流量图表失败: {str(e)}")
|
||||
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)
|
@ -1,14 +1,14 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import telnetlib3
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
import aiofiles
|
||||
import asyncssh
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
|
||||
# ----------------------
|
||||
@ -39,7 +39,7 @@ class SSHConnectionException(SwitchConfigException):
|
||||
|
||||
|
||||
# ----------------------
|
||||
# 核心配置器
|
||||
# 核心配置器(完整双模式)
|
||||
# ----------------------
|
||||
class SwitchConfigurator:
|
||||
def __init__(
|
||||
@ -63,35 +63,12 @@ class SwitchConfigurator:
|
||||
self.ensp_port = ensp_port
|
||||
self.ensp_delay = ensp_command_delay
|
||||
self.ssh_options = ssh_options
|
||||
self._connection_pool = {} # SSH连接池
|
||||
|
||||
# ====================
|
||||
# 公开API方法
|
||||
# ====================
|
||||
async def apply_config(self, ip: str, config: Union[Dict, SwitchConfig]) -> Dict:
|
||||
"""
|
||||
应用配置到交换机(主入口)
|
||||
返回格式:
|
||||
{
|
||||
"status": "success"|"failed",
|
||||
"output": str,
|
||||
"backup_path": str,
|
||||
"error": Optional[str],
|
||||
"timestamp": str
|
||||
}
|
||||
"""
|
||||
async def apply_config(self, ip: str, config: Union[Dict, SwitchConfig]) -> str:
|
||||
"""实际配置逻辑"""
|
||||
if isinstance(config, dict):
|
||||
config = SwitchConfig(**config)
|
||||
|
||||
result = await self.safe_apply(ip, config)
|
||||
result["timestamp"] = datetime.now().isoformat()
|
||||
return result
|
||||
|
||||
# ====================
|
||||
# 内部实现方法
|
||||
# ====================
|
||||
async def _apply_config(self, ip: str, config: SwitchConfig) -> str:
|
||||
"""实际配置逻辑"""
|
||||
commands = (
|
||||
self._generate_ensp_commands(config)
|
||||
if self.ensp_mode
|
||||
@ -107,74 +84,60 @@ class SwitchConfigurator:
|
||||
else await self._send_ssh_commands(ip, commands)
|
||||
)
|
||||
|
||||
# --------- eNSP模式专用 ---------
|
||||
async def _send_ensp_commands(self, ip: str, commands: List[str]) -> str:
|
||||
"""Telnet协议执行(eNSP)"""
|
||||
try:
|
||||
# 修复点:使用正确的timeout参数
|
||||
reader, writer = await telnetlib3.open_connection(
|
||||
host=ip,
|
||||
port=self.ensp_port,
|
||||
connect_minwait=self.timeout,
|
||||
connect_minwait=self.timeout, # telnetlib3的实际可用参数
|
||||
connect_maxwait=self.timeout
|
||||
)
|
||||
|
||||
# 登录流程
|
||||
await reader.readuntil(b"Username:")
|
||||
writer.write(f"{self.username}\n")
|
||||
await reader.readuntil(b"Password:")
|
||||
writer.write(f"{self.password}\n")
|
||||
await asyncio.sleep(1)
|
||||
# 登录流程(增加超时处理)
|
||||
try:
|
||||
await asyncio.wait_for(reader.readuntil(b"Username:"), timeout=self.timeout)
|
||||
writer.write(f"{self.username}\n")
|
||||
|
||||
await asyncio.wait_for(reader.readuntil(b"Password:"), timeout=self.timeout)
|
||||
writer.write(f"{self.password}\n")
|
||||
|
||||
# 等待登录完成
|
||||
await asyncio.sleep(1)
|
||||
except asyncio.TimeoutError:
|
||||
raise EnspConnectionException("登录超时")
|
||||
|
||||
# 执行命令
|
||||
output = ""
|
||||
for cmd in commands:
|
||||
writer.write(f"{cmd}\n")
|
||||
await asyncio.sleep(self.ensp_delay)
|
||||
while True:
|
||||
try:
|
||||
await writer.drain() # 确保命令发送完成
|
||||
|
||||
# 读取响应(增加超时处理)
|
||||
try:
|
||||
while True:
|
||||
data = await asyncio.wait_for(reader.read(1024), timeout=1)
|
||||
if not data:
|
||||
break
|
||||
output += data
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
continue # 单次读取超时不视为错误
|
||||
|
||||
# 关闭连接
|
||||
writer.close()
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except:
|
||||
logging.debug("连接关闭时出现异常", exc_info=True) # 至少记录异常信息
|
||||
pass
|
||||
|
||||
return output
|
||||
except Exception as e:
|
||||
raise EnspConnectionException(f"eNSP连接失败: {str(e)}")
|
||||
|
||||
async def _clean_idle_connections(self):
|
||||
"""连接池清理机制"""
|
||||
now = time.time()
|
||||
for ip, (conn, last_used) in list(self._connection_pool.items()):
|
||||
if now - last_used > 300: # 5分钟空闲超时
|
||||
conn.close()
|
||||
del self._connection_pool[ip]
|
||||
|
||||
async def _send_ssh_commands(self, ip: str, commands: List[str]) -> str:
|
||||
"""SSH协议执行"""
|
||||
async with self.semaphore:
|
||||
try:
|
||||
if ip not in self._connection_pool:
|
||||
self._connection_pool[ip] = await asyncssh.connect(
|
||||
host=ip,
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
connect_timeout=self.timeout,
|
||||
**self.ssh_options
|
||||
)
|
||||
|
||||
results = []
|
||||
for cmd in commands:
|
||||
result = await self._connection_pool[ip].run(cmd)
|
||||
results.append(result.stdout)
|
||||
return "\n".join(results)
|
||||
except asyncssh.Error as e:
|
||||
if ip in self._connection_pool:
|
||||
self._connection_pool[ip].close()
|
||||
del self._connection_pool[ip]
|
||||
raise SSHConnectionException(f"SSH操作失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _generate_ensp_commands(config: SwitchConfig) -> List[str]:
|
||||
"""生成eNSP命令序列"""
|
||||
@ -194,6 +157,28 @@ class SwitchConfigurator:
|
||||
commands.append("return")
|
||||
return [c for c in commands if c.strip()]
|
||||
|
||||
# --------- SSH模式专用(使用AsyncSSH) ---------
|
||||
async def _send_ssh_commands(self, ip: str, commands: List[str]) -> str:
|
||||
"""AsyncSSH执行命令"""
|
||||
async with self.semaphore:
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
host=ip,
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
connect_timeout=self.timeout, # AsyncSSH的正确参数名
|
||||
**self.ssh_options
|
||||
) as conn:
|
||||
results = []
|
||||
for cmd in commands:
|
||||
result = await conn.run(cmd, check=True)
|
||||
results.append(result.stdout)
|
||||
return "\n".join(results)
|
||||
except asyncssh.Error as e:
|
||||
raise SSHConnectionException(f"SSH操作失败: {str(e)}")
|
||||
except Exception as e:
|
||||
raise SSHConnectionException(f"连接异常: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _generate_standard_commands(config: SwitchConfig) -> List[str]:
|
||||
"""生成标准CLI命令"""
|
||||
@ -211,6 +196,16 @@ class SwitchConfigurator:
|
||||
])
|
||||
return commands
|
||||
|
||||
# --------- 通用功能 ---------
|
||||
async def _validate_config(self, ip: str, config: SwitchConfig) -> bool:
|
||||
"""验证配置是否生效"""
|
||||
current = await self._get_current_config(ip)
|
||||
if config.type == "vlan":
|
||||
return f"vlan {config.vlan_id}" in current
|
||||
elif config.type == "interface" and config.vlan:
|
||||
return f"switchport access vlan {config.vlan}" in current
|
||||
return True
|
||||
|
||||
async def _get_current_config(self, ip: str) -> str:
|
||||
"""获取当前配置"""
|
||||
commands = (
|
||||
@ -259,7 +254,7 @@ class SwitchConfigurator:
|
||||
"""安全配置应用(自动回滚)"""
|
||||
backup_path = await self._backup_config(ip)
|
||||
try:
|
||||
result = await self._apply_config(ip, config)
|
||||
result = await self.apply_config(ip, config)
|
||||
if not await self._validate_config(ip, config):
|
||||
raise SwitchConfigException("配置验证失败")
|
||||
return {
|
||||
@ -276,17 +271,40 @@ class SwitchConfigurator:
|
||||
"restore_success": restore_status
|
||||
}
|
||||
|
||||
async def _validate_config(self, ip: str, config: SwitchConfig) -> bool:
|
||||
"""验证配置是否生效"""
|
||||
current = await self._get_current_config(ip)
|
||||
if config.type == "vlan":
|
||||
return f"vlan {config.vlan_id}" in current
|
||||
elif config.type == "interface" and config.vlan:
|
||||
return f"switchport access vlan {config.vlan}" in current
|
||||
return True
|
||||
|
||||
async def close(self):
|
||||
"""清理所有连接"""
|
||||
for conn in self._connection_pool.values():
|
||||
conn.close()
|
||||
self._connection_pool.clear()
|
||||
# ----------------------
|
||||
# 使用示例
|
||||
# ----------------------
|
||||
async def demo():
|
||||
# 示例1: eNSP设备配置(Telnet模式)
|
||||
ensp_configurator = SwitchConfigurator(
|
||||
ensp_mode=True,
|
||||
ensp_port=2000,
|
||||
username="admin",
|
||||
password="admin",
|
||||
timeout=15
|
||||
)
|
||||
ensp_result = await ensp_configurator.safe_apply("127.0.0.1", {
|
||||
"type": "interface",
|
||||
"interface": "GigabitEthernet0/0/1",
|
||||
"vlan": 100,
|
||||
"ip_address": "192.168.1.2 255.255.255.0"
|
||||
})
|
||||
print("eNSP配置结果:", ensp_result)
|
||||
|
||||
# 示例2: 真实设备配置(SSH模式)
|
||||
ssh_configurator = SwitchConfigurator(
|
||||
username="cisco",
|
||||
password="cisco123",
|
||||
timeout=15
|
||||
)
|
||||
ssh_result = await ssh_configurator.safe_apply("192.168.1.1", {
|
||||
"type": "vlan",
|
||||
"vlan_id": 200,
|
||||
"name": "Production"
|
||||
})
|
||||
print("SSH配置结果:", ssh_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(demo())
|
52
src/backend/app/models/traffic_models.py
Normal file
52
src/backend/app/models/traffic_models.py
Normal file
@ -0,0 +1,52 @@
|
||||
# 添加正确的导入
|
||||
from src.backend.app.api.database import Base # 修复:导入 Base
|
||||
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, Float
|
||||
|
||||
|
||||
class TrafficRecord(Base):
|
||||
"""网络流量记录模型"""
|
||||
__tablename__ = "traffic_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
interface = Column(String(50), index=True)
|
||||
bytes_sent = Column(Integer)
|
||||
bytes_recv = Column(Integer)
|
||||
packets_sent = Column(Integer)
|
||||
packets_recv = Column(Integer)
|
||||
timestamp = Column(DateTime)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"interface": self.interface,
|
||||
"bytes_sent": self.bytes_sent,
|
||||
"bytes_recv": self.bytes_recv,
|
||||
"packets_sent": self.packets_sent,
|
||||
"packets_recv": self.packets_recv,
|
||||
"timestamp": self.timestamp.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class SwitchTrafficRecord(Base):
|
||||
__tablename__ = "switch_traffic_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
switch_ip = Column(String(50), index=True)
|
||||
interface = Column(String(50))
|
||||
bytes_in = Column(BigInteger) # 累计流入字节数
|
||||
bytes_out = Column(BigInteger) # 累计流出字节数
|
||||
rate_in = Column(Float) # 当前流入速率(字节/秒)
|
||||
rate_out = Column(Float) # 当前流出速率(字节/秒)
|
||||
timestamp = Column(DateTime)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"switch_ip": self.switch_ip,
|
||||
"interface": self.interface,
|
||||
"bytes_in": self.bytes_in,
|
||||
"bytes_out": self.bytes_out,
|
||||
"rate_in": self.rate_in,
|
||||
"rate_out": self.rate_out,
|
||||
"timestamp": self.timestamp.isoformat()
|
||||
}
|
@ -26,7 +26,7 @@ class AIService:
|
||||
"""
|
||||
|
||||
data = {
|
||||
"model": "text-davinci-003",
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"prompt": prompt,
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.3
|
||||
@ -35,7 +35,7 @@ class AIService:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.api_url}/completions",
|
||||
f"{self.api_url}/chat/completions",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
timeout=30
|
||||
|
@ -1,3 +1,6 @@
|
||||
|
||||
import os
|
||||
os.environ["PATH"] += ";C:\\Program Files (x86)\\Nmap"
|
||||
import nmap
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
196
src/backend/app/services/switch_traffic_monitor.py
Normal file
196
src/backend/app/services/switch_traffic_monitor.py
Normal file
@ -0,0 +1,196 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
from typing import Optional, List, Dict
|
||||
from pysnmp.hlapi import *
|
||||
from ..models.traffic_models import SwitchTrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
from ..utils.logger import logger
|
||||
|
||||
|
||||
class SwitchTrafficMonitor:
|
||||
def __init__(
|
||||
self,
|
||||
switch_ip: str,
|
||||
community: str = 'public',
|
||||
update_interval: int = 5,
|
||||
interfaces: Optional[List[str]] = None
|
||||
):
|
||||
self.switch_ip = switch_ip
|
||||
self.community = community
|
||||
self.update_interval = update_interval
|
||||
self.running = False
|
||||
self.task = None
|
||||
self.interface_history = {}
|
||||
self.history = {
|
||||
"in": deque(maxlen=300),
|
||||
"out": deque(maxlen=300),
|
||||
"time": deque(maxlen=300)
|
||||
}
|
||||
|
||||
# 基本接口OID映射
|
||||
self.interface_oids = {
|
||||
"GigabitEthernet0/0/1": {
|
||||
"in": '1.3.6.1.2.1.2.2.1.10.1', # ifInOctets
|
||||
"out": '1.3.6.1.2.1.2.2.1.16.1' # ifOutOctets
|
||||
},
|
||||
"GigabitEthernet0/0/24": {
|
||||
"in": '1.3.6.1.2.1.2.2.1.10.24',
|
||||
"out": '1.3.6.1.2.1.2.2.1.16.24'
|
||||
}
|
||||
}
|
||||
|
||||
# 接口过滤
|
||||
if interfaces:
|
||||
self.interface_oids = {
|
||||
iface: oid for iface, oid in self.interface_oids.items()
|
||||
if iface in interfaces
|
||||
}
|
||||
logger.info(f"监控指定接口: {', '.join(interfaces)}")
|
||||
else:
|
||||
logger.info("监控所有接口")
|
||||
|
||||
def start_monitoring(self):
|
||||
"""启动交换机流量监控"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._monitor_loop())
|
||||
logger.success(f"交换机流量监控已启动: {self.switch_ip}")
|
||||
|
||||
async def stop_monitoring(self):
|
||||
"""停止监控"""
|
||||
if self.running:
|
||||
self.running = False
|
||||
if self.task:
|
||||
self.task.cancel()
|
||||
try:
|
||||
await self.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info(f"交换机流量监控已停止: {self.switch_ip}")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""监控主循环"""
|
||||
last_values = {iface: {"in": 0, "out": 0} for iface in self.interface_oids}
|
||||
last_time = datetime.now()
|
||||
|
||||
while self.running:
|
||||
await asyncio.sleep(self.update_interval)
|
||||
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
elapsed = (current_time - last_time).total_seconds()
|
||||
|
||||
# 获取所有接口流量
|
||||
for iface, oids in self.interface_oids.items():
|
||||
in_octets = self._snmp_get(oids["in"])
|
||||
out_octets = self._snmp_get(oids["out"])
|
||||
|
||||
if in_octets is not None and out_octets is not None:
|
||||
# 计算速率(字节/秒)
|
||||
# 修复字典访问问题
|
||||
iface_values = last_values[iface]
|
||||
in_rate = (in_octets - iface_values["in"]) / elapsed if iface_values["in"] > 0 else 0
|
||||
out_rate = (out_octets - iface_values["out"]) / elapsed if iface_values["out"] > 0 else 0
|
||||
|
||||
# 保存历史数据
|
||||
self.history["in"].append(in_rate)
|
||||
self.history["out"].append(out_rate)
|
||||
self.history["time"].append(current_time)
|
||||
|
||||
# 保存到数据库
|
||||
self._save_to_db(iface, in_octets, out_octets, in_rate, out_rate, current_time)
|
||||
|
||||
# 更新最后的值
|
||||
iface_values["in"] = in_octets
|
||||
iface_values["out"] = out_octets
|
||||
|
||||
last_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"监控交换机流量出错: {str(e)}")
|
||||
|
||||
def _snmp_get(self, oid) -> Optional[int]:
|
||||
"""执行SNMP GET请求"""
|
||||
try:
|
||||
# 正确格式化的SNMP请求
|
||||
cmd = getCmd(
|
||||
SnmpEngine(),
|
||||
CommunityData(self.community),
|
||||
UdpTransportTarget((self.switch_ip, 161)),
|
||||
ContextData(),
|
||||
ObjectType(ObjectIdentity(oid)))
|
||||
|
||||
# 执行命令
|
||||
errorIndication, errorStatus, errorIndex, varBinds = next(cmd)
|
||||
except Exception as e:
|
||||
logger.error(f"SNMP请求失败: {str(e)}")
|
||||
return None
|
||||
|
||||
if errorIndication:
|
||||
logger.error(f"SNMP错误: {errorIndication}")
|
||||
return None
|
||||
elif errorStatus:
|
||||
try:
|
||||
# 修复括号问题
|
||||
if errorIndex:
|
||||
index_val = int(errorIndex) - 1
|
||||
error_item = varBinds[index_val] if index_val < len(varBinds) else '?'
|
||||
else:
|
||||
error_item = '?'
|
||||
|
||||
error_msg = f"SNMP错误: {errorStatus.prettyPrint()} at {error_item}"
|
||||
logger.error(error_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"解析SNMP错误失败: {str(e)}")
|
||||
return None
|
||||
else:
|
||||
for varBind in varBinds:
|
||||
try:
|
||||
return int(varBind[1])
|
||||
except Exception as e:
|
||||
logger.error(f"转换SNMP值失败: {str(e)}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _save_to_db(self, interface: str, in_octets: int, out_octets: int,
|
||||
in_rate: float, out_rate: float, timestamp: datetime):
|
||||
"""保存流量数据到数据库"""
|
||||
try:
|
||||
with SessionLocal() as session:
|
||||
record = SwitchTrafficRecord(
|
||||
switch_ip=self.switch_ip,
|
||||
interface=interface,
|
||||
bytes_in=in_octets,
|
||||
bytes_out=out_octets,
|
||||
rate_in=in_rate,
|
||||
rate_out=out_rate,
|
||||
timestamp=timestamp
|
||||
)
|
||||
session.add(record)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"保存流量数据到数据库失败: {str(e)}")
|
||||
|
||||
def get_traffic_history(self) -> Dict[str, List]:
|
||||
"""获取流量历史数据"""
|
||||
return {
|
||||
"in": list(self.history["in"]),
|
||||
"out": list(self.history["out"]),
|
||||
"time": list(self.history["time"])
|
||||
}
|
||||
|
||||
|
||||
# 全局监控器字典(支持多个交换机)
|
||||
switch_monitors = {}
|
||||
|
||||
|
||||
def get_switch_monitor(switch_ip: str, community: str = 'public', interfaces: Optional[List[str]] = None):
|
||||
"""获取或创建交换机监控器(添加接口过滤参数)"""
|
||||
if switch_ip not in switch_monitors:
|
||||
switch_monitors[switch_ip] = SwitchTrafficMonitor(
|
||||
switch_ip,
|
||||
community,
|
||||
interfaces=interfaces
|
||||
)
|
||||
return switch_monitors[switch_ip]
|
147
src/backend/app/services/traffic_monitor.py
Normal file
147
src/backend/app/services/traffic_monitor.py
Normal file
@ -0,0 +1,147 @@
|
||||
import psutil
|
||||
import time
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
|
||||
from ..models.traffic_models import TrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal # 修复:导入 SessionLocal
|
||||
|
||||
|
||||
class TrafficMonitor:
|
||||
def __init__(self, history_size: int = 300):
|
||||
self.history_size = history_size # 保存历史大小
|
||||
self.history = {
|
||||
"sent": deque(maxlen=history_size),
|
||||
"recv": deque(maxlen=history_size),
|
||||
"time": deque(maxlen=history_size),
|
||||
"interfaces": {}
|
||||
}
|
||||
self.running = False
|
||||
self.task = None
|
||||
self.update_interval = 1.0 # 秒
|
||||
|
||||
@staticmethod
|
||||
def get_interfaces() -> List[str]:
|
||||
"""获取所有网络接口名称"""
|
||||
return list(psutil.net_io_counters(pernic=True).keys())
|
||||
|
||||
def start_monitoring(self):
|
||||
"""启动流量监控"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._monitor_loop())
|
||||
print("流量监控已启动")
|
||||
|
||||
async def stop_monitoring(self):
|
||||
"""停止流量监控"""
|
||||
if self.running:
|
||||
self.running = False
|
||||
self.task.cancel()
|
||||
try:
|
||||
await self.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
print("流量监控已停止")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""监控主循环"""
|
||||
last_stats = psutil.net_io_counters(pernic=True)
|
||||
last_time = time.time()
|
||||
|
||||
while self.running:
|
||||
await asyncio.sleep(self.update_interval)
|
||||
|
||||
current_time = time.time()
|
||||
current_stats = psutil.net_io_counters(pernic=True)
|
||||
elapsed = current_time - last_time
|
||||
|
||||
# 计算每个接口的流量速率
|
||||
for iface in current_stats:
|
||||
if iface not in self.history["interfaces"]:
|
||||
# 修复:使用 self.history_size
|
||||
self.history["interfaces"][iface] = {
|
||||
"sent": deque(maxlen=self.history_size),
|
||||
"recv": deque(maxlen=self.history_size)
|
||||
}
|
||||
|
||||
if iface in last_stats:
|
||||
sent_rate = (current_stats[iface].bytes_sent - last_stats[iface].bytes_sent) / elapsed
|
||||
recv_rate = (current_stats[iface].bytes_recv - last_stats[iface].bytes_recv) / elapsed
|
||||
|
||||
# 保存到历史数据
|
||||
self.history["sent"].append(sent_rate)
|
||||
self.history["recv"].append(recv_rate)
|
||||
self.history["time"].append(datetime.now())
|
||||
|
||||
# 保存到接口特定历史
|
||||
self.history["interfaces"][iface]["sent"].append(sent_rate)
|
||||
self.history["interfaces"][iface]["recv"].append(recv_rate)
|
||||
|
||||
# 保存到数据库
|
||||
self._save_to_db(current_stats)
|
||||
|
||||
last_stats = current_stats
|
||||
last_time = current_time
|
||||
|
||||
@staticmethod
|
||||
def _save_to_db(stats):
|
||||
"""保存流量数据到数据库"""
|
||||
with SessionLocal() as session:
|
||||
for iface, counters in stats.items():
|
||||
record = TrafficRecord(
|
||||
interface=iface,
|
||||
bytes_sent=counters.bytes_sent,
|
||||
bytes_recv=counters.bytes_recv,
|
||||
packets_sent=counters.packets_sent,
|
||||
packets_recv=counters.packets_recv,
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
session.add(record)
|
||||
session.commit()
|
||||
|
||||
def get_current_traffic(self, interface: Optional[str] = None) -> Dict:
|
||||
"""获取当前流量数据"""
|
||||
stats = psutil.net_io_counters(pernic=True)
|
||||
|
||||
if interface:
|
||||
if interface in stats:
|
||||
return self._format_interface_stats(stats[interface])
|
||||
return {}
|
||||
|
||||
return {iface: self._format_interface_stats(data) for iface, data in stats.items()}
|
||||
|
||||
@staticmethod
|
||||
def _format_interface_stats(counters) -> Dict:
|
||||
"""格式化接口统计数据"""
|
||||
return {
|
||||
"bytes_sent": counters.bytes_sent,
|
||||
"bytes_recv": counters.bytes_recv,
|
||||
"packets_sent": counters.packets_sent,
|
||||
"packets_recv": counters.packets_recv,
|
||||
"errin": counters.errin,
|
||||
"errout": counters.errout,
|
||||
"dropin": counters.dropin,
|
||||
"dropout": counters.dropout
|
||||
}
|
||||
|
||||
def get_traffic_history(self, interface: Optional[str] = None) -> Dict:
|
||||
"""获取流量历史数据"""
|
||||
if interface and interface in self.history["interfaces"]:
|
||||
return {
|
||||
"sent": list(self.history["interfaces"][interface]["sent"]),
|
||||
"recv": list(self.history["interfaces"][interface]["recv"]),
|
||||
"time": list(self.history["time"])
|
||||
}
|
||||
|
||||
return {
|
||||
"sent": list(self.history["sent"]),
|
||||
"recv": list(self.history["recv"]),
|
||||
"time": list(self.history["time"])
|
||||
}
|
||||
|
||||
|
||||
# 全局流量监控实例
|
||||
traffic_monitor = TrafficMonitor()
|
@ -1,15 +1,5 @@
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
class SiliconFlowAPIException(Exception):
|
||||
"""硅基流动API异常"""
|
||||
def __init__(self, detail: str, status_code: int = 500):
|
||||
self.detail = detail
|
||||
self.status_code = status_code
|
||||
super().__init__(detail)
|
||||
|
||||
def __str__(self):
|
||||
return f"SiliconFlowAPI Error [{self.status_code}]: {self.detail}"
|
||||
|
||||
class AICommandParseException(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(
|
||||
@ -34,7 +24,8 @@ class ConfigBackupException(SwitchConfigException):
|
||||
super().__init__(
|
||||
detail=f"无法备份设备 {ip} 的配置"
|
||||
)
|
||||
self.recovery_guide = "检查设备存储空间或权限" # 在子类中存储恢复指南
|
||||
# 将恢复指南作为实例属性
|
||||
self.recovery_guide = "检查设备存储空间或权限"
|
||||
|
||||
class ConfigRollbackException(SwitchConfigException):
|
||||
"""回滚失败异常"""
|
||||
@ -43,4 +34,10 @@ class ConfigRollbackException(SwitchConfigException):
|
||||
detail=f"设备 {ip} 回滚失败(原始错误:{original_error})",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
)
|
||||
self.recovery_guide = "尝试手动恢复配置或重启设备"
|
||||
|
||||
class SiliconFlowAPIException(HTTPException):
|
||||
def __init__(self, detail: str, status_code: int = 400):
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
detail=f"SiliconFlow API error: {detail}"
|
||||
)
|
@ -1,33 +1,89 @@
|
||||
import logging
|
||||
from loguru import logger
|
||||
import sys
|
||||
from loguru import logger as loguru_logger
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
# Get corresponding Loguru level if it exists
|
||||
# 获取对应的Loguru日志级别
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
level = loguru_logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
# Find caller from where originated the logged message
|
||||
# 查找日志来源
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame.f_code.co_filename == logging.__file__:
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
|
||||
# 使用Loguru记录日志
|
||||
loguru_logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# 拦截标准logging
|
||||
logging.basicConfig(handlers=[InterceptHandler()], level=0)
|
||||
logging.basicConfig(handlers=[InterceptHandler()], level=logging.NOTSET)
|
||||
|
||||
# 配置loguru
|
||||
logger.configure(
|
||||
handlers=[
|
||||
{"sink": sys.stdout,
|
||||
"format": "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"}
|
||||
]
|
||||
)
|
||||
# 移除所有现有处理器
|
||||
loguru_logger.remove()
|
||||
|
||||
# 添加新的处理器
|
||||
loguru_logger.add(
|
||||
sys.stdout,
|
||||
format=(
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
level="DEBUG",
|
||||
enqueue=True
|
||||
)
|
||||
|
||||
# 添加文件日志
|
||||
loguru_logger.add(
|
||||
"app.log",
|
||||
rotation="10 MB",
|
||||
retention="30 days",
|
||||
level="INFO",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level} | {message}"
|
||||
)
|
||||
|
||||
|
||||
# 创建通用logger接口
|
||||
class Logger:
|
||||
@staticmethod
|
||||
def debug(msg, *args, **kwargs):
|
||||
loguru_logger.debug(msg, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def info(msg, *args, **kwargs):
|
||||
loguru_logger.info(msg, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def warning(msg, *args, **kwargs):
|
||||
loguru_logger.warning(msg, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def error(msg, *args, **kwargs):
|
||||
loguru_logger.error(msg, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def critical(msg, *args, **kwargs):
|
||||
loguru_logger.critical(msg, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def exception(msg, *args, **kwargs):
|
||||
loguru_logger.exception(msg, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def success(msg, *args, **kwargs):
|
||||
loguru_logger.success(msg, *args, **kwargs)
|
||||
|
||||
|
||||
# 创建全局logger实例
|
||||
logger = Logger()
|
||||
|
||||
# 初始化日志系统
|
||||
setup_logging()
|
@ -1,25 +1,30 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
|
||||
class Settings(BaseSettings):
|
||||
APP_NAME: str = "AI Network Configurator"
|
||||
DEBUG: bool = True
|
||||
API_PREFIX: str = "/api"
|
||||
|
||||
# 硅基流动API配置
|
||||
SILICONFLOW_API_KEY: str = os.getenv("SILICON_API_KEY", "")
|
||||
SILICONFLOW_API_URL: str = os.getenv("SILICONFLOW_API_URL", "https://api.siliconflow.ai/v1")
|
||||
SILICONFLOW_API_KEY: str = os.getenv("SILICONFLOW_API_KEY", "sk-mhzuedasunrgdrxfkcxmxgaypgjnxgodvvmrzzdbqrwtkqej")
|
||||
SILICONFLOW_API_URL: str = os.getenv("SILICONFLOW_API_URL", "https://api.siliconflow.cn/v1")
|
||||
|
||||
# 交换机配置
|
||||
SWITCH_USERNAME: str = os.getenv("SWITCH_USERNAME", "admin")
|
||||
SWITCH_PASSWORD: str = os.getenv("SWITCH_PASSWORD", "admin")
|
||||
SWITCH_TIMEOUT: int = os.getenv("SWITCH_TIMEOUT", 10)
|
||||
|
||||
# eNSP配置
|
||||
ENSP_DEFAULT_IP: str = "172.17.99.201"
|
||||
ENSP_DEFAULT_PORT: int = 2000
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
settings = Settings()
|
@ -1,19 +1,34 @@
|
||||
fastapi>=0.95.2
|
||||
uvicorn>=0.22.0
|
||||
python-dotenv>=1.0.0
|
||||
requests>=2.28.2
|
||||
paramiko>=3.3.0
|
||||
pydantic>=1.10.7
|
||||
loguru>=0.7.0
|
||||
python-nmap>=0.7.1
|
||||
tenacity>=9.1.2
|
||||
typing-extensions>=4.0.0
|
||||
aiofiles>=24.1.0
|
||||
telnetlib3>=2.0.4
|
||||
asyncssh>=2.14.0
|
||||
aiofiles>=24.1.0
|
||||
networkx==3.1
|
||||
scipy==1.11.1
|
||||
stable-baselines3==2.0.0
|
||||
plotly==5.15.0
|
||||
pandas==2.0.3
|
||||
# 核心依赖 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
fastapi==0.110.0
|
||||
uvicorn==0.29.0
|
||||
python-dotenv==1.0.1
|
||||
|
||||
# Pydantic 模型 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
pydantic==2.6.4
|
||||
pydantic-settings==2.2.1
|
||||
|
||||
# 网络操作 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
asyncssh==2.14.2
|
||||
telnetlib3==2.0.3
|
||||
httpx==0.27.0
|
||||
python-nmap==0.7.1
|
||||
|
||||
|
||||
# 异步文件操作 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
aiofiles==23.2.1
|
||||
|
||||
# 日志管理 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
loguru==0.7.2
|
||||
|
||||
# 重试机制 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
tenacity==8.2.3
|
||||
|
||||
# 其他工具 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
asyncio==3.4.3
|
||||
typing_extensions==4.10.0
|
||||
|
||||
#监控依赖 Y
|
||||
|
||||
psutil==5.9.8
|
||||
matplotlib==3.8.3
|
||||
sqlalchemy==2.0.28
|
||||
|
@ -1,11 +1,8 @@
|
||||
import uvicorn
|
||||
from src.backend.app import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"src.backend.app:app",
|
||||
"src.backend.app:app", # 使用字符串引用方式
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
log_level="info",
|
||||
|
@ -3,37 +3,71 @@ import logging
|
||||
from src.backend.app.api.network_config import SwitchConfigurator
|
||||
#该文件用于测试
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
async def test_ensp():
|
||||
"""eNSP测试函数"""
|
||||
# 1. 初始化配置器(对应eNSP设备设置)
|
||||
configurator = SwitchConfigurator(
|
||||
ensp_mode=True, # 启用eNSP模式
|
||||
ensp_port=2000, # 必须与eNSP中设备设置的Telnet端口一致
|
||||
username="admin", # 默认账号
|
||||
password="admin", # 默认密码
|
||||
timeout=15 # 建议超时设长些
|
||||
)
|
||||
|
||||
# 2. 执行配置(示例:创建VLAN100)
|
||||
async def test_connection(configurator):
|
||||
"""测试基础连接"""
|
||||
try:
|
||||
version = await configurator._send_commands("127.0.0.1", ["display version"])
|
||||
print("交换机版本信息:\n", version)
|
||||
return True
|
||||
except Exception as e:
|
||||
print("❌ 连接测试失败:", str(e))
|
||||
return False
|
||||
|
||||
|
||||
async def test_vlan_config(configurator):
|
||||
"""测试 VLAN 配置"""
|
||||
try:
|
||||
result = await configurator.safe_apply(
|
||||
ip="127.0.0.1", # 本地连接固定用这个地址
|
||||
"127.0.0.1",
|
||||
config={
|
||||
"type": "vlan",
|
||||
"vlan_id": 100,
|
||||
"name": "测试VLAN"
|
||||
"name": "自动化测试VLAN"
|
||||
}
|
||||
)
|
||||
print("✅ 配置结果:", result)
|
||||
except Exception as e:
|
||||
print("❌ 配置失败:", str(e))
|
||||
print("VLAN 配置结果:", result)
|
||||
|
||||
# 验证配置
|
||||
vlan_list = await configurator._send_commands("127.0.0.1", ["display vlan"])
|
||||
print("当前VLAN列表:\n", vlan_list)
|
||||
|
||||
return "success" in result.get("status", "")
|
||||
except Exception as e:
|
||||
print("❌ VLAN 配置失败:", str(e))
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试流程"""
|
||||
# 尝试不同端口
|
||||
for port in [2000, 2010, 2020, 23]:
|
||||
print(f"\n尝试端口: {port}")
|
||||
configurator = SwitchConfigurator(
|
||||
ensp_mode=True,
|
||||
ensp_port=port,
|
||||
username="",
|
||||
password="admin",
|
||||
timeout=15
|
||||
)
|
||||
|
||||
if await test_connection(configurator):
|
||||
print(f"✅ 成功连接到端口 {port}")
|
||||
if await test_vlan_config(configurator):
|
||||
print("✅ 所有测试通过!")
|
||||
return
|
||||
else:
|
||||
print("⚠️ VLAN 配置失败,继续尝试其他端口...")
|
||||
else:
|
||||
print("⚠️ 连接失败,尝试下一个端口...")
|
||||
|
||||
print("❌ 所有端口尝试失败,请检查配置")
|
||||
|
||||
|
||||
# 运行测试
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_ensp())
|
||||
asyncio.run(main())
|
Loading…
x
Reference in New Issue
Block a user