解决异步问题

This commit is contained in:
Jerry 2025-08-12 17:15:34 +08:00
parent 29f016bdab
commit 3260af32fc
5 changed files with 326 additions and 85 deletions

View File

@ -10,7 +10,6 @@ src/backend/
│ ├── __init__.py # 创建 Flask 应用实例
│ ├── api/ # API 路由模块
│ │ ├—── __init__.py # 注册 API 蓝图
│ │ ├── command_parser.py # /api/parse_command 接口
│ │ └── network_config.py # /api/apply_config 接口
│ └── services/ # 核心服务逻辑
│ └── ai_services.py # 调用外部 AI 服务生成配置

View File

@ -1,14 +0,0 @@
from typing import Dict, Any
from src.backend.app.services.ai_services import AIService
from src.backend.config import settings
class CommandParser:
def __init__(self):
self.ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
async def parse(self, command: str) -> Dict[str, Any]:
"""
解析中文命令并返回配置
"""
return await self.ai_service.parse_command(command)

View File

@ -1,19 +1,30 @@
import socket
from fastapi import (APIRouter, HTTPException, Response)
from datetime import datetime, timedelta
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
from typing import List
from pydantic import BaseModel
import asyncio
from fastapi.responses import HTMLResponse, JSONResponse
import matplotlib.pyplot as plt
import io
import base64
import psutil
import ipaddress
from ..services.switch_traffic_monitor import get_switch_monitor
from ..utils import logger
from ...app.services.ai_services import AIService
from ...app.api.network_config import SwitchConfigurator
from ...config import settings
from ..services.network_scanner import NetworkScanner
from ...app.services.traffic_monitor import traffic_monitor
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
router = APIRouter(prefix="", tags=["API"])
scanner = NetworkScanner()
@router.get("/", include_in_schema=False)
async def root():
return {
@ -25,16 +36,18 @@ async def root():
"/apply_config",
"/scan_network",
"/list_devices",
"/batch_apply_config"
"/batch_apply_config",
"/traffic/switch/current",
"/traffic/switch/history"
]
}
@router.get("/favicon.ico", include_in_schema=False)
async def favicon():
return Response(status_code=204)
class BatchConfigRequest(BaseModel):
config: dict
switch_ips: List[str]
@ -42,14 +55,31 @@ class BatchConfigRequest(BaseModel):
password: str = None
timeout: int = None
@router.post("/batch_apply_config")
async def batch_apply_config(request: BatchConfigRequest):
results = {}
for ip in request.switch_ips:
try:
configurator = SwitchConfigurator(
username=request.username,
password=request.password,
timeout=request.timeout)
results[ip] = await configurator.apply_config(ip, request.config)
except Exception as e:
results[ip] = str(e)
return {"results": results}
@router.get("/test")
async def test_endpoint():
return {"message": "Hello World"}
@router.get("/scan_network", summary="扫描网络中的交换机")
async def scan_network(subnet: str = "192.168.1.0/24"):
try:
devices = await scanner.scan_subnet(subnet)
devices = await asyncio.to_thread(scanner.scan_subnet, subnet)
return {
"success": True,
"devices": devices,
@ -58,14 +88,18 @@ async def scan_network(subnet: str = "192.168.1.0/24"):
except Exception as e:
raise HTTPException(500, f"扫描失败: {str(e)}")
@router.get("/list_devices", summary="列出已发现的交换机")
async def list_devices():
return {
"devices": await scanner.load_cached_devices()
"devices": await asyncio.to_thread(scanner.load_cached_devices)
}
class CommandRequest(BaseModel):
command: str
vendor: str = "huawei"
class ConfigRequest(BaseModel):
config: dict
@ -73,15 +107,15 @@ class ConfigRequest(BaseModel):
username: str = None
password: str = None
timeout: int = None
vendor: str = "huawei"
@router.post("/parse_command", response_model=dict)
async def parse_command(request: CommandRequest):
"""
解析中文命令并返回JSON配置
"""
"""解析中文命令并返回JSON配置"""
try:
ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
config = await ai_service.parse_command(request.command)
config = await ai_service.parse_command(request.command, request.vendor)
return {"success": True, "config": config}
except Exception as e:
raise HTTPException(
@ -89,16 +123,16 @@ async def parse_command(request: CommandRequest):
detail=f"Failed to parse command: {str(e)}"
)
@router.post("/apply_config", response_model=dict)
async def apply_config(request: ConfigRequest):
"""
应用配置到交换机弃用
"""
"""应用配置到交换机"""
try:
configurator = SwitchConfigurator(
username=request.username,
password=request.password,
timeout=request.timeout
timeout=request.timeout,
vendor=request.vendor
)
result = await configurator.safe_apply(request.switch_ip, request.config)
return {"success": True, "result": result}
@ -132,14 +166,10 @@ class CLICommandRequest(BaseModel):
return [cmd for cmd in self.commands
if not (cmd.startswith("!username=") or cmd.startswith("!password="))]
@router.post("/execute_cli_commands", response_model=dict)
async def execute_cli_commands(request: CLICommandRequest):
"""
执行前端生成的CLI命令
支持在commands中嵌入凭据:
!username=admin
!password=cisco123
"""
"""执行前端生成的CLI命令"""
try:
username, password = request.extract_credentials()
clean_commands = request.get_clean_commands()
@ -163,34 +193,245 @@ async def execute_cli_commands(request: CLICommandRequest):
except Exception as e:
raise HTTPException(500, detail=str(e))
@router.get("/", include_in_schema=False)
async def root():
@router.get("/traffic/interfaces", summary="获取所有网络接口")
async def get_network_interfaces():
return {
"message": "欢迎使用AI交换机配置系统",
"docs": f"{settings.API_PREFIX}/docs",
"redoc": f"{settings.API_PREFIX}/redoc",
"endpoints": [
"/parse_command",
"/apply_config",
"/scan_network",
"/list_devices",
"/batch_apply_config",
"/traffic/switch/current",
"/traffic/switch/history"
]
"interfaces": await asyncio.to_thread(traffic_monitor.get_interfaces)
}
@router.get("/traffic/current", summary="获取当前流量数据")
async def get_current_traffic(interface: str = None):
return await asyncio.to_thread(traffic_monitor.get_current_traffic, interface)
@router.get("/traffic/history", summary="获取流量历史数据")
async def get_traffic_history(interface: str = None, limit: int = 100):
history = await asyncio.to_thread(traffic_monitor.get_traffic_history, interface)
return {
"sent": history["sent"][-limit:],
"recv": history["recv"][-limit:],
"time": [t.isoformat() for t in history["time"]][-limit:]
}
@router.get("/traffic/records", summary="获取流量记录")
async def get_traffic_records(interface: str = None, limit: int = 100):
def sync_get_records():
with SessionLocal() as session:
query = session.query(TrafficRecord)
if interface:
query = query.filter(TrafficRecord.interface == interface)
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
return [record.to_dict() for record in records]
return await asyncio.to_thread(sync_get_records)
@router.websocket("/ws/traffic")
async def websocket_traffic(websocket: WebSocket):
await websocket.accept()
try:
while True:
traffic_data = await asyncio.to_thread(traffic_monitor.get_current_traffic)
await websocket.send_json(traffic_data)
await asyncio.sleep(1)
except WebSocketDisconnect:
print("客户端断开连接")
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
async def get_switch_interfaces(switch_ip: str):
try:
monitor = get_switch_monitor(switch_ip)
interfaces = list(monitor.interface_oids.keys())
return {
"switch_ip": switch_ip,
"interfaces": interfaces
}
except Exception as e:
logger.error(f"获取交换机接口失败: {str(e)}")
raise HTTPException(500, f"获取接口失败: {str(e)}")
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
"""获取指定交换机接口的当前流量数据"""
try:
def sync_get_record():
with SessionLocal() as session:
record = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
if not record:
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": 0.0,
"rate_out": 0.0,
"bytes_in": 0,
"bytes_out": 0
}
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": record.rate_in,
"rate_out": record.rate_out,
"bytes_in": record.bytes_in,
"bytes_out": record.bytes_out
}
return await asyncio.to_thread(sync_get_record)
except Exception as e:
logger.error(f"获取接口流量失败: {str(e)}")
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
try:
monitor = get_switch_monitor(switch_ip)
if not interface:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
return {
"switch_ip": switch_ip,
"traffic": traffic_data
}
return await get_interface_current_traffic(switch_ip, interface)
except Exception as e:
logger.error(f"获取交换机流量失败: {str(e)}")
raise HTTPException(500, f"获取流量失败: {str(e)}")
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
try:
monitor = get_switch_monitor(switch_ip)
if not interface:
return {
"switch_ip": switch_ip,
"history": await asyncio.to_thread(monitor.get_traffic_history)
}
def sync_get_history():
with SessionLocal() as session:
time_threshold = datetime.now() - timedelta(minutes=minutes)
records = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface,
SwitchTrafficRecord.timestamp >= time_threshold
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
history_data = {
"in": [record.rate_in for record in records],
"out": [record.rate_out for record in records],
"time": [record.timestamp.isoformat() for record in records]
}
return history_data
history_data = await asyncio.to_thread(sync_get_history)
return {
"switch_ip": switch_ip,
"interface": interface,
"history": history_data
}
except Exception as e:
logger.error(f"获取历史流量失败: {str(e)}")
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
@router.websocket("/ws/traffic/switch")
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
await websocket.accept()
try:
monitor = get_switch_monitor(switch_ip)
while True:
if interface:
traffic_data = await get_interface_current_traffic(switch_ip, interface)
await websocket.send_json(traffic_data)
else:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
await websocket.send_json({
"switch_ip": switch_ip,
"traffic": traffic_data
})
await asyncio.sleep(1)
except WebSocketDisconnect:
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
except Exception as e:
logger.error(f"交换机流量WebSocket错误: {str(e)}")
await websocket.close(code=1011, reason=str(e))
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
try:
history = await get_switch_traffic_history(switch_ip, interface, minutes)
history_data = history["history"]
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
in_rates = history_data["in"]
out_rates = history_data["out"]
def generate_plot():
plt.figure(figsize=(12, 6))
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
plt.xlabel("时间")
plt.ylabel("流量 (字节/秒)")
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
return image_base64
image_base64 = await asyncio.to_thread(generate_plot)
return f"""
<html>
<head>
<title>交换机流量监控</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.container {{ max-width: 900px; margin: 0 auto; }}
</style>
</head>
<body>
<div class="container">
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</div>
</body>
</html>
"""
except Exception as e:
logger.error(f"生成流量图表失败: {str(e)}")
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)
@router.get("/network_adapters", summary="获取网络适配器网段")
async def get_network_adapters():
try:
def sync_get_adapters():
net_if_addrs = psutil.net_if_addrs()
networks = []
for interface, addrs in net_if_addrs.items():
for addr in addrs:
if addr.family == socket.AF_INET:
ip = addr.address
netmask = addr.netmask
network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False)
networks.append({
"adapter": interface,
@ -198,8 +439,9 @@ async def get_network_adapters():
"ip": ip,
"subnet_mask": netmask
})
return networks
networks = await asyncio.to_thread(sync_get_adapters)
return {"networks": networks}
except Exception as e:
return {"error": f"获取网络适配器信息失败: {str(e)}"}

View File

@ -1,7 +1,5 @@
from typing import Dict, Any, Coroutine
import httpx
from openai import OpenAI
from typing import Any
from openai import AsyncOpenAI
import json
from src.backend.app.utils.exceptions import SiliconFlowAPIException
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
@ -12,28 +10,44 @@ class AIService:
def __init__(self, api_key: str, api_url: str):
self.api_key = api_key
self.api_url = api_url
self.client = OpenAI(
self.client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.api_url,
# timeout=httpx.Timeout(30.0)
)
async def parse_command(self, command: str) -> Any | None:
async def parse_command(self, command: str, vendor: str = "huawei") -> Any | None:
"""
调用硅基流动API解析中文命令
"""
prompt = """
你是一个网络设备配置专家精通各种类型的路由器的配置,请将以下用户的中文命令转换为网络设备配置JSON
但是请注意由于贪婪的人们追求极高的效率所以你必须严格按照 JSON 格式返回数据不要包含任何额外文本或 Markdown 代码块
vendor_prompts = {
"huawei": "华为交换机配置命令",
"cisco": "思科交换机配置命令",
"h3c": "H3C交换机配置命令",
"ruijie": "锐捷交换机配置命令",
"zte": "中兴交换机配置命令"
}
prompt = f"""
你是一个网络设备配置专家精通各种类型的路由器的配置请将以下用户的中文命令转换为{vendor_prompts.get(vendor, '网络设备')}配置JSON
但是请注意由于贪婪的人们追求极高的效率所以你必须严格按照 JSON 格式返回数据不要包含任何额外文本或 Markdown 代码块
返回格式要求
1. 必须包含'type'字段指明配置类型(vlan/interface/acl/route等)
2. 必须包含'commands'字段包含可直接执行的命令列表
3. 其他参数根据配置类型动态添加
4. 不要包含解释性文本步骤说明或注释
5. 要包含使用ssh连接交换机后的完整命令包括但不完全包括system-view退出保存等完整操作注意保存还需要输入Y
根据厂商{vendor}的不同命令格式如下
- 华为: system-view quit save Y
- 思科: enable configure terminal exit write memory
- H3C: system-view quit save
- 锐捷: enable configure terminal exit write
- 中兴: enable configure terminal exit write memory
示例命令'创建VLAN 100名称为TEST'
示例返回{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["system-view","vlan 100", "name TEST","quit","quit","save","Y"]}
注意这里生成的commands中需包含登录交换机和保存等所有操作命令我们使ssh连接交换机你不需要给出连接ssh的命令你只需要给出使用ssh连接到交换机后所输入的全部命令并且注意在system-view状态下是不能save的需要再quit到用户视图
华为示例返回{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["system-view","vlan 100", "name TEST","quit","quit","save","Y"]}}
思科示例返回{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["enable","configure terminal","vlan 100", "name TEST","exit","exit","write memory"]}}
"""
messages = [
@ -42,7 +56,7 @@ class AIService:
]
try:
response = self.client.chat.completions.create(
response = await self.client.chat.completions.create(
model="deepseek-ai/DeepSeek-V3",
messages=messages,
temperature=0.3,

View File

@ -9,13 +9,13 @@ class NetworkScanner:
self.cache_path = Path(cache_path)
self.nm = nmap.PortScanner()
async def scan_subnet(self, subnet: str = "192.168.1.0/24", ports: List[int] = [22, 23, 80]) -> List[Dict]:
def scan_subnet(self, subnet: str = "192.168.1.0/24", ports: List[int] = [22, 23, 80]) -> List[Dict]:
"""扫描指定子网的设备,获取设备信息和开放端口"""
logger.info(f"Scanning subnet: {subnet}")
devices = []
try:
await self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}')
self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}')
for host in self.nm.all_hosts():
ip = host
mac = self.nm[host]['addresses'].get('mac', 'N/A')
@ -33,16 +33,16 @@ class NetworkScanner:
except Exception as e:
logger.error(f"Error while scanning subnet: {e}")
await self._save_to_cache(devices)
self._save_to_cache(devices)
return devices
async def _save_to_cache(self, devices: List[Dict]):
def _save_to_cache(self, devices: List[Dict]):
"""保存扫描结果到本地文件"""
with open(self.cache_path, "w") as f:
json.dump(devices, f, indent=2)
logger.info(f"Saved {len(devices)} devices to cache")
async def load_cached_devices(self) -> List[Dict]:
def load_cached_devices(self) -> List[Dict]:
"""从缓存加载设备列表"""
if not self.cache_path.exists():
return []