修复了异步连接

This commit is contained in:
3 2025-08-12 01:17:39 +08:00
parent 0b6b9624a6
commit 31e3baff9a
3 changed files with 177 additions and 175 deletions

View File

@ -1,15 +1,18 @@
# File: D:\Python work\AI-powered-switches\src\backend\app\api\endpoints.py
import socket
from datetime import datetime, timedelta
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
from typing import List
from pydantic import BaseModel
import asyncio
from fastapi.responses import HTMLResponse
from fastapi.responses import HTMLResponse, JSONResponse
import matplotlib.pyplot as plt
import io
import base64
import psutil
import ipaddress
import json
from ..services.switch_traffic_monitor import get_switch_monitor
from ..utils import logger
@ -23,14 +26,11 @@ from src.backend.app.api.database import SessionLocal
from ..services.network_visualizer import NetworkVisualizer
from ..services.config_validator import ConfigValidator
from ..services.report_generator import ReportGenerator
from fastapi.responses import JSONResponse
router = APIRouter(prefix="", tags=["API"])
scanner = NetworkScanner()
@router.get("/", include_in_schema=False)
async def root():
return {
@ -42,16 +42,18 @@ async def root():
"/apply_config",
"/scan_network",
"/list_devices",
"/batch_apply_config"
"/batch_apply_config",
"/traffic/switch/current",
"/traffic/switch/history"
]
}
@router.get("/favicon.ico", include_in_schema=False)
async def favicon():
return Response(status_code=204)
class BatchConfigRequest(BaseModel):
config: dict
switch_ips: List[str]
@ -59,6 +61,7 @@ class BatchConfigRequest(BaseModel):
password: str = None
timeout: int = None
@router.post("/batch_apply_config")
async def batch_apply_config(request: BatchConfigRequest):
results = {}
@ -67,22 +70,22 @@ async def batch_apply_config(request: BatchConfigRequest):
configurator = SwitchConfigurator(
username=request.username,
password=request.password,
timeout=request.timeout )
timeout=request.timeout)
results[ip] = await configurator.apply_config(ip, request.config)
except Exception as e:
results[ip] = str(e)
return {"results": results}
@router.get("/test")
async def test_endpoint():
return {"message": "Hello World"}
@router.get("/scan_network", summary="扫描网络中的交换机")
async def scan_network(subnet: str = "192.168.1.0/24"):
try:
devices = await scanner.scan_subnet(subnet)
devices = await asyncio.to_thread(scanner.scan_subnet, subnet)
return {
"success": True,
"devices": devices,
@ -91,14 +94,18 @@ async def scan_network(subnet: str = "192.168.1.0/24"):
except Exception as e:
raise HTTPException(500, f"扫描失败: {str(e)}")
@router.get("/list_devices", summary="列出已发现的交换机")
async def list_devices():
return {
"devices": await scanner.load_cached_devices()
"devices": await asyncio.to_thread(scanner.load_cached_devices)
}
class CommandRequest(BaseModel):
command: str
vendor: str = "huawei"
class ConfigRequest(BaseModel):
config: dict
@ -106,15 +113,15 @@ class ConfigRequest(BaseModel):
username: str = None
password: str = None
timeout: int = None
vendor: str = "huawei"
@router.post("/parse_command", response_model=dict)
async def parse_command(request: CommandRequest):
"""
解析中文命令并返回JSON配置
"""
"""解析中文命令并返回JSON配置"""
try:
ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
config = await ai_service.parse_command(request.command)
config = await ai_service.parse_command(request.command, request.vendor)
return {"success": True, "config": config}
except Exception as e:
raise HTTPException(
@ -122,16 +129,16 @@ async def parse_command(request: CommandRequest):
detail=f"Failed to parse command: {str(e)}"
)
@router.post("/apply_config", response_model=dict)
async def apply_config(request: ConfigRequest):
"""
应用配置到交换机弃用
"""
"""应用配置到交换机"""
try:
configurator = SwitchConfigurator(
username=request.username,
password=request.password,
timeout=request.timeout
timeout=request.timeout,
vendor=request.vendor
)
result = await configurator.safe_apply(request.switch_ip, request.config)
return {"success": True, "result": result}
@ -165,14 +172,10 @@ class CLICommandRequest(BaseModel):
return [cmd for cmd in self.commands
if not (cmd.startswith("!username=") or cmd.startswith("!password="))]
@router.post("/execute_cli_commands", response_model=dict)
async def execute_cli_commands(request: CLICommandRequest):
"""
执行前端生成的CLI命令
支持在commands中嵌入凭据:
!username=admin
!password=cisco123
"""
"""执行前端生成的CLI命令"""
try:
username, password = request.extract_credentials()
clean_commands = request.get_clean_commands()
@ -196,27 +199,32 @@ async def execute_cli_commands(request: CLICommandRequest):
except Exception as e:
raise HTTPException(500, detail=str(e))
@router.get("/traffic/interfaces", summary="获取所有网络接口")
async def get_network_interfaces():
return {
"interfaces": traffic_monitor.get_interfaces()
"interfaces": await asyncio.to_thread(traffic_monitor.get_interfaces)
}
@router.get("/traffic/current", summary="获取当前流量数据")
async def get_current_traffic(interface: str = None):
return traffic_monitor.get_current_traffic(interface)
return await asyncio.to_thread(traffic_monitor.get_current_traffic, interface)
@router.get("/traffic/history", summary="获取流量历史数据")
async def get_traffic_history(interface: str = None, limit: int = 100):
history = traffic_monitor.get_traffic_history(interface)
history = await asyncio.to_thread(traffic_monitor.get_traffic_history, interface)
return {
"sent": history["sent"][-limit:],
"recv": history["recv"][-limit:],
"time": [t.isoformat() for t in history["time"]][-limit:]
}
@router.get("/traffic/records", summary="获取流量记录")
async def get_traffic_records(interface: str = None, limit: int = 100):
def sync_get_records():
with SessionLocal() as session:
query = session.query(TrafficRecord)
if interface:
@ -224,42 +232,23 @@ async def get_traffic_records(interface: str = None, limit: int = 100):
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
return [record.to_dict() for record in records]
return await asyncio.to_thread(sync_get_records)
@router.websocket("/ws/traffic")
async def websocket_traffic(websocket: WebSocket):
"""实时流量WebSocket"""
await websocket.accept()
try:
while True:
traffic_data = traffic_monitor.get_current_traffic()
traffic_data = await asyncio.to_thread(traffic_monitor.get_current_traffic)
await websocket.send_json(traffic_data)
await asyncio.sleep(1)
except WebSocketDisconnect:
print("客户端断开连接")
@router.get("/", include_in_schema=False)
async def root():
return {
"message": "欢迎使用AI交换机配置系统",
"docs": f"{settings.API_PREFIX}/docs",
"redoc": f"{settings.API_PREFIX}/redoc",
"endpoints": [
"/parse_command",
"/apply_config",
"/scan_network",
"/list_devices",
"/batch_apply_config",
"/traffic/switch/current",
"/traffic/switch/history"
]
}
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
async def get_switch_interfaces(switch_ip: str):
"""获取指定交换机的所有接口"""
try:
monitor = get_switch_monitor(switch_ip)
interfaces = list(monitor.interface_oids.keys())
@ -272,33 +261,11 @@ async def get_switch_interfaces(switch_ip: str):
raise HTTPException(500, f"获取接口失败: {str(e)}")
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
"""获取交换机的当前流量数据"""
try:
monitor = get_switch_monitor(switch_ip)
if not interface:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
return {
"switch_ip": switch_ip,
"traffic": traffic_data
}
return await get_interface_current_traffic(switch_ip, interface)
except Exception as e:
logger.error(f"获取交换机流量失败: {str(e)}")
raise HTTPException(500, f"获取流量失败: {str(e)}")
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
"""获取指定交换机接口的当前流量数据"""
try:
def sync_get_record():
with SessionLocal() as session:
record = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface
@ -322,26 +289,44 @@ async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
"bytes_in": record.bytes_in,
"bytes_out": record.bytes_out
}
return await asyncio.to_thread(sync_get_record)
except Exception as e:
logger.error(f"获取接口流量失败: {str(e)}")
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
"""获取交换机的流量历史数据"""
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
try:
monitor = get_switch_monitor(switch_ip)
if not interface:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
return {
"switch_ip": switch_ip,
"traffic": traffic_data
}
return await get_interface_current_traffic(switch_ip, interface)
except Exception as e:
logger.error(f"获取交换机流量失败: {str(e)}")
raise HTTPException(500, f"获取流量失败: {str(e)}")
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
try:
monitor = get_switch_monitor(switch_ip)
if not interface:
return {
"switch_ip": switch_ip,
"history": monitor.get_traffic_history()
"history": await asyncio.to_thread(monitor.get_traffic_history)
}
def sync_get_history():
with SessionLocal() as session:
time_threshold = datetime.now() - timedelta(minutes=minutes)
records = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface,
@ -353,7 +338,9 @@ async def get_switch_traffic_history(switch_ip: str, interface: str = None, minu
"out": [record.rate_out for record in records],
"time": [record.timestamp.isoformat() for record in records]
}
return history_data
history_data = await asyncio.to_thread(sync_get_history)
return {
"switch_ip": switch_ip,
"interface": interface,
@ -366,11 +353,9 @@ async def get_switch_traffic_history(switch_ip: str, interface: str = None, minu
@router.websocket("/ws/traffic/switch")
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
"""交换机实时流量WebSocket"""
await websocket.accept()
try:
monitor = get_switch_monitor(switch_ip)
while True:
if interface:
traffic_data = await get_interface_current_traffic(switch_ip, interface)
@ -379,12 +364,10 @@ async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interfa
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
await websocket.send_json({
"switch_ip": switch_ip,
"traffic": traffic_data
})
await asyncio.sleep(1)
except WebSocketDisconnect:
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
@ -392,17 +375,17 @@ async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interfa
logger.error(f"交换机流量WebSocket错误: {str(e)}")
await websocket.close(code=1011, reason=str(e))
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
"""生成交换机流量图表"""
try:
history = await get_switch_traffic_history(switch_ip, interface, minutes)
history_data = history["history"]
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
in_rates = history_data["in"]
out_rates = history_data["out"]
def generate_plot():
plt.figure(figsize=(12, 6))
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
@ -418,7 +401,9 @@ async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10)
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
return image_base64
image_base64 = await asyncio.to_thread(generate_plot)
return f"""
<html>
<head>
@ -445,15 +430,14 @@ async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10)
@router.get("/network_adapters", summary="获取网络适配器网段")
async def get_network_adapters():
try:
def sync_get_adapters():
net_if_addrs = psutil.net_if_addrs()
networks = []
for interface, addrs in net_if_addrs.items():
for addr in addrs:
if addr.family == socket.AF_INET:
ip = addr.address
netmask = addr.netmask
network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False)
networks.append({
"adapter": interface,
@ -461,9 +445,10 @@ async def get_network_adapters():
"ip": ip,
"subnet_mask": netmask
})
return networks
networks = await asyncio.to_thread(sync_get_adapters)
return {"networks": networks}
except Exception as e:
return {"error": f"获取网络适配器信息失败: {str(e)}"}
@ -471,14 +456,14 @@ async def get_network_adapters():
visualizer = NetworkVisualizer()
report_gen = ReportGenerator()
@router.get("/topology/visualize", response_class=HTMLResponse)
async def visualize_topology():
"""获取网络拓扑可视化图"""
try:
devices = await list_devices() # 复用现有的设备列表接口
visualizer.update_topology(devices["devices"])
image_data = visualizer.generate_topology_image()
devices = await list_devices()
await asyncio.to_thread(visualizer.update_topology, devices["devices"])
image_data = await asyncio.to_thread(visualizer.generate_topology_image)
return f"""
<html>
<head><title>Network Topology</title></head>
@ -495,11 +480,12 @@ async def visualize_topology():
@router.post("/config/validate")
async def validate_config(config: dict):
"""验证配置有效性"""
is_valid, errors = ConfigValidator.validate_full_config(config)
is_valid, errors = await asyncio.to_thread(ConfigValidator.validate_full_config, config)
return {
"valid": is_valid,
"errors": errors,
"has_security_risks": len(ConfigValidator.check_security_risks(config.get("commands", []))) > 0
"has_security_risks": len(
await asyncio.to_thread(ConfigValidator.check_security_risks, config.get("commands", []))) > 0
}
@ -507,7 +493,7 @@ async def validate_config(config: dict):
async def get_traffic_report(ip: str, days: int = 1):
"""获取流量分析报告"""
try:
report = report_gen.generate_traffic_report(ip, days)
report = await asyncio.to_thread(report_gen.generate_traffic_report, ip, days)
return JSONResponse(content=report)
except Exception as e:
raise HTTPException(500, detail=str(e))
@ -517,7 +503,7 @@ async def get_traffic_report(ip: str, days: int = 1):
async def get_local_traffic_report(days: int = 1):
"""获取本地网络流量报告"""
try:
report = report_gen.generate_traffic_report(days=days)
report = await asyncio.to_thread(report_gen.generate_traffic_report, days=days)
return JSONResponse(content=report)
except Exception as e:
raise HTTPException(500, detail=str(e))
@ -527,7 +513,7 @@ async def get_local_traffic_report(days: int = 1):
async def get_traffic_heatmap(minutes: int = 10):
"""获取流量热力图数据"""
try:
heatmap = visualizer.get_traffic_heatmap(minutes)
heatmap = await asyncio.to_thread(visualizer.get_traffic_heatmap, minutes)
return {"heatmap": heatmap}
except Exception as e:
raise HTTPException(500, detail=str(e))

View File

@ -1,7 +1,7 @@
from typing import Dict, Any, Coroutine
import httpx
from openai import OpenAI
from openai import AsyncOpenAI
import json
from src.backend.app.utils.exceptions import SiliconFlowAPIException
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
@ -12,28 +12,44 @@ class AIService:
def __init__(self, api_key: str, api_url: str):
self.api_key = api_key
self.api_url = api_url
self.client = OpenAI(
self.client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.api_url,
# timeout=httpx.Timeout(30.0)
)
async def parse_command(self, command: str) -> Any | None:
async def parse_command(self, command: str, vendor: str = "huawei") -> Any | None:
"""
调用硅基流动API解析中文命令
"""
prompt = """
你是一个网络设备配置专家精通各种类型的路由器的配置,请将以下用户的中文命令转换为网络设备配置JSON
但是请注意由于贪婪的人们追求极高的效率所以你必须严格按照 JSON 格式返回数据不要包含任何额外文本或 Markdown 代码块
vendor_prompts = {
"huawei": "华为交换机配置命令",
"cisco": "思科交换机配置命令",
"h3c": "H3C交换机配置命令",
"ruijie": "锐捷交换机配置命令",
"zte": "中兴交换机配置命令"
}
prompt = f"""
你是一个网络设备配置专家精通各种类型的路由器的配置请将以下用户的中文命令转换为{vendor_prompts.get(vendor, '网络设备')}配置JSON
但是请注意由于贪婪的人们追求极高的效率所以你必须严格按照 JSON 格式返回数据不要包含任何额外文本或 Markdown 代码块
返回格式要求
1. 必须包含'type'字段指明配置类型(vlan/interface/acl/route等)
2. 必须包含'commands'字段包含可直接执行的命令列表
3. 其他参数根据配置类型动态添加
4. 不要包含解释性文本步骤说明或注释
5.要包含使用ssh连接交换机后的完整命令包括但不完全包括system-view退出保存等完整操作注意保存还需要输入Y
5. 要包含使用ssh连接交换机后的完整命令包括但不完全包括system-view退出保存等完整操作注意保存还需要输入Y
根据厂商{vendor}的不同命令格式如下
- 华为: system-view quit save Y
- 思科: enable configure terminal exit write memory
- H3C: system-view quit save
- 锐捷: enable configure terminal exit write
- 中兴: enable configure terminal exit write memory
示例命令'创建VLAN 100名称为TEST'
示例返回{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["system-view","vlan 100", "name TEST","quit","quit","save","Y"]}
注意这里生成的commands中需包含登录交换机和保存等所有操作命令我们使ssh连接交换机你不需要给出连接ssh的命令你只需要给出使用ssh连接到交换机后所输入的全部命令并且注意在system-view状态下是不能save的需要再quit到用户视图
华为示例返回{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["system-view","vlan 100", "name TEST","quit","quit","save","Y"]}}
思科示例返回{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["enable","configure terminal","vlan 100", "name TEST","exit","exit","write memory"]}}
"""
messages = [

View File

@ -9,7 +9,7 @@ class NetworkScanner:
self.cache_path = Path(cache_path)
self.nm = nmap.PortScanner()
async def scan_subnet(self, subnet: str = "192.168.1.0/24", ports: List[int] = [22, 23, 80]) -> List[Dict]:
def scan_subnet(self, subnet: str = "192.168.1.0/24", ports: List[int] = [22, 23, 80]) -> List[Dict]:
"""扫描指定子网的设备,获取设备信息和开放端口"""
logger.info(f"Scanning subnet: {subnet}")
@ -33,16 +33,16 @@ class NetworkScanner:
except Exception as e:
logger.error(f"Error while scanning subnet: {e}")
await self._save_to_cache(devices)
self._save_to_cache(devices)
return devices
async def _save_to_cache(self, devices: List[Dict]):
def _save_to_cache(self, devices: List[Dict]):
"""保存扫描结果到本地文件"""
with open(self.cache_path, "w") as f:
json.dump(devices, f, indent=2)
logger.info(f"Saved {len(devices)} devices to cache")
async def load_cached_devices(self) -> List[Dict]:
def load_cached_devices(self) -> List[Dict]:
"""从缓存加载设备列表"""
if not self.cache_path.exists():
return []