diff --git a/src/backend/app/api/endpoints.py b/src/backend/app/api/endpoints.py index 1cf0fca..fa2244c 100644 --- a/src/backend/app/api/endpoints.py +++ b/src/backend/app/api/endpoints.py @@ -1,15 +1,18 @@ +# File: D:\Python work\AI-powered-switches\src\backend\app\api\endpoints.py + import socket from datetime import datetime, timedelta from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect) from typing import List from pydantic import BaseModel import asyncio -from fastapi.responses import HTMLResponse +from fastapi.responses import HTMLResponse, JSONResponse import matplotlib.pyplot as plt import io import base64 import psutil import ipaddress +import json from ..services.switch_traffic_monitor import get_switch_monitor from ..utils import logger @@ -23,14 +26,11 @@ from src.backend.app.api.database import SessionLocal from ..services.network_visualizer import NetworkVisualizer from ..services.config_validator import ConfigValidator from ..services.report_generator import ReportGenerator -from fastapi.responses import JSONResponse - - - router = APIRouter(prefix="", tags=["API"]) scanner = NetworkScanner() + @router.get("/", include_in_schema=False) async def root(): return { @@ -42,16 +42,18 @@ async def root(): "/apply_config", "/scan_network", "/list_devices", - "/batch_apply_config" + "/batch_apply_config", "/traffic/switch/current", "/traffic/switch/history" ] } + @router.get("/favicon.ico", include_in_schema=False) async def favicon(): return Response(status_code=204) + class BatchConfigRequest(BaseModel): config: dict switch_ips: List[str] @@ -59,6 +61,7 @@ class BatchConfigRequest(BaseModel): password: str = None timeout: int = None + @router.post("/batch_apply_config") async def batch_apply_config(request: BatchConfigRequest): results = {} @@ -67,22 +70,22 @@ async def batch_apply_config(request: BatchConfigRequest): configurator = SwitchConfigurator( username=request.username, password=request.password, - timeout=request.timeout ) + timeout=request.timeout) results[ip] = await configurator.apply_config(ip, request.config) except Exception as e: results[ip] = str(e) return {"results": results} - @router.get("/test") async def test_endpoint(): return {"message": "Hello World"} + @router.get("/scan_network", summary="扫描网络中的交换机") async def scan_network(subnet: str = "192.168.1.0/24"): try: - devices = await scanner.scan_subnet(subnet) + devices = await asyncio.to_thread(scanner.scan_subnet, subnet) return { "success": True, "devices": devices, @@ -91,14 +94,18 @@ async def scan_network(subnet: str = "192.168.1.0/24"): except Exception as e: raise HTTPException(500, f"扫描失败: {str(e)}") + @router.get("/list_devices", summary="列出已发现的交换机") async def list_devices(): return { - "devices": await scanner.load_cached_devices() + "devices": await asyncio.to_thread(scanner.load_cached_devices) } + class CommandRequest(BaseModel): command: str + vendor: str = "huawei" + class ConfigRequest(BaseModel): config: dict @@ -106,15 +113,15 @@ class ConfigRequest(BaseModel): username: str = None password: str = None timeout: int = None + vendor: str = "huawei" + @router.post("/parse_command", response_model=dict) async def parse_command(request: CommandRequest): - """ - 解析中文命令并返回JSON配置 - """ + """解析中文命令并返回JSON配置""" try: ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL) - config = await ai_service.parse_command(request.command) + config = await ai_service.parse_command(request.command, request.vendor) return {"success": True, "config": config} except Exception as e: raise HTTPException( @@ -122,16 +129,16 @@ async def parse_command(request: CommandRequest): detail=f"Failed to parse command: {str(e)}" ) + @router.post("/apply_config", response_model=dict) async def apply_config(request: ConfigRequest): - """ - 应用配置到交换机(弃用) - """ + """应用配置到交换机""" try: configurator = SwitchConfigurator( username=request.username, password=request.password, - timeout=request.timeout + timeout=request.timeout, + vendor=request.vendor ) result = await configurator.safe_apply(request.switch_ip, request.config) return {"success": True, "result": result} @@ -165,14 +172,10 @@ class CLICommandRequest(BaseModel): return [cmd for cmd in self.commands if not (cmd.startswith("!username=") or cmd.startswith("!password="))] + @router.post("/execute_cli_commands", response_model=dict) async def execute_cli_commands(request: CLICommandRequest): - """ - 执行前端生成的CLI命令 - 支持在commands中嵌入凭据: - !username=admin - !password=cisco123 - """ + """执行前端生成的CLI命令""" try: username, password = request.extract_credentials() clean_commands = request.get_clean_commands() @@ -196,70 +199,56 @@ async def execute_cli_commands(request: CLICommandRequest): except Exception as e: raise HTTPException(500, detail=str(e)) + @router.get("/traffic/interfaces", summary="获取所有网络接口") async def get_network_interfaces(): return { - "interfaces": traffic_monitor.get_interfaces() + "interfaces": await asyncio.to_thread(traffic_monitor.get_interfaces) } + @router.get("/traffic/current", summary="获取当前流量数据") async def get_current_traffic(interface: str = None): - return traffic_monitor.get_current_traffic(interface) + return await asyncio.to_thread(traffic_monitor.get_current_traffic, interface) + @router.get("/traffic/history", summary="获取流量历史数据") async def get_traffic_history(interface: str = None, limit: int = 100): - history = traffic_monitor.get_traffic_history(interface) + history = await asyncio.to_thread(traffic_monitor.get_traffic_history, interface) return { "sent": history["sent"][-limit:], "recv": history["recv"][-limit:], "time": [t.isoformat() for t in history["time"]][-limit:] } + @router.get("/traffic/records", summary="获取流量记录") async def get_traffic_records(interface: str = None, limit: int = 100): - with SessionLocal() as session: - query = session.query(TrafficRecord) - if interface: - query = query.filter(TrafficRecord.interface == interface) - records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all() - return [record.to_dict() for record in records] + def sync_get_records(): + with SessionLocal() as session: + query = session.query(TrafficRecord) + if interface: + query = query.filter(TrafficRecord.interface == interface) + records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all() + return [record.to_dict() for record in records] + + return await asyncio.to_thread(sync_get_records) + @router.websocket("/ws/traffic") async def websocket_traffic(websocket: WebSocket): - """实时流量WebSocket""" await websocket.accept() try: while True: - traffic_data = traffic_monitor.get_current_traffic() + traffic_data = await asyncio.to_thread(traffic_monitor.get_current_traffic) await websocket.send_json(traffic_data) await asyncio.sleep(1) except WebSocketDisconnect: print("客户端断开连接") -@router.get("/", include_in_schema=False) -async def root(): - return { - "message": "欢迎使用AI交换机配置系统", - "docs": f"{settings.API_PREFIX}/docs", - "redoc": f"{settings.API_PREFIX}/redoc", - "endpoints": [ - "/parse_command", - "/apply_config", - "/scan_network", - "/list_devices", - "/batch_apply_config", - "/traffic/switch/current", - "/traffic/switch/history" - ] - } - - - - @router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口") async def get_switch_interfaces(switch_ip: str): - """获取指定交换机的所有接口""" try: monitor = get_switch_monitor(switch_ip) interfaces = list(monitor.interface_oids.keys()) @@ -272,13 +261,45 @@ async def get_switch_interfaces(switch_ip: str): raise HTTPException(500, f"获取接口失败: {str(e)}") +async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict: + """获取指定交换机接口的当前流量数据""" + try: + def sync_get_record(): + with SessionLocal() as session: + record = session.query(SwitchTrafficRecord).filter( + SwitchTrafficRecord.switch_ip == switch_ip, + SwitchTrafficRecord.interface == interface + ).order_by(SwitchTrafficRecord.timestamp.desc()).first() + + if not record: + return { + "switch_ip": switch_ip, + "interface": interface, + "rate_in": 0.0, + "rate_out": 0.0, + "bytes_in": 0, + "bytes_out": 0 + } + + return { + "switch_ip": switch_ip, + "interface": interface, + "rate_in": record.rate_in, + "rate_out": record.rate_out, + "bytes_in": record.bytes_in, + "bytes_out": record.bytes_out + } + + return await asyncio.to_thread(sync_get_record) + except Exception as e: + logger.error(f"获取接口流量失败: {str(e)}") + raise HTTPException(500, f"获取接口流量失败: {str(e)}") + + @router.get("/traffic/switch/current", summary="获取交换机的当前流量数据") async def get_switch_current_traffic(switch_ip: str, interface: str = None): - """获取交换机的当前流量数据""" try: monitor = get_switch_monitor(switch_ip) - - if not interface: traffic_data = {} for iface in monitor.interface_oids: @@ -287,78 +308,44 @@ async def get_switch_current_traffic(switch_ip: str, interface: str = None): "switch_ip": switch_ip, "traffic": traffic_data } - return await get_interface_current_traffic(switch_ip, interface) except Exception as e: logger.error(f"获取交换机流量失败: {str(e)}") raise HTTPException(500, f"获取流量失败: {str(e)}") -async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict: - """获取指定交换机接口的当前流量数据""" - try: - with SessionLocal() as session: - - record = session.query(SwitchTrafficRecord).filter( - SwitchTrafficRecord.switch_ip == switch_ip, - SwitchTrafficRecord.interface == interface - ).order_by(SwitchTrafficRecord.timestamp.desc()).first() - - if not record: - return { - "switch_ip": switch_ip, - "interface": interface, - "rate_in": 0.0, - "rate_out": 0.0, - "bytes_in": 0, - "bytes_out": 0 - } - - return { - "switch_ip": switch_ip, - "interface": interface, - "rate_in": record.rate_in, - "rate_out": record.rate_out, - "bytes_in": record.bytes_in, - "bytes_out": record.bytes_out - } - except Exception as e: - logger.error(f"获取接口流量失败: {str(e)}") - raise HTTPException(500, f"获取接口流量失败: {str(e)}") - - @router.get("/traffic/switch/history", summary="获取交换机的流量历史数据") async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10): - """获取交换机的流量历史数据""" try: monitor = get_switch_monitor(switch_ip) - if not interface: return { "switch_ip": switch_ip, - "history": monitor.get_traffic_history() + "history": await asyncio.to_thread(monitor.get_traffic_history) } - with SessionLocal() as session: - time_threshold = datetime.now() - timedelta(minutes=minutes) + def sync_get_history(): + with SessionLocal() as session: + time_threshold = datetime.now() - timedelta(minutes=minutes) + records = session.query(SwitchTrafficRecord).filter( + SwitchTrafficRecord.switch_ip == switch_ip, + SwitchTrafficRecord.interface == interface, + SwitchTrafficRecord.timestamp >= time_threshold + ).order_by(SwitchTrafficRecord.timestamp.asc()).all() - records = session.query(SwitchTrafficRecord).filter( - SwitchTrafficRecord.switch_ip == switch_ip, - SwitchTrafficRecord.interface == interface, - SwitchTrafficRecord.timestamp >= time_threshold - ).order_by(SwitchTrafficRecord.timestamp.asc()).all() + history_data = { + "in": [record.rate_in for record in records], + "out": [record.rate_out for record in records], + "time": [record.timestamp.isoformat() for record in records] + } + return history_data - history_data = { - "in": [record.rate_in for record in records], - "out": [record.rate_out for record in records], - "time": [record.timestamp.isoformat() for record in records] - } - - return { - "switch_ip": switch_ip, - "interface": interface, - "history": history_data - } + history_data = await asyncio.to_thread(sync_get_history) + return { + "switch_ip": switch_ip, + "interface": interface, + "history": history_data + } except Exception as e: logger.error(f"获取历史流量失败: {str(e)}") raise HTTPException(500, f"获取历史流量失败: {str(e)}") @@ -366,11 +353,9 @@ async def get_switch_traffic_history(switch_ip: str, interface: str = None, minu @router.websocket("/ws/traffic/switch") async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None): - """交换机实时流量WebSocket""" await websocket.accept() try: monitor = get_switch_monitor(switch_ip) - while True: if interface: traffic_data = await get_interface_current_traffic(switch_ip, interface) @@ -379,12 +364,10 @@ async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interfa traffic_data = {} for iface in monitor.interface_oids: traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface) - await websocket.send_json({ "switch_ip": switch_ip, "traffic": traffic_data }) - await asyncio.sleep(1) except WebSocketDisconnect: logger.info(f"客户端断开交换机流量连接: {switch_ip}") @@ -392,33 +375,35 @@ async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interfa logger.error(f"交换机流量WebSocket错误: {str(e)}") await websocket.close(code=1011, reason=str(e)) + @router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化") async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10): - """生成交换机流量图表""" try: history = await get_switch_traffic_history(switch_ip, interface, minutes) history_data = history["history"] - time_points = [datetime.fromisoformat(t) for t in history_data["time"]] in_rates = history_data["in"] out_rates = history_data["out"] - plt.figure(figsize=(12, 6)) - plt.plot(time_points, in_rates, label="流入流量 (B/s)") - plt.plot(time_points, out_rates, label="流出流量 (B/s)") - plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟") - plt.xlabel("时间") - plt.ylabel("流量 (字节/秒)") - plt.legend() - plt.grid(True) - plt.xticks(rotation=45) - plt.tight_layout() - buf = io.BytesIO() - plt.savefig(buf, format="png") - buf.seek(0) - image_base64 = base64.b64encode(buf.read()).decode("utf-8") - plt.close() + def generate_plot(): + plt.figure(figsize=(12, 6)) + plt.plot(time_points, in_rates, label="流入流量 (B/s)") + plt.plot(time_points, out_rates, label="流出流量 (B/s)") + plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟") + plt.xlabel("时间") + plt.ylabel("流量 (字节/秒)") + plt.legend() + plt.grid(True) + plt.xticks(rotation=45) + plt.tight_layout() + buf = io.BytesIO() + plt.savefig(buf, format="png") + buf.seek(0) + image_base64 = base64.b64encode(buf.read()).decode("utf-8") + plt.close() + return image_base64 + image_base64 = await asyncio.to_thread(generate_plot) return f"""
@@ -445,25 +430,25 @@ async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10) @router.get("/network_adapters", summary="获取网络适配器网段") async def get_network_adapters(): try: - net_if_addrs = psutil.net_if_addrs() - - networks = [] - for interface, addrs in net_if_addrs.items(): - for addr in addrs: - if addr.family == socket.AF_INET: - ip = addr.address - netmask = addr.netmask - - network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False) - networks.append({ - "adapter": interface, - "network": str(network), - "ip": ip, - "subnet_mask": netmask - }) + def sync_get_adapters(): + net_if_addrs = psutil.net_if_addrs() + networks = [] + for interface, addrs in net_if_addrs.items(): + for addr in addrs: + if addr.family == socket.AF_INET: + ip = addr.address + netmask = addr.netmask + network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False) + networks.append({ + "adapter": interface, + "network": str(network), + "ip": ip, + "subnet_mask": netmask + }) + return networks + networks = await asyncio.to_thread(sync_get_adapters) return {"networks": networks} - except Exception as e: return {"error": f"获取网络适配器信息失败: {str(e)}"} @@ -471,14 +456,14 @@ async def get_network_adapters(): visualizer = NetworkVisualizer() report_gen = ReportGenerator() + @router.get("/topology/visualize", response_class=HTMLResponse) async def visualize_topology(): """获取网络拓扑可视化图""" try: - devices = await list_devices() # 复用现有的设备列表接口 - visualizer.update_topology(devices["devices"]) - image_data = visualizer.generate_topology_image() - + devices = await list_devices() + await asyncio.to_thread(visualizer.update_topology, devices["devices"]) + image_data = await asyncio.to_thread(visualizer.generate_topology_image) return f"""