Compare commits

...

8 Commits

Author SHA1 Message Date
3
90f127b5ce Merge remote-tracking branch 'origin/main'
# Conflicts:
#	src/backend/.envExample
#	src/backend/app/__init__.py
#	src/backend/app/api/command_parser.py
#	src/backend/app/api/endpoints.py
#	src/backend/app/api/network_config.py
#	src/backend/app/utils/exceptions.py
#	src/backend/requirements.txt
2025-07-09 13:26:56 +08:00
3
ba1a7c216c 1 2025-07-09 13:26:10 +08:00
3
71eb1ee79a 家里电脑版本:保留了连接功能 2025-06-21 11:57:09 +08:00
3
7c17bb931b 114514+114514 2025-06-19 14:52:59 +08:00
3
2231b8cf82 114514+114514 2025-06-18 18:36:52 +08:00
3
6e5cd34da7 114514+114514 2025-06-18 17:00:14 +08:00
3
6f74a80036 114514 2025-06-18 16:49:42 +08:00
3
60359b54ee 114514 2025-06-18 16:18:42 +08:00
18 changed files with 1081 additions and 274 deletions

View File

@ -9,7 +9,7 @@
<content url="file://$MODULE_DIR$"> <content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/.venv" /> <excludeFolder url="file://$MODULE_DIR$/.venv" />
</content> </content>
<orderEntry type="jdk" jdkName="Python 3.13" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="Python 3.11 (AI-powered-switches)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="library" name="Python 3.13 interpreter library" level="application" /> <orderEntry type="library" name="Python 3.13 interpreter library" level="application" />
</component> </component>

View File

@ -1,13 +0,0 @@
# 交换机认证配置
SWITCH_USERNAME=admin
SWITCH_PASSWORD=your_secure_password
SWITCH_TIMEOUT=15
# 硅基流动API配置
SILICONFLOW_API_KEY=sk-114514
SILICONFLOW_API_URL=https://api.siliconflow.ai/v1
# FastAPI 配置
UVICORN_HOST=0.0.0.0
UVICORN_PORT=8000
UVICORN_RELOAD=false

View File

@ -1,25 +1,49 @@
from fastapi import FastAPI from fastapi import FastAPI, responses
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware import Middleware # 新增导入
from src.backend.app.api.endpoints import router from src.backend.app.api.endpoints import router
from src.backend.app.utils.logger import setup_logging from src.backend.app.utils.logger import setup_logging
from src.backend.config import settings from src.backend.config import settings
from .services.switch_traffic_monitor import get_switch_monitor
api_app = FastAPI() from .services.traffic_monitor import traffic_monitor
api_app.include_router(router,prefix="/api") from src.backend.app.api.database import init_db
def create_app() -> FastAPI: def create_app() -> FastAPI:
# 初始化数据库
init_db()
# 启动流量监控
traffic_monitor.start_monitoring()
# 设置日志 # 设置日志
setup_logging() setup_logging()
# 创建FastAPI应用 # 创建FastAPI应用(使用新的中间件配置方式)
app = FastAPI( app = FastAPI(
title=settings.APP_NAME, title=settings.APP_NAME,
debug=settings.DEBUG, debug=settings.DEBUG,
docs_url=f"{settings.API_PREFIX}/docs", docs_url=f"{settings.API_PREFIX}/docs",
redoc_url=f"{settings.API_PREFIX}/redoc", redoc_url=f"{settings.API_PREFIX}/redoc",
openapi_url=f"{settings.API_PREFIX}/openapi.json" openapi_url=f"{settings.API_PREFIX}/openapi.json",
middleware=[ # 这里直接配置中间件
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
) )
]
)
# 添加根路径重定向
@app.get("/", include_in_schema=False)
async def root():
return responses.RedirectResponse(url=f"{settings.API_PREFIX}/docs")
# 添加API路由 # 添加API路由
app.include_router(router, prefix=settings.API_PREFIX) app.include_router(router, prefix=settings.API_PREFIX)
return app return app
app = create_app()

View File

@ -1,7 +1,8 @@
from typing import Dict, Any, Optional from typing import Dict, Any
from src.backend.app.services.ai_services import AIService from src.backend.app.services.ai_services import AIService
from src.backend.config import settings from src.backend.config import settings
class CommandParser: class CommandParser:
def __init__(self): def __init__(self):
self.ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL) self.ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
@ -19,7 +20,7 @@ class CommandParser:
return await self.ai_service.parse_command(command) return await self.ai_service.parse_command(command)
@staticmethod @staticmethod
def _try_local_parse(command: str) -> Optional[Dict[str, Any]]: def _try_local_parse(command: str) -> dict[str, str | list[Any]] | dict[str, str] | None:
""" """
尝试本地解析常见命令 尝试本地解析常见命令
""" """
@ -28,13 +29,12 @@ class CommandParser:
# VLAN配置 # VLAN配置
if "vlan" in command and "创建" in command: if "vlan" in command and "创建" in command:
parts = command.split() parts = command.split()
vlan_id_str = next((p for p in parts if p.isdigit()), None) vlan_id = next((p for p in parts if p.isdigit()), None)
if vlan_id_str: if vlan_id:
vlan_id = int(vlan_id_str) # 转换为整数
return { return {
"type": "vlan", "type": "vlan",
"vlan_id": vlan_id, # 使用整数 "vlan_id": vlan_id,
"name": f"VLAN{vlan_id}", # 使用转换后的整数 "name": f"VLAN{vlan_id}",
"interfaces": [] "interfaces": []
} }
@ -61,9 +61,9 @@ class CommandParser:
config["description"] = description config["description"] = description
if "vlan" in command: if "vlan" in command:
vlan_id_str = next((p for p in parts if p.isdigit()), None) vlan_id = next((p for p in parts if p.isdigit()), None)
if vlan_id_str: if vlan_id:
config["vlan"] = int(vlan_id_str) # 转换为整数 config["vlan"] = vlan_id
return config return config

View File

@ -0,0 +1,17 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///./traffic_monitor.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def init_db():
"""初始化数据库"""
# 删除多余的导入
Base.metadata.create_all(bind=engine)

View File

@ -1,97 +1,73 @@
from fastapi import APIRouter, HTTPException from datetime import datetime, timedelta
from typing import List, Dict from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
from typing import List
from pydantic import BaseModel from pydantic import BaseModel
import asyncio
from fastapi.responses import HTMLResponse
import matplotlib.pyplot as plt
import io
import base64
from ..services.switch_traffic_monitor import get_switch_monitor
from ..utils import logger
from ...app.services.ai_services import AIService
from ...app.api.network_config import SwitchConfigurator
from ...config import settings from ...config import settings
from ..services.network_scanner import NetworkScanner from ..services.network_scanner import NetworkScanner
from ..api.network_config import SwitchConfigurator from ...app.services.traffic_monitor import traffic_monitor
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
router = APIRouter(prefix="/api", tags=["API"])
router = APIRouter(prefix="", tags=["API"])
scanner = NetworkScanner() scanner = NetworkScanner()
# 添加根路径路由
@router.get("/", include_in_schema=False)
async def root():
return {
"message": "欢迎使用AI交换机配置系统",
"docs": f"{settings.API_PREFIX}/docs",
"redoc": f"{settings.API_PREFIX}/redoc",
"endpoints": [
"/parse_command",
"/apply_config",
"/scan_network",
"/list_devices",
"/batch_apply_config"
"/traffic/switch/current", # 添加交换机流量端点
"/traffic/switch/history" # 添加交换机历史流量端点
]
}
# 添加favicon处理
@router.get("/favicon.ico", include_in_schema=False)
async def favicon():
return Response(status_code=204)
# ====================
# 请求模型
# ====================
class BatchConfigRequest(BaseModel): class BatchConfigRequest(BaseModel):
config: Dict config: dict
switch_ips: List[str] switch_ips: List[str] # 支持多个IP
class CommandRequest(BaseModel):
command: str
class ConfigRequest(BaseModel):
config: Dict
switch_ip: str
# ====================
# API端点
# ====================
@router.post("/batch_apply_config") @router.post("/batch_apply_config")
async def batch_apply_config(request: BatchConfigRequest): async def batch_apply_config(request: BatchConfigRequest):
"""
批量配置交换机
- 支持同时配置多台设备
- 自动处理连接池
- 返回每个设备的详细结果
"""
configurator = SwitchConfigurator(
username=settings.SWITCH_USERNAME,
password=settings.SWITCH_PASSWORD,
timeout=settings.SWITCH_TIMEOUT
)
results = {} results = {}
try:
for ip in request.switch_ips: for ip in request.switch_ips:
try: try:
# 使用公开的apply_config方法 configurator = SwitchConfigurator()
results[ip] = await configurator.apply_config(ip, request.config) results[ip] = await configurator.apply_config(ip, request.config)
except Exception as e: except Exception as e:
results[ip] = { results[ip] = str(e)
"status": "failed",
"error": str(e)
}
return {"results": results} return {"results": results}
finally:
await configurator.close()
@router.post("/apply_config", response_model=Dict)
async def apply_config(request: ConfigRequest):
"""
单设备配置
- 更详细的错误处理
- 自动备份和回滚
"""
configurator = SwitchConfigurator(
username=settings.SWITCH_USERNAME,
password=settings.SWITCH_PASSWORD,
timeout=settings.SWITCH_TIMEOUT
)
try:
result = await configurator.apply_config(request.switch_ip, request.config)
if result["status"] != "success":
raise HTTPException(
status_code=500,
detail=result.get("error", "配置失败")
)
return result
finally:
await configurator.close()
# ====================
# 其他原有端点(保持不动)
# ====================
@router.get("/test") @router.get("/test")
async def test_endpoint(): async def test_endpoint():
return {"message": "Hello World"} return {"message": "Hello World"}
@router.get("/scan_network", summary="扫描网络中的交换机") @router.get("/scan_network", summary="扫描网络中的交换机")
async def scan_network(subnet: str = "192.168.1.0/24"): async def scan_network(subnet: str = "192.168.1.0/24"):
try: try:
@ -104,23 +80,25 @@ async def scan_network(subnet: str = "192.168.1.0/24"):
except Exception as e: except Exception as e:
raise HTTPException(500, f"扫描失败: {str(e)}") raise HTTPException(500, f"扫描失败: {str(e)}")
@router.get("/list_devices", summary="列出已发现的交换机") @router.get("/list_devices", summary="列出已发现的交换机")
async def list_devices(): async def list_devices():
return { return {
"devices": scanner.load_cached_devices() "devices": scanner.load_cached_devices()
} }
class CommandRequest(BaseModel):
command: str
@router.post("/parse_command", response_model=Dict) class ConfigRequest(BaseModel):
config: dict
switch_ip: str
@router.post("/parse_command", response_model=dict)
async def parse_command(request: CommandRequest): async def parse_command(request: CommandRequest):
""" """
解析中文命令并返回JSON配置 解析中文命令并返回JSON配置
- 依赖AI服务
- 返回标准化配置
""" """
try: try:
from ..services.ai_services import AIService # 延迟导入避免循环依赖
ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL) ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
config = await ai_service.parse_command(request.command) config = await ai_service.parse_command(request.command)
return {"success": True, "config": config} return {"success": True, "config": config}
@ -129,3 +107,284 @@ async def parse_command(request: CommandRequest):
status_code=400, status_code=400,
detail=f"Failed to parse command: {str(e)}" detail=f"Failed to parse command: {str(e)}"
) )
@router.post("/apply_config", response_model=dict)
async def apply_config(request: ConfigRequest):
"""
应用配置到交换机
"""
try:
configurator = SwitchConfigurator(
username=settings.SWITCH_USERNAME,
password=settings.SWITCH_PASSWORD,
timeout=settings.SWITCH_TIMEOUT
)
result = await configurator.safe_apply(request.switch_ip, request.config)
return {"success": True, "result": result}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to apply config: {str(e)}"
)
@router.get("/traffic/interfaces", summary="获取所有网络接口")
async def get_network_interfaces():
return {
"interfaces": traffic_monitor.get_interfaces()
}
@router.get("/traffic/current", summary="获取当前流量数据")
async def get_current_traffic(interface: str = None):
return traffic_monitor.get_current_traffic(interface)
@router.get("/traffic/history", summary="获取流量历史数据")
async def get_traffic_history(interface: str = None, limit: int = 100):
history = traffic_monitor.get_traffic_history(interface)
return {
"sent": history["sent"][-limit:],
"recv": history["recv"][-limit:],
"time": [t.isoformat() for t in history["time"]][-limit:]
}
@router.get("/traffic/records", summary="获取流量记录")
async def get_traffic_records(interface: str = None, limit: int = 100):
with SessionLocal() as session:
query = session.query(TrafficRecord)
if interface:
query = query.filter(TrafficRecord.interface == interface)
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
return [record.to_dict() for record in records]
@router.websocket("/ws/traffic")
async def websocket_traffic(websocket: WebSocket):
"""实时流量WebSocket"""
await websocket.accept()
try:
while True:
# 获取所有接口的当前流量
traffic_data = traffic_monitor.get_current_traffic()
await websocket.send_json(traffic_data)
await asyncio.sleep(1) # 每秒更新一次
except WebSocketDisconnect:
print("客户端断开连接")
@router.get("/", include_in_schema=False)
async def root():
return {
"message": "欢迎使用AI交换机配置系统",
"docs": f"{settings.API_PREFIX}/docs",
"redoc": f"{settings.API_PREFIX}/redoc",
"endpoints": [
"/parse_command",
"/apply_config",
"/scan_network",
"/list_devices",
"/batch_apply_config",
"/traffic/switch/current", # 添加交换机流量端点
"/traffic/switch/history" # 添加交换机历史流量端点
]
}
# ... 其他路由保持不变 ...
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
async def get_switch_interfaces(switch_ip: str):
"""获取指定交换机的所有接口"""
try:
monitor = get_switch_monitor(switch_ip)
interfaces = list(monitor.interface_oids.keys())
return {
"switch_ip": switch_ip,
"interfaces": interfaces
}
except Exception as e:
logger.error(f"获取交换机接口失败: {str(e)}")
raise HTTPException(500, f"获取接口失败: {str(e)}")
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
"""获取交换机的当前流量数据"""
try:
monitor = get_switch_monitor(switch_ip)
# 如果没有指定接口,获取所有接口的流量
if not interface:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
return {
"switch_ip": switch_ip,
"traffic": traffic_data
}
# 获取指定接口的流量
return await get_interface_current_traffic(switch_ip, interface)
except Exception as e:
logger.error(f"获取交换机流量失败: {str(e)}")
raise HTTPException(500, f"获取流量失败: {str(e)}")
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
"""获取指定交换机接口的当前流量数据"""
try:
with SessionLocal() as session:
# 获取最新记录
record = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
if not record:
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": 0.0,
"rate_out": 0.0,
"bytes_in": 0,
"bytes_out": 0
}
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": record.rate_in,
"rate_out": record.rate_out,
"bytes_in": record.bytes_in,
"bytes_out": record.bytes_out
}
except Exception as e:
logger.error(f"获取接口流量失败: {str(e)}")
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
"""获取交换机的流量历史数据"""
try:
monitor = get_switch_monitor(switch_ip)
# 如果没有指定接口,返回所有接口的历史数据
if not interface:
return {
"switch_ip": switch_ip,
"history": monitor.get_traffic_history()
}
# 获取指定接口的历史数据
with SessionLocal() as session:
# 计算时间范围
time_threshold = datetime.now() - timedelta(minutes=minutes)
# 查询历史记录
records = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface,
SwitchTrafficRecord.timestamp >= time_threshold
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
# 提取数据
history_data = {
"in": [record.rate_in for record in records],
"out": [record.rate_out for record in records],
"time": [record.timestamp.isoformat() for record in records]
}
return {
"switch_ip": switch_ip,
"interface": interface,
"history": history_data
}
except Exception as e:
logger.error(f"获取历史流量失败: {str(e)}")
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
@router.websocket("/ws/traffic/switch")
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
"""交换机实时流量WebSocket"""
await websocket.accept()
try:
monitor = get_switch_monitor(switch_ip)
while True:
# 获取流量数据
if interface:
# 单个接口
traffic_data = await get_interface_current_traffic(switch_ip, interface)
await websocket.send_json(traffic_data)
else:
# 所有接口
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
await websocket.send_json({
"switch_ip": switch_ip,
"traffic": traffic_data
})
# 每秒更新一次
await asyncio.sleep(1)
except WebSocketDisconnect:
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
except Exception as e:
logger.error(f"交换机流量WebSocket错误: {str(e)}")
await websocket.close(code=1011, reason=str(e))
# 交换机流量可视化
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
"""生成交换机流量图表"""
try:
# 获取历史数据
history = await get_switch_traffic_history(switch_ip, interface, minutes)
history_data = history["history"]
# 提取数据点
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
in_rates = history_data["in"]
out_rates = history_data["out"]
# 创建图表
plt.figure(figsize=(12, 6))
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
plt.xlabel("时间")
plt.ylabel("流量 (字节/秒)")
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
# 转换为HTML图像
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
return f"""
<html>
<head>
<title>交换机流量监控</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.container {{ max-width: 900px; margin: 0 auto; }}
</style>
</head>
<body>
<div class="container">
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</div>
</body>
</html>
"""
except Exception as e:
logger.error(f"生成流量图表失败: {str(e)}")
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)

View File

@ -1,14 +1,14 @@
import asyncio import asyncio
import logging import logging
import telnetlib3 import telnetlib3
import time
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_exponential
import aiofiles import aiofiles
import asyncssh import asyncssh
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_exponential
# ---------------------- # ----------------------
@ -39,7 +39,7 @@ class SSHConnectionException(SwitchConfigException):
# ---------------------- # ----------------------
# 核心配置器 # 核心配置器(完整双模式)
# ---------------------- # ----------------------
class SwitchConfigurator: class SwitchConfigurator:
def __init__( def __init__(
@ -63,35 +63,12 @@ class SwitchConfigurator:
self.ensp_port = ensp_port self.ensp_port = ensp_port
self.ensp_delay = ensp_command_delay self.ensp_delay = ensp_command_delay
self.ssh_options = ssh_options self.ssh_options = ssh_options
self._connection_pool = {} # SSH连接池
# ==================== async def apply_config(self, ip: str, config: Union[Dict, SwitchConfig]) -> str:
# 公开API方法 """实际配置逻辑"""
# ====================
async def apply_config(self, ip: str, config: Union[Dict, SwitchConfig]) -> Dict:
"""
应用配置到交换机主入口
返回格式:
{
"status": "success"|"failed",
"output": str,
"backup_path": str,
"error": Optional[str],
"timestamp": str
}
"""
if isinstance(config, dict): if isinstance(config, dict):
config = SwitchConfig(**config) config = SwitchConfig(**config)
result = await self.safe_apply(ip, config)
result["timestamp"] = datetime.now().isoformat()
return result
# ====================
# 内部实现方法
# ====================
async def _apply_config(self, ip: str, config: SwitchConfig) -> str:
"""实际配置逻辑"""
commands = ( commands = (
self._generate_ensp_commands(config) self._generate_ensp_commands(config)
if self.ensp_mode if self.ensp_mode
@ -107,74 +84,60 @@ class SwitchConfigurator:
else await self._send_ssh_commands(ip, commands) else await self._send_ssh_commands(ip, commands)
) )
# --------- eNSP模式专用 ---------
async def _send_ensp_commands(self, ip: str, commands: List[str]) -> str: async def _send_ensp_commands(self, ip: str, commands: List[str]) -> str:
"""Telnet协议执行eNSP""" """Telnet协议执行eNSP"""
try: try:
# 修复点使用正确的timeout参数
reader, writer = await telnetlib3.open_connection( reader, writer = await telnetlib3.open_connection(
host=ip, host=ip,
port=self.ensp_port, port=self.ensp_port,
connect_minwait=self.timeout, connect_minwait=self.timeout, # telnetlib3的实际可用参数
connect_maxwait=self.timeout connect_maxwait=self.timeout
) )
# 登录流程 # 登录流程(增加超时处理)
await reader.readuntil(b"Username:") try:
await asyncio.wait_for(reader.readuntil(b"Username:"), timeout=self.timeout)
writer.write(f"{self.username}\n") writer.write(f"{self.username}\n")
await reader.readuntil(b"Password:")
await asyncio.wait_for(reader.readuntil(b"Password:"), timeout=self.timeout)
writer.write(f"{self.password}\n") writer.write(f"{self.password}\n")
# 等待登录完成
await asyncio.sleep(1) await asyncio.sleep(1)
except asyncio.TimeoutError:
raise EnspConnectionException("登录超时")
# 执行命令 # 执行命令
output = "" output = ""
for cmd in commands: for cmd in commands:
writer.write(f"{cmd}\n") writer.write(f"{cmd}\n")
await asyncio.sleep(self.ensp_delay) await writer.drain() # 确保命令发送完成
while True:
# 读取响应(增加超时处理)
try: try:
while True:
data = await asyncio.wait_for(reader.read(1024), timeout=1) data = await asyncio.wait_for(reader.read(1024), timeout=1)
if not data: if not data:
break break
output += data output += data
except asyncio.TimeoutError: except asyncio.TimeoutError:
break continue # 单次读取超时不视为错误
# 关闭连接
writer.close() writer.close()
# noinspection PyBroadException
try:
await writer.wait_closed()
except:
logging.debug("连接关闭时出现异常", exc_info=True) # 至少记录异常信息
pass
return output return output
except Exception as e: except Exception as e:
raise EnspConnectionException(f"eNSP连接失败: {str(e)}") raise EnspConnectionException(f"eNSP连接失败: {str(e)}")
async def _clean_idle_connections(self):
"""连接池清理机制"""
now = time.time()
for ip, (conn, last_used) in list(self._connection_pool.items()):
if now - last_used > 300: # 5分钟空闲超时
conn.close()
del self._connection_pool[ip]
async def _send_ssh_commands(self, ip: str, commands: List[str]) -> str:
"""SSH协议执行"""
async with self.semaphore:
try:
if ip not in self._connection_pool:
self._connection_pool[ip] = await asyncssh.connect(
host=ip,
username=self.username,
password=self.password,
connect_timeout=self.timeout,
**self.ssh_options
)
results = []
for cmd in commands:
result = await self._connection_pool[ip].run(cmd)
results.append(result.stdout)
return "\n".join(results)
except asyncssh.Error as e:
if ip in self._connection_pool:
self._connection_pool[ip].close()
del self._connection_pool[ip]
raise SSHConnectionException(f"SSH操作失败: {str(e)}")
@staticmethod @staticmethod
def _generate_ensp_commands(config: SwitchConfig) -> List[str]: def _generate_ensp_commands(config: SwitchConfig) -> List[str]:
"""生成eNSP命令序列""" """生成eNSP命令序列"""
@ -194,6 +157,28 @@ class SwitchConfigurator:
commands.append("return") commands.append("return")
return [c for c in commands if c.strip()] return [c for c in commands if c.strip()]
# --------- SSH模式专用使用AsyncSSH ---------
async def _send_ssh_commands(self, ip: str, commands: List[str]) -> str:
"""AsyncSSH执行命令"""
async with self.semaphore:
try:
async with asyncssh.connect(
host=ip,
username=self.username,
password=self.password,
connect_timeout=self.timeout, # AsyncSSH的正确参数名
**self.ssh_options
) as conn:
results = []
for cmd in commands:
result = await conn.run(cmd, check=True)
results.append(result.stdout)
return "\n".join(results)
except asyncssh.Error as e:
raise SSHConnectionException(f"SSH操作失败: {str(e)}")
except Exception as e:
raise SSHConnectionException(f"连接异常: {str(e)}")
@staticmethod @staticmethod
def _generate_standard_commands(config: SwitchConfig) -> List[str]: def _generate_standard_commands(config: SwitchConfig) -> List[str]:
"""生成标准CLI命令""" """生成标准CLI命令"""
@ -211,6 +196,16 @@ class SwitchConfigurator:
]) ])
return commands return commands
# --------- 通用功能 ---------
async def _validate_config(self, ip: str, config: SwitchConfig) -> bool:
"""验证配置是否生效"""
current = await self._get_current_config(ip)
if config.type == "vlan":
return f"vlan {config.vlan_id}" in current
elif config.type == "interface" and config.vlan:
return f"switchport access vlan {config.vlan}" in current
return True
async def _get_current_config(self, ip: str) -> str: async def _get_current_config(self, ip: str) -> str:
"""获取当前配置""" """获取当前配置"""
commands = ( commands = (
@ -259,7 +254,7 @@ class SwitchConfigurator:
"""安全配置应用(自动回滚)""" """安全配置应用(自动回滚)"""
backup_path = await self._backup_config(ip) backup_path = await self._backup_config(ip)
try: try:
result = await self._apply_config(ip, config) result = await self.apply_config(ip, config)
if not await self._validate_config(ip, config): if not await self._validate_config(ip, config):
raise SwitchConfigException("配置验证失败") raise SwitchConfigException("配置验证失败")
return { return {
@ -276,17 +271,40 @@ class SwitchConfigurator:
"restore_success": restore_status "restore_success": restore_status
} }
async def _validate_config(self, ip: str, config: SwitchConfig) -> bool:
"""验证配置是否生效"""
current = await self._get_current_config(ip)
if config.type == "vlan":
return f"vlan {config.vlan_id}" in current
elif config.type == "interface" and config.vlan:
return f"switchport access vlan {config.vlan}" in current
return True
async def close(self): # ----------------------
"""清理所有连接""" # 使用示例
for conn in self._connection_pool.values(): # ----------------------
conn.close() async def demo():
self._connection_pool.clear() # 示例1: eNSP设备配置Telnet模式
ensp_configurator = SwitchConfigurator(
ensp_mode=True,
ensp_port=2000,
username="admin",
password="admin",
timeout=15
)
ensp_result = await ensp_configurator.safe_apply("127.0.0.1", {
"type": "interface",
"interface": "GigabitEthernet0/0/1",
"vlan": 100,
"ip_address": "192.168.1.2 255.255.255.0"
})
print("eNSP配置结果:", ensp_result)
# 示例2: 真实设备配置SSH模式
ssh_configurator = SwitchConfigurator(
username="cisco",
password="cisco123",
timeout=15
)
ssh_result = await ssh_configurator.safe_apply("192.168.1.1", {
"type": "vlan",
"vlan_id": 200,
"name": "Production"
})
print("SSH配置结果:", ssh_result)
if __name__ == "__main__":
asyncio.run(demo())

View File

@ -0,0 +1,52 @@
# 添加正确的导入
from src.backend.app.api.database import Base # 修复:导入 Base
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, Float
class TrafficRecord(Base):
"""网络流量记录模型"""
__tablename__ = "traffic_records"
id = Column(Integer, primary_key=True, index=True)
interface = Column(String(50), index=True)
bytes_sent = Column(Integer)
bytes_recv = Column(Integer)
packets_sent = Column(Integer)
packets_recv = Column(Integer)
timestamp = Column(DateTime)
def to_dict(self):
return {
"id": self.id,
"interface": self.interface,
"bytes_sent": self.bytes_sent,
"bytes_recv": self.bytes_recv,
"packets_sent": self.packets_sent,
"packets_recv": self.packets_recv,
"timestamp": self.timestamp.isoformat()
}
class SwitchTrafficRecord(Base):
__tablename__ = "switch_traffic_records"
id = Column(Integer, primary_key=True, index=True)
switch_ip = Column(String(50), index=True)
interface = Column(String(50))
bytes_in = Column(BigInteger) # 累计流入字节数
bytes_out = Column(BigInteger) # 累计流出字节数
rate_in = Column(Float) # 当前流入速率(字节/秒)
rate_out = Column(Float) # 当前流出速率(字节/秒)
timestamp = Column(DateTime)
def to_dict(self):
return {
"id": self.id,
"switch_ip": self.switch_ip,
"interface": self.interface,
"bytes_in": self.bytes_in,
"bytes_out": self.bytes_out,
"rate_in": self.rate_in,
"rate_out": self.rate_out,
"timestamp": self.timestamp.isoformat()
}

View File

@ -26,7 +26,7 @@ class AIService:
""" """
data = { data = {
"model": "text-davinci-003", "model": "deepseek-ai/DeepSeek-V3",
"prompt": prompt, "prompt": prompt,
"max_tokens": 1000, "max_tokens": 1000,
"temperature": 0.3 "temperature": 0.3
@ -35,7 +35,7 @@ class AIService:
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.api_url}/completions", f"{self.api_url}/chat/completions",
headers=self.headers, headers=self.headers,
json=data, json=data,
timeout=30 timeout=30

View File

@ -1,3 +1,6 @@
import os
os.environ["PATH"] += ";C:\\Program Files (x86)\\Nmap"
import nmap import nmap
import json import json
from pathlib import Path from pathlib import Path

View File

@ -0,0 +1,196 @@
import asyncio
from datetime import datetime
from collections import deque
from typing import Optional, List, Dict
from pysnmp.hlapi import *
from ..models.traffic_models import SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
from ..utils.logger import logger
class SwitchTrafficMonitor:
def __init__(
self,
switch_ip: str,
community: str = 'public',
update_interval: int = 5,
interfaces: Optional[List[str]] = None
):
self.switch_ip = switch_ip
self.community = community
self.update_interval = update_interval
self.running = False
self.task = None
self.interface_history = {}
self.history = {
"in": deque(maxlen=300),
"out": deque(maxlen=300),
"time": deque(maxlen=300)
}
# 基本接口OID映射
self.interface_oids = {
"GigabitEthernet0/0/1": {
"in": '1.3.6.1.2.1.2.2.1.10.1', # ifInOctets
"out": '1.3.6.1.2.1.2.2.1.16.1' # ifOutOctets
},
"GigabitEthernet0/0/24": {
"in": '1.3.6.1.2.1.2.2.1.10.24',
"out": '1.3.6.1.2.1.2.2.1.16.24'
}
}
# 接口过滤
if interfaces:
self.interface_oids = {
iface: oid for iface, oid in self.interface_oids.items()
if iface in interfaces
}
logger.info(f"监控指定接口: {', '.join(interfaces)}")
else:
logger.info("监控所有接口")
def start_monitoring(self):
"""启动交换机流量监控"""
if not self.running:
self.running = True
self.task = asyncio.create_task(self._monitor_loop())
logger.success(f"交换机流量监控已启动: {self.switch_ip}")
async def stop_monitoring(self):
"""停止监控"""
if self.running:
self.running = False
if self.task:
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
logger.info(f"交换机流量监控已停止: {self.switch_ip}")
async def _monitor_loop(self):
"""监控主循环"""
last_values = {iface: {"in": 0, "out": 0} for iface in self.interface_oids}
last_time = datetime.now()
while self.running:
await asyncio.sleep(self.update_interval)
try:
current_time = datetime.now()
elapsed = (current_time - last_time).total_seconds()
# 获取所有接口流量
for iface, oids in self.interface_oids.items():
in_octets = self._snmp_get(oids["in"])
out_octets = self._snmp_get(oids["out"])
if in_octets is not None and out_octets is not None:
# 计算速率(字节/秒)
# 修复字典访问问题
iface_values = last_values[iface]
in_rate = (in_octets - iface_values["in"]) / elapsed if iface_values["in"] > 0 else 0
out_rate = (out_octets - iface_values["out"]) / elapsed if iface_values["out"] > 0 else 0
# 保存历史数据
self.history["in"].append(in_rate)
self.history["out"].append(out_rate)
self.history["time"].append(current_time)
# 保存到数据库
self._save_to_db(iface, in_octets, out_octets, in_rate, out_rate, current_time)
# 更新最后的值
iface_values["in"] = in_octets
iface_values["out"] = out_octets
last_time = current_time
except Exception as e:
logger.error(f"监控交换机流量出错: {str(e)}")
def _snmp_get(self, oid) -> Optional[int]:
"""执行SNMP GET请求"""
try:
# 正确格式化的SNMP请求
cmd = getCmd(
SnmpEngine(),
CommunityData(self.community),
UdpTransportTarget((self.switch_ip, 161)),
ContextData(),
ObjectType(ObjectIdentity(oid)))
# 执行命令
errorIndication, errorStatus, errorIndex, varBinds = next(cmd)
except Exception as e:
logger.error(f"SNMP请求失败: {str(e)}")
return None
if errorIndication:
logger.error(f"SNMP错误: {errorIndication}")
return None
elif errorStatus:
try:
# 修复括号问题
if errorIndex:
index_val = int(errorIndex) - 1
error_item = varBinds[index_val] if index_val < len(varBinds) else '?'
else:
error_item = '?'
error_msg = f"SNMP错误: {errorStatus.prettyPrint()} at {error_item}"
logger.error(error_msg)
except Exception as e:
logger.error(f"解析SNMP错误失败: {str(e)}")
return None
else:
for varBind in varBinds:
try:
return int(varBind[1])
except Exception as e:
logger.error(f"转换SNMP值失败: {str(e)}")
return None
return None
def _save_to_db(self, interface: str, in_octets: int, out_octets: int,
in_rate: float, out_rate: float, timestamp: datetime):
"""保存流量数据到数据库"""
try:
with SessionLocal() as session:
record = SwitchTrafficRecord(
switch_ip=self.switch_ip,
interface=interface,
bytes_in=in_octets,
bytes_out=out_octets,
rate_in=in_rate,
rate_out=out_rate,
timestamp=timestamp
)
session.add(record)
session.commit()
except Exception as e:
logger.error(f"保存流量数据到数据库失败: {str(e)}")
def get_traffic_history(self) -> Dict[str, List]:
"""获取流量历史数据"""
return {
"in": list(self.history["in"]),
"out": list(self.history["out"]),
"time": list(self.history["time"])
}
# 全局监控器字典(支持多个交换机)
switch_monitors = {}
def get_switch_monitor(switch_ip: str, community: str = 'public', interfaces: Optional[List[str]] = None):
"""获取或创建交换机监控器(添加接口过滤参数)"""
if switch_ip not in switch_monitors:
switch_monitors[switch_ip] = SwitchTrafficMonitor(
switch_ip,
community,
interfaces=interfaces
)
return switch_monitors[switch_ip]

View File

@ -0,0 +1,147 @@
import psutil
import time
import asyncio
from datetime import datetime
from collections import deque
from typing import Dict, Optional, List
from ..models.traffic_models import TrafficRecord
from src.backend.app.api.database import SessionLocal # 修复:导入 SessionLocal
class TrafficMonitor:
def __init__(self, history_size: int = 300):
self.history_size = history_size # 保存历史大小
self.history = {
"sent": deque(maxlen=history_size),
"recv": deque(maxlen=history_size),
"time": deque(maxlen=history_size),
"interfaces": {}
}
self.running = False
self.task = None
self.update_interval = 1.0 # 秒
@staticmethod
def get_interfaces() -> List[str]:
"""获取所有网络接口名称"""
return list(psutil.net_io_counters(pernic=True).keys())
def start_monitoring(self):
"""启动流量监控"""
if not self.running:
self.running = True
self.task = asyncio.create_task(self._monitor_loop())
print("流量监控已启动")
async def stop_monitoring(self):
"""停止流量监控"""
if self.running:
self.running = False
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
print("流量监控已停止")
async def _monitor_loop(self):
"""监控主循环"""
last_stats = psutil.net_io_counters(pernic=True)
last_time = time.time()
while self.running:
await asyncio.sleep(self.update_interval)
current_time = time.time()
current_stats = psutil.net_io_counters(pernic=True)
elapsed = current_time - last_time
# 计算每个接口的流量速率
for iface in current_stats:
if iface not in self.history["interfaces"]:
# 修复:使用 self.history_size
self.history["interfaces"][iface] = {
"sent": deque(maxlen=self.history_size),
"recv": deque(maxlen=self.history_size)
}
if iface in last_stats:
sent_rate = (current_stats[iface].bytes_sent - last_stats[iface].bytes_sent) / elapsed
recv_rate = (current_stats[iface].bytes_recv - last_stats[iface].bytes_recv) / elapsed
# 保存到历史数据
self.history["sent"].append(sent_rate)
self.history["recv"].append(recv_rate)
self.history["time"].append(datetime.now())
# 保存到接口特定历史
self.history["interfaces"][iface]["sent"].append(sent_rate)
self.history["interfaces"][iface]["recv"].append(recv_rate)
# 保存到数据库
self._save_to_db(current_stats)
last_stats = current_stats
last_time = current_time
@staticmethod
def _save_to_db(stats):
"""保存流量数据到数据库"""
with SessionLocal() as session:
for iface, counters in stats.items():
record = TrafficRecord(
interface=iface,
bytes_sent=counters.bytes_sent,
bytes_recv=counters.bytes_recv,
packets_sent=counters.packets_sent,
packets_recv=counters.packets_recv,
timestamp=datetime.now()
)
session.add(record)
session.commit()
def get_current_traffic(self, interface: Optional[str] = None) -> Dict:
"""获取当前流量数据"""
stats = psutil.net_io_counters(pernic=True)
if interface:
if interface in stats:
return self._format_interface_stats(stats[interface])
return {}
return {iface: self._format_interface_stats(data) for iface, data in stats.items()}
@staticmethod
def _format_interface_stats(counters) -> Dict:
"""格式化接口统计数据"""
return {
"bytes_sent": counters.bytes_sent,
"bytes_recv": counters.bytes_recv,
"packets_sent": counters.packets_sent,
"packets_recv": counters.packets_recv,
"errin": counters.errin,
"errout": counters.errout,
"dropin": counters.dropin,
"dropout": counters.dropout
}
def get_traffic_history(self, interface: Optional[str] = None) -> Dict:
"""获取流量历史数据"""
if interface and interface in self.history["interfaces"]:
return {
"sent": list(self.history["interfaces"][interface]["sent"]),
"recv": list(self.history["interfaces"][interface]["recv"]),
"time": list(self.history["time"])
}
return {
"sent": list(self.history["sent"]),
"recv": list(self.history["recv"]),
"time": list(self.history["time"])
}
# 全局流量监控实例
traffic_monitor = TrafficMonitor()

View File

@ -1,15 +1,5 @@
from fastapi import HTTPException, status from fastapi import HTTPException, status
class SiliconFlowAPIException(Exception):
"""硅基流动API异常"""
def __init__(self, detail: str, status_code: int = 500):
self.detail = detail
self.status_code = status_code
super().__init__(detail)
def __str__(self):
return f"SiliconFlowAPI Error [{self.status_code}]: {self.detail}"
class AICommandParseException(HTTPException): class AICommandParseException(HTTPException):
def __init__(self, detail: str): def __init__(self, detail: str):
super().__init__( super().__init__(
@ -34,7 +24,8 @@ class ConfigBackupException(SwitchConfigException):
super().__init__( super().__init__(
detail=f"无法备份设备 {ip} 的配置" detail=f"无法备份设备 {ip} 的配置"
) )
self.recovery_guide = "检查设备存储空间或权限" # 在子类中存储恢复指南 # 将恢复指南作为实例属性
self.recovery_guide = "检查设备存储空间或权限"
class ConfigRollbackException(SwitchConfigException): class ConfigRollbackException(SwitchConfigException):
"""回滚失败异常""" """回滚失败异常"""
@ -43,4 +34,10 @@ class ConfigRollbackException(SwitchConfigException):
detail=f"设备 {ip} 回滚失败(原始错误:{original_error}", detail=f"设备 {ip} 回滚失败(原始错误:{original_error}",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
) )
self.recovery_guide = "尝试手动恢复配置或重启设备"
class SiliconFlowAPIException(HTTPException):
def __init__(self, detail: str, status_code: int = 400):
super().__init__(
status_code=status_code,
detail=f"SiliconFlow API error: {detail}"
)

View File

@ -1,33 +1,89 @@
import logging import logging
from loguru import logger
import sys import sys
from loguru import logger as loguru_logger
class InterceptHandler(logging.Handler): class InterceptHandler(logging.Handler):
def emit(self, record): def emit(self, record):
# Get corresponding Loguru level if it exists # 获取对应的Loguru日志级别
try: try:
level = logger.level(record.levelname).name level = loguru_logger.level(record.levelname).name
except ValueError: except ValueError:
level = record.levelno level = record.levelno
# Find caller from where originated the logged message # 查找日志来源
frame, depth = logging.currentframe(), 2 frame, depth = logging.currentframe(), 2
while frame.f_code.co_filename == logging.__file__: while frame and frame.f_code.co_filename == logging.__file__:
frame = frame.f_back frame = frame.f_back
depth += 1 depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) # 使用Loguru记录日志
loguru_logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
def setup_logging(): def setup_logging():
# 拦截标准logging # 拦截标准logging
logging.basicConfig(handlers=[InterceptHandler()], level=0) logging.basicConfig(handlers=[InterceptHandler()], level=logging.NOTSET)
# 配置loguru # 移除所有现有处理器
logger.configure( loguru_logger.remove()
handlers=[
{"sink": sys.stdout, # 添加新的处理器
"format": "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"} loguru_logger.add(
] sys.stdout,
format=(
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
"<level>{message}</level>"
),
level="DEBUG",
enqueue=True
) )
# 添加文件日志
loguru_logger.add(
"app.log",
rotation="10 MB",
retention="30 days",
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level} | {message}"
)
# 创建通用logger接口
class Logger:
@staticmethod
def debug(msg, *args, **kwargs):
loguru_logger.debug(msg, *args, **kwargs)
@staticmethod
def info(msg, *args, **kwargs):
loguru_logger.info(msg, *args, **kwargs)
@staticmethod
def warning(msg, *args, **kwargs):
loguru_logger.warning(msg, *args, **kwargs)
@staticmethod
def error(msg, *args, **kwargs):
loguru_logger.error(msg, *args, **kwargs)
@staticmethod
def critical(msg, *args, **kwargs):
loguru_logger.critical(msg, *args, **kwargs)
@staticmethod
def exception(msg, *args, **kwargs):
loguru_logger.exception(msg, *args, **kwargs)
@staticmethod
def success(msg, *args, **kwargs):
loguru_logger.success(msg, *args, **kwargs)
# 创建全局logger实例
logger = Logger()
# 初始化日志系统
setup_logging()

View File

@ -1,25 +1,30 @@
from pydantic import BaseModel
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from dotenv import load_dotenv from dotenv import load_dotenv
import os
load_dotenv() load_dotenv()
import os
class Settings(BaseSettings): class Settings(BaseSettings):
APP_NAME: str = "AI Network Configurator" APP_NAME: str = "AI Network Configurator"
DEBUG: bool = True DEBUG: bool = True
API_PREFIX: str = "/api" API_PREFIX: str = "/api"
# 硅基流动API配置 # 硅基流动API配置
SILICONFLOW_API_KEY: str = os.getenv("SILICON_API_KEY", "") SILICONFLOW_API_KEY: str = os.getenv("SILICONFLOW_API_KEY", "sk-mhzuedasunrgdrxfkcxmxgaypgjnxgodvvmrzzdbqrwtkqej")
SILICONFLOW_API_URL: str = os.getenv("SILICONFLOW_API_URL", "https://api.siliconflow.ai/v1") SILICONFLOW_API_URL: str = os.getenv("SILICONFLOW_API_URL", "https://api.siliconflow.cn/v1")
# 交换机配置 # 交换机配置
SWITCH_USERNAME: str = os.getenv("SWITCH_USERNAME", "admin")
SWITCH_PASSWORD: str = os.getenv("SWITCH_PASSWORD", "admin")
SWITCH_TIMEOUT: int = os.getenv("SWITCH_TIMEOUT", 10) SWITCH_TIMEOUT: int = os.getenv("SWITCH_TIMEOUT", 10)
# eNSP配置
ENSP_DEFAULT_IP: str = "172.17.99.201"
ENSP_DEFAULT_PORT: int = 2000
class Config: class Config:
env_file = ".env" env_file = ".env"
extra = "ignore"
settings = Settings() settings = Settings()

View File

@ -1,19 +1,34 @@
fastapi>=0.95.2 # 核心依赖 -i https://pypi.tuna.tsinghua.edu.cn/simple
uvicorn>=0.22.0 fastapi==0.110.0
python-dotenv>=1.0.0 uvicorn==0.29.0
requests>=2.28.2 python-dotenv==1.0.1
paramiko>=3.3.0
pydantic>=1.10.7 # Pydantic 模型 -i https://pypi.tuna.tsinghua.edu.cn/simple
loguru>=0.7.0 pydantic==2.6.4
python-nmap>=0.7.1 pydantic-settings==2.2.1
tenacity>=9.1.2
typing-extensions>=4.0.0 # 网络操作 -i https://pypi.tuna.tsinghua.edu.cn/simple
aiofiles>=24.1.0 asyncssh==2.14.2
telnetlib3>=2.0.4 telnetlib3==2.0.3
asyncssh>=2.14.0 httpx==0.27.0
aiofiles>=24.1.0 python-nmap==0.7.1
networkx==3.1
scipy==1.11.1
stable-baselines3==2.0.0 # 异步文件操作 -i https://pypi.tuna.tsinghua.edu.cn/simple
plotly==5.15.0 aiofiles==23.2.1
pandas==2.0.3
# 日志管理 -i https://pypi.tuna.tsinghua.edu.cn/simple
loguru==0.7.2
# 重试机制 -i https://pypi.tuna.tsinghua.edu.cn/simple
tenacity==8.2.3
# 其他工具 -i https://pypi.tuna.tsinghua.edu.cn/simple
asyncio==3.4.3
typing_extensions==4.10.0
#监控依赖 Y
psutil==5.9.8
matplotlib==3.8.3
sqlalchemy==2.0.28

View File

@ -1,11 +1,8 @@
import uvicorn import uvicorn
from src.backend.app import create_app
app = create_app()
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run( uvicorn.run(
"src.backend.app:app", "src.backend.app:app", # 使用字符串引用方式
host="0.0.0.0", host="0.0.0.0",
port=8000, port=8000,
log_level="info", log_level="info",

View File

@ -3,37 +3,71 @@ import logging
from src.backend.app.api.network_config import SwitchConfigurator from src.backend.app.api.network_config import SwitchConfigurator
#该文件用于测试 #该文件用于测试
# 设置日志
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG, level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s' format='%(asctime)s - %(levelname)s - %(message)s'
) )
async def test_ensp():
"""eNSP测试函数"""
# 1. 初始化配置器对应eNSP设备设置
configurator = SwitchConfigurator(
ensp_mode=True, # 启用eNSP模式
ensp_port=2000, # 必须与eNSP中设备设置的Telnet端口一致
username="admin", # 默认账号
password="admin", # 默认密码
timeout=15 # 建议超时设长些
)
# 2. 执行配置示例创建VLAN100 async def test_connection(configurator):
"""测试基础连接"""
try:
version = await configurator._send_commands("127.0.0.1", ["display version"])
print("交换机版本信息:\n", version)
return True
except Exception as e:
print("❌ 连接测试失败:", str(e))
return False
async def test_vlan_config(configurator):
"""测试 VLAN 配置"""
try: try:
result = await configurator.safe_apply( result = await configurator.safe_apply(
ip="127.0.0.1", # 本地连接固定用这个地址 "127.0.0.1",
config={ config={
"type": "vlan", "type": "vlan",
"vlan_id": 100, "vlan_id": 100,
"name": "测试VLAN" "name": "自动化测试VLAN"
} }
) )
print("✅ 配置结果:", result) print("VLAN 配置结果:", result)
except Exception as e:
print("❌ 配置失败:", str(e)) # 验证配置
vlan_list = await configurator._send_commands("127.0.0.1", ["display vlan"])
print("当前VLAN列表:\n", vlan_list)
return "success" in result.get("status", "")
except Exception as e:
print("❌ VLAN 配置失败:", str(e))
return False
async def main():
"""主测试流程"""
# 尝试不同端口
for port in [2000, 2010, 2020, 23]:
print(f"\n尝试端口: {port}")
configurator = SwitchConfigurator(
ensp_mode=True,
ensp_port=port,
username="",
password="admin",
timeout=15
)
if await test_connection(configurator):
print(f"✅ 成功连接到端口 {port}")
if await test_vlan_config(configurator):
print("✅ 所有测试通过!")
return
else:
print("⚠️ VLAN 配置失败,继续尝试其他端口...")
else:
print("⚠️ 连接失败,尝试下一个端口...")
print("❌ 所有端口尝试失败,请检查配置")
# 运行测试
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(test_ensp()) asyncio.run(main())