解决异步问题

This commit is contained in:
Jerry 2025-08-12 17:15:34 +08:00
parent 29f016bdab
commit 3260af32fc
5 changed files with 326 additions and 85 deletions

View File

@ -10,7 +10,6 @@ src/backend/
│ ├── __init__.py # 创建 Flask 应用实例 │ ├── __init__.py # 创建 Flask 应用实例
│ ├── api/ # API 路由模块 │ ├── api/ # API 路由模块
│ │ ├—── __init__.py # 注册 API 蓝图 │ │ ├—── __init__.py # 注册 API 蓝图
│ │ ├── command_parser.py # /api/parse_command 接口
│ │ └── network_config.py # /api/apply_config 接口 │ │ └── network_config.py # /api/apply_config 接口
│ └── services/ # 核心服务逻辑 │ └── services/ # 核心服务逻辑
│ └── ai_services.py # 调用外部 AI 服务生成配置 │ └── ai_services.py # 调用外部 AI 服务生成配置

View File

@ -1,14 +0,0 @@
from typing import Dict, Any
from src.backend.app.services.ai_services import AIService
from src.backend.config import settings
class CommandParser:
def __init__(self):
self.ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
async def parse(self, command: str) -> Dict[str, Any]:
"""
解析中文命令并返回配置
"""
return await self.ai_service.parse_command(command)

View File

@ -1,19 +1,30 @@
import socket import socket
from fastapi import (APIRouter, HTTPException, Response) from datetime import datetime, timedelta
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
from typing import List from typing import List
from pydantic import BaseModel from pydantic import BaseModel
import asyncio
from fastapi.responses import HTMLResponse, JSONResponse
import matplotlib.pyplot as plt
import io
import base64
import psutil import psutil
import ipaddress import ipaddress
from ..services.switch_traffic_monitor import get_switch_monitor
from ..utils import logger
from ...app.services.ai_services import AIService from ...app.services.ai_services import AIService
from ...app.api.network_config import SwitchConfigurator from ...app.api.network_config import SwitchConfigurator
from ...config import settings from ...config import settings
from ..services.network_scanner import NetworkScanner from ..services.network_scanner import NetworkScanner
from ...app.services.traffic_monitor import traffic_monitor
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
router = APIRouter(prefix="", tags=["API"]) router = APIRouter(prefix="", tags=["API"])
scanner = NetworkScanner() scanner = NetworkScanner()
@router.get("/", include_in_schema=False) @router.get("/", include_in_schema=False)
async def root(): async def root():
return { return {
@ -25,16 +36,18 @@ async def root():
"/apply_config", "/apply_config",
"/scan_network", "/scan_network",
"/list_devices", "/list_devices",
"/batch_apply_config" "/batch_apply_config",
"/traffic/switch/current", "/traffic/switch/current",
"/traffic/switch/history" "/traffic/switch/history"
] ]
} }
@router.get("/favicon.ico", include_in_schema=False) @router.get("/favicon.ico", include_in_schema=False)
async def favicon(): async def favicon():
return Response(status_code=204) return Response(status_code=204)
class BatchConfigRequest(BaseModel): class BatchConfigRequest(BaseModel):
config: dict config: dict
switch_ips: List[str] switch_ips: List[str]
@ -42,14 +55,31 @@ class BatchConfigRequest(BaseModel):
password: str = None password: str = None
timeout: int = None timeout: int = None
@router.post("/batch_apply_config")
async def batch_apply_config(request: BatchConfigRequest):
results = {}
for ip in request.switch_ips:
try:
configurator = SwitchConfigurator(
username=request.username,
password=request.password,
timeout=request.timeout)
results[ip] = await configurator.apply_config(ip, request.config)
except Exception as e:
results[ip] = str(e)
return {"results": results}
@router.get("/test") @router.get("/test")
async def test_endpoint(): async def test_endpoint():
return {"message": "Hello World"} return {"message": "Hello World"}
@router.get("/scan_network", summary="扫描网络中的交换机") @router.get("/scan_network", summary="扫描网络中的交换机")
async def scan_network(subnet: str = "192.168.1.0/24"): async def scan_network(subnet: str = "192.168.1.0/24"):
try: try:
devices = await scanner.scan_subnet(subnet) devices = await asyncio.to_thread(scanner.scan_subnet, subnet)
return { return {
"success": True, "success": True,
"devices": devices, "devices": devices,
@ -58,14 +88,18 @@ async def scan_network(subnet: str = "192.168.1.0/24"):
except Exception as e: except Exception as e:
raise HTTPException(500, f"扫描失败: {str(e)}") raise HTTPException(500, f"扫描失败: {str(e)}")
@router.get("/list_devices", summary="列出已发现的交换机") @router.get("/list_devices", summary="列出已发现的交换机")
async def list_devices(): async def list_devices():
return { return {
"devices": await scanner.load_cached_devices() "devices": await asyncio.to_thread(scanner.load_cached_devices)
} }
class CommandRequest(BaseModel): class CommandRequest(BaseModel):
command: str command: str
vendor: str = "huawei"
class ConfigRequest(BaseModel): class ConfigRequest(BaseModel):
config: dict config: dict
@ -73,15 +107,15 @@ class ConfigRequest(BaseModel):
username: str = None username: str = None
password: str = None password: str = None
timeout: int = None timeout: int = None
vendor: str = "huawei"
@router.post("/parse_command", response_model=dict) @router.post("/parse_command", response_model=dict)
async def parse_command(request: CommandRequest): async def parse_command(request: CommandRequest):
""" """解析中文命令并返回JSON配置"""
解析中文命令并返回JSON配置
"""
try: try:
ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL) ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
config = await ai_service.parse_command(request.command) config = await ai_service.parse_command(request.command, request.vendor)
return {"success": True, "config": config} return {"success": True, "config": config}
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -89,16 +123,16 @@ async def parse_command(request: CommandRequest):
detail=f"Failed to parse command: {str(e)}" detail=f"Failed to parse command: {str(e)}"
) )
@router.post("/apply_config", response_model=dict) @router.post("/apply_config", response_model=dict)
async def apply_config(request: ConfigRequest): async def apply_config(request: ConfigRequest):
""" """应用配置到交换机"""
应用配置到交换机弃用
"""
try: try:
configurator = SwitchConfigurator( configurator = SwitchConfigurator(
username=request.username, username=request.username,
password=request.password, password=request.password,
timeout=request.timeout timeout=request.timeout,
vendor=request.vendor
) )
result = await configurator.safe_apply(request.switch_ip, request.config) result = await configurator.safe_apply(request.switch_ip, request.config)
return {"success": True, "result": result} return {"success": True, "result": result}
@ -132,14 +166,10 @@ class CLICommandRequest(BaseModel):
return [cmd for cmd in self.commands return [cmd for cmd in self.commands
if not (cmd.startswith("!username=") or cmd.startswith("!password="))] if not (cmd.startswith("!username=") or cmd.startswith("!password="))]
@router.post("/execute_cli_commands", response_model=dict) @router.post("/execute_cli_commands", response_model=dict)
async def execute_cli_commands(request: CLICommandRequest): async def execute_cli_commands(request: CLICommandRequest):
""" """执行前端生成的CLI命令"""
执行前端生成的CLI命令
支持在commands中嵌入凭据:
!username=admin
!password=cisco123
"""
try: try:
username, password = request.extract_credentials() username, password = request.extract_credentials()
clean_commands = request.get_clean_commands() clean_commands = request.get_clean_commands()
@ -163,43 +193,255 @@ async def execute_cli_commands(request: CLICommandRequest):
except Exception as e: except Exception as e:
raise HTTPException(500, detail=str(e)) raise HTTPException(500, detail=str(e))
@router.get("/", include_in_schema=False)
async def root(): @router.get("/traffic/interfaces", summary="获取所有网络接口")
async def get_network_interfaces():
return { return {
"message": "欢迎使用AI交换机配置系统", "interfaces": await asyncio.to_thread(traffic_monitor.get_interfaces)
"docs": f"{settings.API_PREFIX}/docs",
"redoc": f"{settings.API_PREFIX}/redoc",
"endpoints": [
"/parse_command",
"/apply_config",
"/scan_network",
"/list_devices",
"/batch_apply_config",
"/traffic/switch/current",
"/traffic/switch/history"
]
} }
@router.get("/traffic/current", summary="获取当前流量数据")
async def get_current_traffic(interface: str = None):
return await asyncio.to_thread(traffic_monitor.get_current_traffic, interface)
@router.get("/traffic/history", summary="获取流量历史数据")
async def get_traffic_history(interface: str = None, limit: int = 100):
history = await asyncio.to_thread(traffic_monitor.get_traffic_history, interface)
return {
"sent": history["sent"][-limit:],
"recv": history["recv"][-limit:],
"time": [t.isoformat() for t in history["time"]][-limit:]
}
@router.get("/traffic/records", summary="获取流量记录")
async def get_traffic_records(interface: str = None, limit: int = 100):
def sync_get_records():
with SessionLocal() as session:
query = session.query(TrafficRecord)
if interface:
query = query.filter(TrafficRecord.interface == interface)
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
return [record.to_dict() for record in records]
return await asyncio.to_thread(sync_get_records)
@router.websocket("/ws/traffic")
async def websocket_traffic(websocket: WebSocket):
await websocket.accept()
try:
while True:
traffic_data = await asyncio.to_thread(traffic_monitor.get_current_traffic)
await websocket.send_json(traffic_data)
await asyncio.sleep(1)
except WebSocketDisconnect:
print("客户端断开连接")
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
async def get_switch_interfaces(switch_ip: str):
try:
monitor = get_switch_monitor(switch_ip)
interfaces = list(monitor.interface_oids.keys())
return {
"switch_ip": switch_ip,
"interfaces": interfaces
}
except Exception as e:
logger.error(f"获取交换机接口失败: {str(e)}")
raise HTTPException(500, f"获取接口失败: {str(e)}")
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
"""获取指定交换机接口的当前流量数据"""
try:
def sync_get_record():
with SessionLocal() as session:
record = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
if not record:
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": 0.0,
"rate_out": 0.0,
"bytes_in": 0,
"bytes_out": 0
}
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": record.rate_in,
"rate_out": record.rate_out,
"bytes_in": record.bytes_in,
"bytes_out": record.bytes_out
}
return await asyncio.to_thread(sync_get_record)
except Exception as e:
logger.error(f"获取接口流量失败: {str(e)}")
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
try:
monitor = get_switch_monitor(switch_ip)
if not interface:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
return {
"switch_ip": switch_ip,
"traffic": traffic_data
}
return await get_interface_current_traffic(switch_ip, interface)
except Exception as e:
logger.error(f"获取交换机流量失败: {str(e)}")
raise HTTPException(500, f"获取流量失败: {str(e)}")
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
try:
monitor = get_switch_monitor(switch_ip)
if not interface:
return {
"switch_ip": switch_ip,
"history": await asyncio.to_thread(monitor.get_traffic_history)
}
def sync_get_history():
with SessionLocal() as session:
time_threshold = datetime.now() - timedelta(minutes=minutes)
records = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface,
SwitchTrafficRecord.timestamp >= time_threshold
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
history_data = {
"in": [record.rate_in for record in records],
"out": [record.rate_out for record in records],
"time": [record.timestamp.isoformat() for record in records]
}
return history_data
history_data = await asyncio.to_thread(sync_get_history)
return {
"switch_ip": switch_ip,
"interface": interface,
"history": history_data
}
except Exception as e:
logger.error(f"获取历史流量失败: {str(e)}")
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
@router.websocket("/ws/traffic/switch")
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
await websocket.accept()
try:
monitor = get_switch_monitor(switch_ip)
while True:
if interface:
traffic_data = await get_interface_current_traffic(switch_ip, interface)
await websocket.send_json(traffic_data)
else:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
await websocket.send_json({
"switch_ip": switch_ip,
"traffic": traffic_data
})
await asyncio.sleep(1)
except WebSocketDisconnect:
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
except Exception as e:
logger.error(f"交换机流量WebSocket错误: {str(e)}")
await websocket.close(code=1011, reason=str(e))
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
try:
history = await get_switch_traffic_history(switch_ip, interface, minutes)
history_data = history["history"]
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
in_rates = history_data["in"]
out_rates = history_data["out"]
def generate_plot():
plt.figure(figsize=(12, 6))
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
plt.xlabel("时间")
plt.ylabel("流量 (字节/秒)")
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
return image_base64
image_base64 = await asyncio.to_thread(generate_plot)
return f"""
<html>
<head>
<title>交换机流量监控</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.container {{ max-width: 900px; margin: 0 auto; }}
</style>
</head>
<body>
<div class="container">
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</div>
</body>
</html>
"""
except Exception as e:
logger.error(f"生成流量图表失败: {str(e)}")
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)
@router.get("/network_adapters", summary="获取网络适配器网段") @router.get("/network_adapters", summary="获取网络适配器网段")
async def get_network_adapters(): async def get_network_adapters():
try: try:
net_if_addrs = psutil.net_if_addrs() def sync_get_adapters():
net_if_addrs = psutil.net_if_addrs()
networks = [] networks = []
for interface, addrs in net_if_addrs.items(): for interface, addrs in net_if_addrs.items():
for addr in addrs: for addr in addrs:
if addr.family == socket.AF_INET: if addr.family == socket.AF_INET:
ip = addr.address ip = addr.address
netmask = addr.netmask netmask = addr.netmask
network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False)
network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False) networks.append({
networks.append({ "adapter": interface,
"adapter": interface, "network": str(network),
"network": str(network), "ip": ip,
"ip": ip, "subnet_mask": netmask
"subnet_mask": netmask })
}) return networks
networks = await asyncio.to_thread(sync_get_adapters)
return {"networks": networks} return {"networks": networks}
except Exception as e: except Exception as e:
return {"error": f"获取网络适配器信息失败: {str(e)}"} return {"error": f"获取网络适配器信息失败: {str(e)}"}

View File

@ -1,7 +1,5 @@
from typing import Dict, Any, Coroutine from typing import Any
from openai import AsyncOpenAI
import httpx
from openai import OpenAI
import json import json
from src.backend.app.utils.exceptions import SiliconFlowAPIException from src.backend.app.utils.exceptions import SiliconFlowAPIException
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
@ -12,28 +10,44 @@ class AIService:
def __init__(self, api_key: str, api_url: str): def __init__(self, api_key: str, api_url: str):
self.api_key = api_key self.api_key = api_key
self.api_url = api_url self.api_url = api_url
self.client = OpenAI( self.client = AsyncOpenAI(
api_key=self.api_key, api_key=self.api_key,
base_url=self.api_url, base_url=self.api_url,
# timeout=httpx.Timeout(30.0) # timeout=httpx.Timeout(30.0)
) )
async def parse_command(self, command: str) -> Any | None: async def parse_command(self, command: str, vendor: str = "huawei") -> Any | None:
""" """
调用硅基流动API解析中文命令 调用硅基流动API解析中文命令
""" """
prompt = """ vendor_prompts = {
你是一个网络设备配置专家精通各种类型的路由器的配置,请将以下用户的中文命令转换为网络设备配置JSON "huawei": "华为交换机配置命令",
但是请注意由于贪婪的人们追求极高的效率所以你必须严格按照 JSON 格式返回数据不要包含任何额外文本或 Markdown 代码块 "cisco": "思科交换机配置命令",
"h3c": "H3C交换机配置命令",
"ruijie": "锐捷交换机配置命令",
"zte": "中兴交换机配置命令"
}
prompt = f"""
你是一个网络设备配置专家精通各种类型的路由器的配置请将以下用户的中文命令转换为{vendor_prompts.get(vendor, '网络设备')}配置JSON
但是请注意由于贪婪的人们追求极高的效率所以你必须严格按照 JSON 格式返回数据不要包含任何额外文本或 Markdown 代码块
返回格式要求 返回格式要求
1. 必须包含'type'字段指明配置类型(vlan/interface/acl/route等) 1. 必须包含'type'字段指明配置类型(vlan/interface/acl/route等)
2. 必须包含'commands'字段包含可直接执行的命令列表 2. 必须包含'commands'字段包含可直接执行的命令列表
3. 其他参数根据配置类型动态添加 3. 其他参数根据配置类型动态添加
4. 不要包含解释性文本步骤说明或注释 4. 不要包含解释性文本步骤说明或注释
5.要包含使用ssh连接交换机后的完整命令包括但不完全包括system-view退出保存等完整操作注意保存还需要输入Y 5. 要包含使用ssh连接交换机后的完整命令包括但不完全包括system-view退出保存等完整操作注意保存还需要输入Y
根据厂商{vendor}的不同命令格式如下
- 华为: system-view quit save Y
- 思科: enable configure terminal exit write memory
- H3C: system-view quit save
- 锐捷: enable configure terminal exit write
- 中兴: enable configure terminal exit write memory
示例命令'创建VLAN 100名称为TEST' 示例命令'创建VLAN 100名称为TEST'
示例返回{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["system-view","vlan 100", "name TEST","quit","quit","save","Y"]} 华为示例返回{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["system-view","vlan 100", "name TEST","quit","quit","save","Y"]}}
注意这里生成的commands中需包含登录交换机和保存等所有操作命令我们使ssh连接交换机你不需要给出连接ssh的命令你只需要给出使用ssh连接到交换机后所输入的全部命令并且注意在system-view状态下是不能save的需要再quit到用户视图 思科示例返回{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["enable","configure terminal","vlan 100", "name TEST","exit","exit","write memory"]}}
""" """
messages = [ messages = [
@ -42,7 +56,7 @@ class AIService:
] ]
try: try:
response = self.client.chat.completions.create( response = await self.client.chat.completions.create(
model="deepseek-ai/DeepSeek-V3", model="deepseek-ai/DeepSeek-V3",
messages=messages, messages=messages,
temperature=0.3, temperature=0.3,

View File

@ -9,13 +9,13 @@ class NetworkScanner:
self.cache_path = Path(cache_path) self.cache_path = Path(cache_path)
self.nm = nmap.PortScanner() self.nm = nmap.PortScanner()
async def scan_subnet(self, subnet: str = "192.168.1.0/24", ports: List[int] = [22, 23, 80]) -> List[Dict]: def scan_subnet(self, subnet: str = "192.168.1.0/24", ports: List[int] = [22, 23, 80]) -> List[Dict]:
"""扫描指定子网的设备,获取设备信息和开放端口""" """扫描指定子网的设备,获取设备信息和开放端口"""
logger.info(f"Scanning subnet: {subnet}") logger.info(f"Scanning subnet: {subnet}")
devices = [] devices = []
try: try:
await self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}') self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}')
for host in self.nm.all_hosts(): for host in self.nm.all_hosts():
ip = host ip = host
mac = self.nm[host]['addresses'].get('mac', 'N/A') mac = self.nm[host]['addresses'].get('mac', 'N/A')
@ -33,19 +33,19 @@ class NetworkScanner:
except Exception as e: except Exception as e:
logger.error(f"Error while scanning subnet: {e}") logger.error(f"Error while scanning subnet: {e}")
await self._save_to_cache(devices) self._save_to_cache(devices)
return devices return devices
async def _save_to_cache(self, devices: List[Dict]): def _save_to_cache(self, devices: List[Dict]):
"""保存扫描结果到本地文件""" """保存扫描结果到本地文件"""
with open(self.cache_path, "w") as f: with open(self.cache_path, "w") as f:
json.dump(devices, f, indent=2) json.dump(devices, f, indent=2)
logger.info(f"Saved {len(devices)} devices to cache") logger.info(f"Saved {len(devices)} devices to cache")
async def load_cached_devices(self) -> List[Dict]: def load_cached_devices(self) -> List[Dict]:
"""从缓存加载设备列表""" """从缓存加载设备列表"""
if not self.cache_path.exists(): if not self.cache_path.exists():
return [] return []
with open(self.cache_path) as f: with open(self.cache_path) as f:
return json.load(f) return json.load(f)