import socket from datetime import datetime, timedelta from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect) from typing import List from pydantic import BaseModel import asyncio from fastapi.responses import HTMLResponse import matplotlib.pyplot as plt import io import base64 import psutil import ipaddress from ..services.switch_traffic_monitor import get_switch_monitor from ..utils import logger from ...app.services.ai_services import AIService from ...app.api.network_config import SwitchConfigurator from ...config import settings from ..services.network_scanner import NetworkScanner from ...app.services.traffic_monitor import traffic_monitor from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord from src.backend.app.api.database import SessionLocal router = APIRouter(prefix="", tags=["API"]) scanner = NetworkScanner() @router.get("/", include_in_schema=False) async def root(): return { "message": "欢迎使用AI交换机配置系统", "docs": f"{settings.API_PREFIX}/docs", "redoc": f"{settings.API_PREFIX}/redoc", "endpoints": [ "/parse_command", "/apply_config", "/scan_network", "/list_devices", "/batch_apply_config" "/traffic/switch/current", "/traffic/switch/history" ] } @router.get("/favicon.ico", include_in_schema=False) async def favicon(): return Response(status_code=204) class BatchConfigRequest(BaseModel): config: dict switch_ips: List[str] username: str = None password: str = None timeout: int = None @router.post("/batch_apply_config") async def batch_apply_config(request: BatchConfigRequest): results = {} for ip in request.switch_ips: try: configurator = SwitchConfigurator( username=request.username, password=request.password, timeout=request.timeout ) results[ip] = await configurator.apply_config(ip, request.config) except Exception as e: results[ip] = str(e) return {"results": results} @router.get("/test") async def test_endpoint(): return {"message": "Hello World"} @router.get("/scan_network", summary="扫描网络中的交换机") async def scan_network(subnet: str = "192.168.1.0/24"): try: devices = scanner.scan_subnet(subnet) return { "success": True, "devices": devices, "count": len(devices) } except Exception as e: raise HTTPException(500, f"扫描失败: {str(e)}") @router.get("/list_devices", summary="列出已发现的交换机") async def list_devices(): return { "devices": scanner.load_cached_devices() } class CommandRequest(BaseModel): command: str class ConfigRequest(BaseModel): config: dict switch_ip: str username: str = None password: str = None timeout: int = None @router.post("/parse_command", response_model=dict) async def parse_command(request: CommandRequest): """ 解析中文命令并返回JSON配置 """ try: ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL) config = await ai_service.parse_command(request.command) return {"success": True, "config": config} except Exception as e: raise HTTPException( status_code=400, detail=f"Failed to parse command: {str(e)}" ) @router.post("/apply_config", response_model=dict) async def apply_config(request: ConfigRequest): """ 应用配置到交换机 """ try: configurator = SwitchConfigurator( username=request.username, password=request.password, timeout=request.timeout ) result = await configurator.safe_apply(request.switch_ip, request.config) return {"success": True, "result": result} except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to apply config: {str(e)}" ) class CLICommandRequest(BaseModel): switch_ip: str commands: List[str] is_ensp: bool = False def extract_credentials(self) -> tuple: """从commands中提取用户名和密码""" username = None password = None for cmd in self.commands: if cmd.startswith("!username="): username = cmd.split("=")[1] elif cmd.startswith("!password="): password = cmd.split("=")[1] return username, password def get_clean_commands(self) -> List[str]: """获取去除凭据后的实际命令""" return [cmd for cmd in self.commands if not (cmd.startswith("!username=") or cmd.startswith("!password="))] @router.post("/execute_cli_commands", response_model=dict) async def execute_cli_commands(request: CLICommandRequest): """ 执行前端生成的CLI命令 支持在commands中嵌入凭据: !username=admin !password=cisco123 """ try: username, password = request.extract_credentials() clean_commands = request.get_clean_commands() configurator = SwitchConfigurator( username=username, password=password, timeout=settings.SWITCH_TIMEOUT, ensp_mode=request.is_ensp ) result = await configurator.execute_raw_commands( ip=request.switch_ip, commands=request.commands ) return { "success": True, "output": result, "mode": "eNSP" if request.is_ensp else "SSH" } except Exception as e: raise HTTPException(500, detail=str(e)) @router.get("/traffic/interfaces", summary="获取所有网络接口") async def get_network_interfaces(): return { "interfaces": traffic_monitor.get_interfaces() } @router.get("/traffic/current", summary="获取当前流量数据") async def get_current_traffic(interface: str = None): return traffic_monitor.get_current_traffic(interface) @router.get("/traffic/history", summary="获取流量历史数据") async def get_traffic_history(interface: str = None, limit: int = 100): history = traffic_monitor.get_traffic_history(interface) return { "sent": history["sent"][-limit:], "recv": history["recv"][-limit:], "time": [t.isoformat() for t in history["time"]][-limit:] } @router.get("/traffic/records", summary="获取流量记录") async def get_traffic_records(interface: str = None, limit: int = 100): with SessionLocal() as session: query = session.query(TrafficRecord) if interface: query = query.filter(TrafficRecord.interface == interface) records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all() return [record.to_dict() for record in records] @router.websocket("/ws/traffic") async def websocket_traffic(websocket: WebSocket): """实时流量WebSocket""" await websocket.accept() try: while True: traffic_data = traffic_monitor.get_current_traffic() await websocket.send_json(traffic_data) await asyncio.sleep(1) except WebSocketDisconnect: print("客户端断开连接") @router.get("/", include_in_schema=False) async def root(): return { "message": "欢迎使用AI交换机配置系统", "docs": f"{settings.API_PREFIX}/docs", "redoc": f"{settings.API_PREFIX}/redoc", "endpoints": [ "/parse_command", "/apply_config", "/scan_network", "/list_devices", "/batch_apply_config", "/traffic/switch/current", "/traffic/switch/history" ] } @router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口") async def get_switch_interfaces(switch_ip: str): """获取指定交换机的所有接口""" try: monitor = get_switch_monitor(switch_ip) interfaces = list(monitor.interface_oids.keys()) return { "switch_ip": switch_ip, "interfaces": interfaces } except Exception as e: logger.error(f"获取交换机接口失败: {str(e)}") raise HTTPException(500, f"获取接口失败: {str(e)}") @router.get("/traffic/switch/current", summary="获取交换机的当前流量数据") async def get_switch_current_traffic(switch_ip: str, interface: str = None): """获取交换机的当前流量数据""" try: monitor = get_switch_monitor(switch_ip) if not interface: traffic_data = {} for iface in monitor.interface_oids: traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface) return { "switch_ip": switch_ip, "traffic": traffic_data } return await get_interface_current_traffic(switch_ip, interface) except Exception as e: logger.error(f"获取交换机流量失败: {str(e)}") raise HTTPException(500, f"获取流量失败: {str(e)}") async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict: """获取指定交换机接口的当前流量数据""" try: with SessionLocal() as session: record = session.query(SwitchTrafficRecord).filter( SwitchTrafficRecord.switch_ip == switch_ip, SwitchTrafficRecord.interface == interface ).order_by(SwitchTrafficRecord.timestamp.desc()).first() if not record: return { "switch_ip": switch_ip, "interface": interface, "rate_in": 0.0, "rate_out": 0.0, "bytes_in": 0, "bytes_out": 0 } return { "switch_ip": switch_ip, "interface": interface, "rate_in": record.rate_in, "rate_out": record.rate_out, "bytes_in": record.bytes_in, "bytes_out": record.bytes_out } except Exception as e: logger.error(f"获取接口流量失败: {str(e)}") raise HTTPException(500, f"获取接口流量失败: {str(e)}") @router.get("/traffic/switch/history", summary="获取交换机的流量历史数据") async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10): """获取交换机的流量历史数据""" try: monitor = get_switch_monitor(switch_ip) if not interface: return { "switch_ip": switch_ip, "history": monitor.get_traffic_history() } with SessionLocal() as session: time_threshold = datetime.now() - timedelta(minutes=minutes) records = session.query(SwitchTrafficRecord).filter( SwitchTrafficRecord.switch_ip == switch_ip, SwitchTrafficRecord.interface == interface, SwitchTrafficRecord.timestamp >= time_threshold ).order_by(SwitchTrafficRecord.timestamp.asc()).all() history_data = { "in": [record.rate_in for record in records], "out": [record.rate_out for record in records], "time": [record.timestamp.isoformat() for record in records] } return { "switch_ip": switch_ip, "interface": interface, "history": history_data } except Exception as e: logger.error(f"获取历史流量失败: {str(e)}") raise HTTPException(500, f"获取历史流量失败: {str(e)}") @router.websocket("/ws/traffic/switch") async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None): """交换机实时流量WebSocket""" await websocket.accept() try: monitor = get_switch_monitor(switch_ip) while True: if interface: traffic_data = await get_interface_current_traffic(switch_ip, interface) await websocket.send_json(traffic_data) else: traffic_data = {} for iface in monitor.interface_oids: traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface) await websocket.send_json({ "switch_ip": switch_ip, "traffic": traffic_data }) await asyncio.sleep(1) except WebSocketDisconnect: logger.info(f"客户端断开交换机流量连接: {switch_ip}") except Exception as e: logger.error(f"交换机流量WebSocket错误: {str(e)}") await websocket.close(code=1011, reason=str(e)) @router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化") async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10): """生成交换机流量图表""" try: history = await get_switch_traffic_history(switch_ip, interface, minutes) history_data = history["history"] time_points = [datetime.fromisoformat(t) for t in history_data["time"]] in_rates = history_data["in"] out_rates = history_data["out"] plt.figure(figsize=(12, 6)) plt.plot(time_points, in_rates, label="流入流量 (B/s)") plt.plot(time_points, out_rates, label="流出流量 (B/s)") plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟") plt.xlabel("时间") plt.ylabel("流量 (字节/秒)") plt.legend() plt.grid(True) plt.xticks(rotation=45) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png") buf.seek(0) image_base64 = base64.b64encode(buf.read()).decode("utf-8") plt.close() return f""" 交换机流量监控

交换机 {switch_ip} 接口 {interface} 流量监控

流量图表

更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

""" except Exception as e: logger.error(f"生成流量图表失败: {str(e)}") return HTMLResponse(content=f"

错误

{str(e)}

", status_code=500) @router.get("/network_adapters", summary="获取网络适配器网段") async def get_network_adapters(): try: net_if_addrs = psutil.net_if_addrs() networks = [] for interface, addrs in net_if_addrs.items(): for addr in addrs: if addr.family == socket.AF_INET: ip = addr.address netmask = addr.netmask network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False) networks.append({ "adapter": interface, "network": str(network), "ip": ip, "subnet_mask": netmask }) return {"networks": networks} except Exception as e: return {"error": f"获取网络适配器信息失败: {str(e)}"}