From d77a0dacadc33e989f1a470067b352a29c4964ac Mon Sep 17 00:00:00 2001 From: Jerry Date: Tue, 12 Aug 2025 00:30:15 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E5=88=A0=E6=8E=89=E4=B8=80=E4=BA=9B?= =?UTF-8?q?=E4=B8=9C=E8=A5=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/api/endpoints.py | 294 +------------------- src/backend/app/services/ai_services.py | 2 +- src/backend/app/services/network_scanner.py | 2 +- src/backend/app/services/traffic_monitor.py | 5 +- 4 files changed, 8 insertions(+), 295 deletions(-) diff --git a/src/backend/app/api/endpoints.py b/src/backend/app/api/endpoints.py index 1cf0fca..1e75556 100644 --- a/src/backend/app/api/endpoints.py +++ b/src/backend/app/api/endpoints.py @@ -1,29 +1,17 @@ import socket -from datetime import datetime, timedelta -from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect) +from fastapi import (APIRouter, HTTPException, Response) from typing import List from pydantic import BaseModel -import asyncio from fastapi.responses import HTMLResponse -import matplotlib.pyplot as plt -import io -import base64 import psutil import ipaddress - -from ..services.switch_traffic_monitor import get_switch_monitor -from ..utils import logger from ...app.services.ai_services import AIService from ...app.api.network_config import SwitchConfigurator from ...config import settings from ..services.network_scanner import NetworkScanner -from ...app.services.traffic_monitor import traffic_monitor -from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord -from src.backend.app.api.database import SessionLocal from ..services.network_visualizer import NetworkVisualizer from ..services.config_validator import ConfigValidator from ..services.report_generator import ReportGenerator -from fastapi.responses import JSONResponse @@ -59,22 +47,6 @@ class BatchConfigRequest(BaseModel): password: str = None timeout: int = None -@router.post("/batch_apply_config") -async def batch_apply_config(request: BatchConfigRequest): - results = {} - for ip in request.switch_ips: - try: - configurator = SwitchConfigurator( - username=request.username, - password=request.password, - timeout=request.timeout ) - results[ip] = await configurator.apply_config(ip, request.config) - except Exception as e: - results[ip] = str(e) - return {"results": results} - - - @router.get("/test") async def test_endpoint(): return {"message": "Hello World"} @@ -196,47 +168,6 @@ async def execute_cli_commands(request: CLICommandRequest): except Exception as e: raise HTTPException(500, detail=str(e)) -@router.get("/traffic/interfaces", summary="获取所有网络接口") -async def get_network_interfaces(): - return { - "interfaces": traffic_monitor.get_interfaces() - } - -@router.get("/traffic/current", summary="获取当前流量数据") -async def get_current_traffic(interface: str = None): - return traffic_monitor.get_current_traffic(interface) - -@router.get("/traffic/history", summary="获取流量历史数据") -async def get_traffic_history(interface: str = None, limit: int = 100): - history = traffic_monitor.get_traffic_history(interface) - return { - "sent": history["sent"][-limit:], - "recv": history["recv"][-limit:], - "time": [t.isoformat() for t in history["time"]][-limit:] - } - -@router.get("/traffic/records", summary="获取流量记录") -async def get_traffic_records(interface: str = None, limit: int = 100): - with SessionLocal() as session: - query = session.query(TrafficRecord) - if interface: - query = query.filter(TrafficRecord.interface == interface) - records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all() - return [record.to_dict() for record in records] - -@router.websocket("/ws/traffic") -async def websocket_traffic(websocket: WebSocket): - """实时流量WebSocket""" - await websocket.accept() - try: - while True: - traffic_data = traffic_monitor.get_current_traffic() - await websocket.send_json(traffic_data) - await asyncio.sleep(1) - except WebSocketDisconnect: - print("客户端断开连接") - - @router.get("/", include_in_schema=False) async def root(): return { @@ -253,195 +184,6 @@ async def root(): "/traffic/switch/history" ] } - - - - -@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口") -async def get_switch_interfaces(switch_ip: str): - """获取指定交换机的所有接口""" - try: - monitor = get_switch_monitor(switch_ip) - interfaces = list(monitor.interface_oids.keys()) - return { - "switch_ip": switch_ip, - "interfaces": interfaces - } - except Exception as e: - logger.error(f"获取交换机接口失败: {str(e)}") - raise HTTPException(500, f"获取接口失败: {str(e)}") - - -@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据") -async def get_switch_current_traffic(switch_ip: str, interface: str = None): - """获取交换机的当前流量数据""" - try: - monitor = get_switch_monitor(switch_ip) - - - if not interface: - traffic_data = {} - for iface in monitor.interface_oids: - traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface) - return { - "switch_ip": switch_ip, - "traffic": traffic_data - } - - return await get_interface_current_traffic(switch_ip, interface) - except Exception as e: - logger.error(f"获取交换机流量失败: {str(e)}") - raise HTTPException(500, f"获取流量失败: {str(e)}") - - -async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict: - """获取指定交换机接口的当前流量数据""" - try: - with SessionLocal() as session: - - record = session.query(SwitchTrafficRecord).filter( - SwitchTrafficRecord.switch_ip == switch_ip, - SwitchTrafficRecord.interface == interface - ).order_by(SwitchTrafficRecord.timestamp.desc()).first() - - if not record: - return { - "switch_ip": switch_ip, - "interface": interface, - "rate_in": 0.0, - "rate_out": 0.0, - "bytes_in": 0, - "bytes_out": 0 - } - - return { - "switch_ip": switch_ip, - "interface": interface, - "rate_in": record.rate_in, - "rate_out": record.rate_out, - "bytes_in": record.bytes_in, - "bytes_out": record.bytes_out - } - except Exception as e: - logger.error(f"获取接口流量失败: {str(e)}") - raise HTTPException(500, f"获取接口流量失败: {str(e)}") - - -@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据") -async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10): - """获取交换机的流量历史数据""" - try: - monitor = get_switch_monitor(switch_ip) - - if not interface: - return { - "switch_ip": switch_ip, - "history": monitor.get_traffic_history() - } - - with SessionLocal() as session: - time_threshold = datetime.now() - timedelta(minutes=minutes) - - records = session.query(SwitchTrafficRecord).filter( - SwitchTrafficRecord.switch_ip == switch_ip, - SwitchTrafficRecord.interface == interface, - SwitchTrafficRecord.timestamp >= time_threshold - ).order_by(SwitchTrafficRecord.timestamp.asc()).all() - - history_data = { - "in": [record.rate_in for record in records], - "out": [record.rate_out for record in records], - "time": [record.timestamp.isoformat() for record in records] - } - - return { - "switch_ip": switch_ip, - "interface": interface, - "history": history_data - } - except Exception as e: - logger.error(f"获取历史流量失败: {str(e)}") - raise HTTPException(500, f"获取历史流量失败: {str(e)}") - - -@router.websocket("/ws/traffic/switch") -async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None): - """交换机实时流量WebSocket""" - await websocket.accept() - try: - monitor = get_switch_monitor(switch_ip) - - while True: - if interface: - traffic_data = await get_interface_current_traffic(switch_ip, interface) - await websocket.send_json(traffic_data) - else: - traffic_data = {} - for iface in monitor.interface_oids: - traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface) - - await websocket.send_json({ - "switch_ip": switch_ip, - "traffic": traffic_data - }) - - await asyncio.sleep(1) - except WebSocketDisconnect: - logger.info(f"客户端断开交换机流量连接: {switch_ip}") - except Exception as e: - logger.error(f"交换机流量WebSocket错误: {str(e)}") - await websocket.close(code=1011, reason=str(e)) - -@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化") -async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10): - """生成交换机流量图表""" - try: - history = await get_switch_traffic_history(switch_ip, interface, minutes) - history_data = history["history"] - - time_points = [datetime.fromisoformat(t) for t in history_data["time"]] - in_rates = history_data["in"] - out_rates = history_data["out"] - - plt.figure(figsize=(12, 6)) - plt.plot(time_points, in_rates, label="流入流量 (B/s)") - plt.plot(time_points, out_rates, label="流出流量 (B/s)") - plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟") - plt.xlabel("时间") - plt.ylabel("流量 (字节/秒)") - plt.legend() - plt.grid(True) - plt.xticks(rotation=45) - plt.tight_layout() - buf = io.BytesIO() - plt.savefig(buf, format="png") - buf.seek(0) - image_base64 = base64.b64encode(buf.read()).decode("utf-8") - plt.close() - - return f""" - - - 交换机流量监控 - - - -
-

交换机 {switch_ip} 接口 {interface} 流量监控

- 流量图表 -

更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

-
- - - """ - except Exception as e: - logger.error(f"生成流量图表失败: {str(e)}") - return HTMLResponse(content=f"

错误

{str(e)}

", status_code=500) - - @router.get("/network_adapters", summary="获取网络适配器网段") async def get_network_adapters(): try: @@ -475,7 +217,7 @@ report_gen = ReportGenerator() async def visualize_topology(): """获取网络拓扑可视化图""" try: - devices = await list_devices() # 复用现有的设备列表接口 + devices = await list_devices() visualizer.update_topology(devices["devices"]) image_data = visualizer.generate_topology_image() @@ -500,34 +242,4 @@ async def validate_config(config: dict): "valid": is_valid, "errors": errors, "has_security_risks": len(ConfigValidator.check_security_risks(config.get("commands", []))) > 0 - } - - -@router.get("/reports/traffic/{ip}") -async def get_traffic_report(ip: str, days: int = 1): - """获取流量分析报告""" - try: - report = report_gen.generate_traffic_report(ip, days) - return JSONResponse(content=report) - except Exception as e: - raise HTTPException(500, detail=str(e)) - - -@router.get("/reports/traffic") -async def get_local_traffic_report(days: int = 1): - """获取本地网络流量报告""" - try: - report = report_gen.generate_traffic_report(days=days) - return JSONResponse(content=report) - except Exception as e: - raise HTTPException(500, detail=str(e)) - - -@router.get("/topology/traffic_heatmap") -async def get_traffic_heatmap(minutes: int = 10): - """获取流量热力图数据""" - try: - heatmap = visualizer.get_traffic_heatmap(minutes) - return {"heatmap": heatmap} - except Exception as e: - raise HTTPException(500, detail=str(e)) \ No newline at end of file + } \ No newline at end of file diff --git a/src/backend/app/services/ai_services.py b/src/backend/app/services/ai_services.py index 80e14e8..d315477 100644 --- a/src/backend/app/services/ai_services.py +++ b/src/backend/app/services/ai_services.py @@ -42,7 +42,7 @@ class AIService: ] try: - response = await self.client.chat.completions.create( + response = self.client.chat.completions.create( model="deepseek-ai/DeepSeek-V3", messages=messages, temperature=0.3, diff --git a/src/backend/app/services/network_scanner.py b/src/backend/app/services/network_scanner.py index b6021d3..77ac5ff 100644 --- a/src/backend/app/services/network_scanner.py +++ b/src/backend/app/services/network_scanner.py @@ -15,7 +15,7 @@ class NetworkScanner: devices = [] try: - self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}') + await self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}') for host in self.nm.all_hosts(): ip = host mac = self.nm[host]['addresses'].get('mac', 'N/A') diff --git a/src/backend/app/services/traffic_monitor.py b/src/backend/app/services/traffic_monitor.py index 54adb11..1c498b3 100644 --- a/src/backend/app/services/traffic_monitor.py +++ b/src/backend/app/services/traffic_monitor.py @@ -8,6 +8,7 @@ from typing import Dict, Optional, List from ..models.traffic_models import TrafficRecord from src.backend.app.api.database import SessionLocal +from ..utils.logger import logger class TrafficMonitor: @@ -33,7 +34,7 @@ class TrafficMonitor: if not self.running: self.running = True self.task = asyncio.create_task(self._monitor_loop()) - print("流量监控已启动") + logger.info("流量监控已启动") async def stop_monitoring(self): """停止流量监控""" @@ -44,7 +45,7 @@ class TrafficMonitor: await self.task except asyncio.CancelledError: pass - print("流量监控已停止") + logger.info("流量监控已停止") async def _monitor_loop(self): """监控主循环""" From 59c8604cda19f38df7648fcdf43ad7351c95d06c Mon Sep 17 00:00:00 2001 From: Jerry Date: Tue, 12 Aug 2025 00:31:15 +0800 Subject: [PATCH 2/7] =?UTF-8?q?=E5=86=8D=E5=88=A0=E6=8E=89=E4=B8=80?= =?UTF-8?q?=E4=BA=9B=E4=B8=9C=E8=A5=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/api/endpoints.py | 42 +------------------------------- 1 file changed, 1 insertion(+), 41 deletions(-) diff --git a/src/backend/app/api/endpoints.py b/src/backend/app/api/endpoints.py index 1e75556..4bcd616 100644 --- a/src/backend/app/api/endpoints.py +++ b/src/backend/app/api/endpoints.py @@ -2,17 +2,12 @@ import socket from fastapi import (APIRouter, HTTPException, Response) from typing import List from pydantic import BaseModel -from fastapi.responses import HTMLResponse import psutil import ipaddress from ...app.services.ai_services import AIService from ...app.api.network_config import SwitchConfigurator from ...config import settings from ..services.network_scanner import NetworkScanner -from ..services.network_visualizer import NetworkVisualizer -from ..services.config_validator import ConfigValidator -from ..services.report_generator import ReportGenerator - @@ -207,39 +202,4 @@ async def get_network_adapters(): return {"networks": networks} except Exception as e: - return {"error": f"获取网络适配器信息失败: {str(e)}"} - - -visualizer = NetworkVisualizer() -report_gen = ReportGenerator() - -@router.get("/topology/visualize", response_class=HTMLResponse) -async def visualize_topology(): - """获取网络拓扑可视化图""" - try: - devices = await list_devices() - visualizer.update_topology(devices["devices"]) - image_data = visualizer.generate_topology_image() - - return f""" - - Network Topology - -

Network Topology

- Network Topology - - - """ - except Exception as e: - raise HTTPException(500, detail=str(e)) - - -@router.post("/config/validate") -async def validate_config(config: dict): - """验证配置有效性""" - is_valid, errors = ConfigValidator.validate_full_config(config) - return { - "valid": is_valid, - "errors": errors, - "has_security_risks": len(ConfigValidator.check_security_risks(config.get("commands", []))) > 0 - } \ No newline at end of file + return {"error": f"获取网络适配器信息失败: {str(e)}"} \ No newline at end of file From 29f016bdab96a77548718dffe58c4e5884334a89 Mon Sep 17 00:00:00 2001 From: Jerry Date: Tue, 12 Aug 2025 00:33:12 +0800 Subject: [PATCH 3/7] =?UTF-8?q?=E5=86=8D=E5=86=8D=E5=88=A0=E6=8E=89?= =?UTF-8?q?=E4=B8=80=E4=BA=9B=E4=B8=9C=E8=A5=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/services/config_validator.py | 85 ------- src/backend/app/services/failover_manager.py | 129 ---------- src/backend/app/services/network_optimizer.py | 220 ------------------ .../app/services/network_visualizer.py | 104 --------- src/backend/app/services/report_generator.py | 120 ---------- 5 files changed, 658 deletions(-) delete mode 100644 src/backend/app/services/config_validator.py delete mode 100644 src/backend/app/services/failover_manager.py delete mode 100644 src/backend/app/services/network_optimizer.py delete mode 100644 src/backend/app/services/network_visualizer.py delete mode 100644 src/backend/app/services/report_generator.py diff --git a/src/backend/app/services/config_validator.py b/src/backend/app/services/config_validator.py deleted file mode 100644 index 5be5839..0000000 --- a/src/backend/app/services/config_validator.py +++ /dev/null @@ -1,85 +0,0 @@ -import re -from typing import Dict, List, Tuple -from ..utils.exceptions import SwitchConfigException - - -class ConfigValidator: - @staticmethod - def validate_vlan_config(config: Dict) -> Tuple[bool, str]: - """验证VLAN配置""" - if 'vlan_id' not in config: - return False, "Missing VLAN ID" - - vlan_id = config['vlan_id'] - if not (1 <= vlan_id <= 4094): - return False, f"Invalid VLAN ID {vlan_id}. Must be 1-4094" - - if 'name' in config and len(config['name']) > 32: - return False, "VLAN name too long (max 32 chars)" - - return True, "Valid VLAN config" - - @staticmethod - def validate_interface_config(config: Dict) -> Tuple[bool, str]: - """验证接口配置""" - required_fields = ['interface', 'ip_address'] - for field in required_fields: - if field not in config: - return False, f"Missing required field: {field}" - - # 验证IP地址格式 - ip_pattern = r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$' - if not re.match(ip_pattern, config['ip_address']): - return False, "Invalid IP address format" - - # 验证接口名称格式 - interface_pattern = r'^(GigabitEthernet|FastEthernet|Eth)\d+/\d+/\d+$' - if not re.match(interface_pattern, config['interface']): - return False, "Invalid interface name format" - - return True, "Valid interface config" - - @staticmethod - def check_security_risks(commands: List[str]) -> List[str]: - """检查潜在安全风险""" - risky_commands = [] - dangerous_patterns = [ - r'no\s+aaa', # 禁用认证 - r'enable\s+password', # 明文密码 - r'service\s+password-encryption', # 弱加密 - r'ip\s+http\s+server', # 启用HTTP服务 - r'no\s+ip\s+http\s+secure-server' # 禁用HTTPS - ] - - for cmd in commands: - for pattern in dangerous_patterns: - if re.search(pattern, cmd, re.IGNORECASE): - risky_commands.append(cmd) - break - - return risky_commands - - @staticmethod - def validate_full_config(config: Dict) -> Tuple[bool, List[str]]: - """全面验证配置""" - errors = [] - - if 'type' not in config: - errors.append("Missing configuration type") - return False, errors - - if config['type'] == 'vlan': - valid, msg = ConfigValidator.validate_vlan_config(config) - if not valid: - errors.append(msg) - elif config['type'] == 'interface': - valid, msg = ConfigValidator.validate_interface_config(config) - if not valid: - errors.append(msg) - - if 'commands' in config: - risks = ConfigValidator.check_security_risks(config['commands']) - if risks: - errors.append(f"Potential security risks detected: {', '.join(risks)}") - - return len(errors) == 0, errors \ No newline at end of file diff --git a/src/backend/app/services/failover_manager.py b/src/backend/app/services/failover_manager.py deleted file mode 100644 index 84c31d9..0000000 --- a/src/backend/app/services/failover_manager.py +++ /dev/null @@ -1,129 +0,0 @@ -import asyncio -from datetime import datetime, timedelta -from pathlib import Path -from typing import Dict, List - -import networkx as nx - -from ..models.traffic_models import SwitchTrafficRecord -from src.backend.app.api.database import SessionLocal -from ..utils.exceptions import SwitchConfigException -from ..services.network_scanner import NetworkScanner - - -class FailoverManager: - def __init__(self, config_backup_dir: str = "config_backups"): - self.scanner = NetworkScanner() - self.backup_dir = Path(config_backup_dir) - self.backup_dir.mkdir(exist_ok=True) - - async def detect_failures(self, threshold: float = 0.9) -> List[Dict]: - """检测可能的网络故障""" - with SessionLocal() as session: - # 获取最近5分钟的所有交换机流量记录 - records = session.query(SwitchTrafficRecord).filter( - SwitchTrafficRecord.timestamp >= datetime.now() - timedelta(minutes=5) - ).all() - - # 分析流量异常 - abnormal_devices = [] - device_stats = {} - - for record in records: - if record.switch_ip not in device_stats: - device_stats[record.switch_ip] = { - 'interfaces': {}, - 'max_in': 0, - 'max_out': 0 - } - - if record.interface not in device_stats[record.switch_ip]['interfaces']: - device_stats[record.switch_ip]['interfaces'][record.interface] = { - 'in': [], - 'out': [] - } - - device_stats[record.switch_ip]['interfaces'][record.interface]['in'].append(record.rate_in) - device_stats[record.switch_ip]['interfaces'][record.interface]['out'].append(record.rate_out) - - # 更新设备最大流量 - if record.rate_in > device_stats[record.switch_ip]['max_in']: - device_stats[record.switch_ip]['max_in'] = record.rate_in - if record.rate_out > device_stats[record.switch_ip]['max_out']: - device_stats[record.switch_ip]['max_out'] = record.rate_out - - # 检测异常 - for ip, stats in device_stats.items(): - for iface, traffic in stats['interfaces'].items(): - avg_in = sum(traffic['in']) / len(traffic['in']) - avg_out = sum(traffic['out']) / len(traffic['out']) - - # 如果平均流量超过最大流量的90%,认为可能有问题 - if avg_in > stats['max_in'] * threshold or \ - avg_out > stats['max_out'] * threshold: - abnormal_devices.append({ - 'ip': ip, - 'interface': iface, - 'avg_in': avg_in, - 'avg_out': avg_out, - 'max_in': stats['max_in'], - 'max_out': stats['max_out'] - }) - - return abnormal_devices - - async def automatic_recovery(self, failed_device_ip: str) -> bool: - """自动故障恢复""" - try: - # 1. 检查设备是否在线 - devices = self.scanner.scan_subnet(failed_device_ip + "/32") - if not devices: - raise SwitchConfigException(f"Device {failed_device_ip} is offline") - - # 2. 查找最近的配置备份 - backup_files = sorted(self.backup_dir.glob(f"{failed_device_ip}_*.cfg")) - if not backup_files: - raise SwitchConfigException(f"No backup found for {failed_device_ip}") - - latest_backup = backup_files[-1] - - # 3. 恢复配置 - with open(latest_backup) as f: - config_commands = f.read().splitlines() - - # 使用SSH执行恢复命令 - # (这里需要实现SSH连接和执行命令的逻辑) - - return True - - except Exception as e: - raise SwitchConfigException(f"Recovery failed: {str(e)}") - - async def redundancy_check(self, critical_devices: List[str]) -> Dict: - """检查关键设备的冗余配置""" - results = {} - topology = self.scanner.get_current_topology() - - for device_ip in critical_devices: - # 检查是否有备用路径 - try: - paths = list(nx.all_shortest_paths( - topology, source=device_ip, target="core_switch")) - - if len(paths) > 1: - results[device_ip] = { - 'status': 'redundant', - 'path_count': len(paths) - } - else: - results[device_ip] = { - 'status': 'single_point_of_failure', - 'recommendation': 'Add redundant links' - } - except: - results[device_ip] = { - 'status': 'disconnected', - 'recommendation': 'Check physical connection' - } - - return results \ No newline at end of file diff --git a/src/backend/app/services/network_optimizer.py b/src/backend/app/services/network_optimizer.py deleted file mode 100644 index 39c14f7..0000000 --- a/src/backend/app/services/network_optimizer.py +++ /dev/null @@ -1,220 +0,0 @@ -import networkx as nx -import numpy as np -from scipy.optimize import linprog -from typing import Dict, List, Optional, Tuple, Any -from dataclasses import dataclass - - -@dataclass -class FlowDemand: - source: str - destination: str - bandwidth: float - priority: float = 1.0 - - -class NetworkOptimizer: - def __init__(self, devices: List[Dict[str, Any]]): - """基于图论的网络优化模型""" - self.graph = self.build_topology_graph(devices) - self.traffic_matrix: Optional[np.ndarray] = None - self._initialize_capacities() - - def _initialize_capacities(self) -> Dict[Tuple[str, str], float]: - """初始化链路容量字典""" - self.remaining_capacity = { - (u, v): data['bandwidth'] - for u, v, data in self.graph.edges(data=True) - } - return self.remaining_capacity - - def build_topology_graph(self, devices: List[Dict[str, Any]]) -> nx.Graph: - """构建带权重的网络拓扑图""" - G = nx.Graph() - - for device in devices: - G.add_node(device['ip'], - type=device['type'], - capacity=device.get('capacity', 1000)) - - # 添加接口作为子节点 - for interface in device.get('interfaces', []): - interface_id = f"{device['ip']}_{interface['name']}" - G.add_node(interface_id, - type='interface', - capacity=interface.get('bandwidth', 1000)) - G.add_edge(device['ip'], interface_id, - bandwidth=interface.get('bandwidth', 1000), - latency=interface.get('latency', 1)) - - # 添加设备间连接 - self._connect_devices(G, devices) - return G - - @staticmethod - def _connect_devices(graph: nx.Graph, devices: List[Dict[str, Any]]): - """自动连接设备""" - for i in range(len(devices) - 1): - graph.add_edge( - devices[i]['ip'], - devices[i + 1]['ip'], - bandwidth=1000, - latency=5 - ) - - @staticmethod - def _build_flow_conservation(nodes: List[str]) -> Tuple[np.ndarray, np.ndarray]: - """构建流量守恒约束矩阵""" - num_nodes = len(nodes) - A_eq: List[np.ndarray] = [] # 明确指定列表元素类型 - b_eq: List[float] = [] # 明确指定float列表 - - for i in range(num_nodes): - for j in range(num_nodes): - if i != j: - constraint = np.zeros((num_nodes, num_nodes)) - constraint[i][j] = 1 - constraint[j][i] = -1 - A_eq.append(constraint.flatten()) - b_eq.append(0.0) # 明确使用float类型 - - return np.array(A_eq, dtype=np.float64), np.array(b_eq, dtype=np.float64) - def optimize_bandwidth_allocation( - self, - demands: List[FlowDemand] - ) -> Dict[str, Dict[str, float]]: - """ - 基于线性规划的带宽分配优化 - 返回: {源IP: {目标IP: 分配带宽}} - """ - demand_dict = {(d.source, d.destination): float(d.bandwidth) for d in demands} # 确保float类型 - return self._optimize_with_highs(demand_dict) - - def _optimize_with_highs(self, demands: Dict[Tuple[str, str], float]) -> Dict[str, Dict[str, float]]: - """使用HiGHS求解器实现""" - nodes = list(self.graph.nodes()) - node_index = {node: i for i, node in enumerate(nodes)} - - # 构建流量矩阵 - self.traffic_matrix = np.zeros((len(nodes), len(nodes))) - for (src, dst), bw in demands.items(): - if src in node_index and dst in node_index: - self.traffic_matrix[node_index[src]][node_index[dst]] = bw - - # 构建约束 - c = np.ones(len(nodes) ** 2) # 最小化总流量 - A_ub, b_ub = self._build_capacity_constraints(nodes, node_index) - A_eq, b_eq = self._build_flow_conservation(nodes) - - # 求解线性规划 - res = linprog( - c=c, - A_ub=A_ub, - b_ub=b_ub, - A_eq=A_eq, - b_eq=b_eq, - bounds=(0, None), - method='highs' - ) - - if not res.success: - raise ValueError(f"Optimization failed: {res.message}") - - return self._format_results(res.x, nodes, node_index) - - def _build_capacity_constraints(self, nodes: List[str], node_index: Dict[str, int]) -> Tuple[ - np.ndarray, np.ndarray]: - """构建容量约束矩阵""" - A_ub = [] - b_ub = [] - - for u, v, data in self.graph.edges(data=True): - capacity = float(data.get('bandwidth', 1000)) # 确保float类型 - b_ub.append(capacity) - - constraint = np.zeros((len(nodes), len(nodes))) - for src, dst in [(u, v), (v, u)]: - if src in node_index and dst in node_index: - i, j = node_index[src], node_index[dst] - constraint[i][j] += 1 - A_ub.append(constraint.flatten()) - - return np.array(A_ub), np.array(b_ub) - - @staticmethod - def _format_results(solution: np.ndarray, nodes: List[str], node_index: Dict[str, int]) -> Dict[ - str, Dict[str, float]]: - """格式化优化结果""" - flows = solution.reshape((len(nodes), len(nodes))) - return { - nodes[i]: { - nodes[j]: float(flows[i][j]) # 明确转换为float - for j in range(len(nodes)) - if flows[i][j] > 0.001 - } - for i in range(len(nodes)) - } - - def optimize_qos(self, flows: List[FlowDemand]) -> Dict[str, Dict[str, Any]]: - """ - 带QoS的流量优化 - 返回: { - "源-目标": { - "path": 路径列表, - "allocated": 分配带宽, - "priority": 优先级 - } - } - """ - sorted_flows = sorted(flows, key=lambda x: x.priority, reverse=True) - results = {} - - # 使用实例变量而非局部变量 - self._initialize_capacities() # 重置剩余带宽 - - for flow in sorted_flows: - path = self.find_optimal_path(flow.source, flow.destination, flow.bandwidth) - if not path: - continue - - min_bw = min(self.graph[u][v]['bandwidth'] for u, v in zip(path[:-1], path[1:])) - allocated = min(min_bw, flow.bandwidth) - - # 更新剩余带宽 - for u, v in zip(path[:-1], path[1:]): - self.remaining_capacity[(u, v)] -= allocated - if self.remaining_capacity[(u, v)] <= 0: - self.graph[u][v]['bandwidth'] = 0 - - results[f"{flow.source}-{flow.destination}"] = { - "path": path, - "allocated": float(allocated), # 确保float类型 - "priority": float(flow.priority) - } - - return results - - def find_optimal_path(self, source: str, target: str, - bandwidth: float = 1.0) -> Optional[List[str]]: - """ - 改进的最优路径查找 - """ - try: - paths = nx.shortest_simple_paths( - self.graph, source, target, - weight=lambda u, v, d: 1 / max(1, d['bandwidth']) + d['latency'] # 避免除零 - ) - return next( - (path for path in paths - if self._path_has_sufficient_bandwidth(path, bandwidth)), - None - ) - except (nx.NetworkXNoPath, nx.NodeNotFound): - return None - - def _path_has_sufficient_bandwidth(self, path: List[str], bw: float) -> bool: - """检查路径带宽是否满足要求""" - return all( - self.graph[u][v]['bandwidth'] >= bw - for u, v in zip(path[:-1], path[1:]) - ) \ No newline at end of file diff --git a/src/backend/app/services/network_visualizer.py b/src/backend/app/services/network_visualizer.py deleted file mode 100644 index f9c28fd..0000000 --- a/src/backend/app/services/network_visualizer.py +++ /dev/null @@ -1,104 +0,0 @@ -from datetime import datetime, timedelta - -import networkx as nx -import matplotlib.pyplot as plt -from io import BytesIO -import base64 -from typing import Dict, List -from ..models.traffic_models import SwitchTrafficRecord -from src.backend.app.api.database import SessionLocal - - -class NetworkVisualizer: - def __init__(self): - self.graph = nx.Graph() - - def update_topology(self, devices: List[Dict]): - """更新网络拓扑图""" - self.graph.clear() - - # 添加节点 - for device in devices: - self.graph.add_node( - device['ip'], - type='switch', - label=f"Switch\n{device['ip']}" - ) - - # 添加连接(简化版,实际应根据扫描结果) - if len(devices) > 1: - for i in range(len(devices) - 1): - self.graph.add_edge( - devices[i]['ip'], - devices[i + 1]['ip'], - bandwidth=1000, - label="1Gbps" - ) - - def generate_topology_image(self) -> str: - """生成拓扑图并返回base64编码""" - plt.figure(figsize=(10, 8)) - pos = nx.spring_layout(self.graph) - - # 绘制节点 - node_colors = [] - for node in self.graph.nodes(): - if self.graph.nodes[node]['type'] == 'switch': - node_colors.append('lightblue') - - nx.draw_networkx_nodes( - self.graph, pos, - node_size=2000, - node_color=node_colors - ) - - # 绘制边 - nx.draw_networkx_edges( - self.graph, pos, - width=2, - alpha=0.5 - ) - - # 绘制标签 - node_labels = nx.get_node_attributes(self.graph, 'label') - nx.draw_networkx_labels( - self.graph, pos, - labels=node_labels, - font_size=8 - ) - - edge_labels = nx.get_edge_attributes(self.graph, 'label') - nx.draw_networkx_edge_labels( - self.graph, pos, - edge_labels=edge_labels, - font_size=8 - ) - - plt.title("Network Topology") - plt.axis('off') - - # 转换为base64 - buf = BytesIO() - plt.savefig(buf, format='png', bbox_inches='tight') - plt.close() - buf.seek(0) - return base64.b64encode(buf.read()).decode('utf-8') - - @staticmethod - def get_traffic_heatmap(minutes: int = 10) -> Dict: - """获取流量热力图数据""" - with SessionLocal() as session: - records = session.query(SwitchTrafficRecord).filter( - SwitchTrafficRecord.timestamp >= datetime.now() - timedelta(minutes=minutes) - ).all() - - heatmap_data = {} - for record in records: - if record.switch_ip not in heatmap_data: - heatmap_data[record.switch_ip] = {} - heatmap_data[record.switch_ip][record.interface] = { - 'in': record.rate_in, - 'out': record.rate_out - } - - return heatmap_data \ No newline at end of file diff --git a/src/backend/app/services/report_generator.py b/src/backend/app/services/report_generator.py deleted file mode 100644 index d6742f4..0000000 --- a/src/backend/app/services/report_generator.py +++ /dev/null @@ -1,120 +0,0 @@ -import pandas as pd -import matplotlib.pyplot as plt -from io import BytesIO -import base64 -from datetime import datetime, timedelta -from typing import Dict, List -from ..models.traffic_models import TrafficRecord, SwitchTrafficRecord -from src.backend.app.api.database import SessionLocal - - -class ReportGenerator: - @staticmethod - def generate_traffic_report(ip: str = None, days: int = 7) -> Dict: - """生成流量分析报告""" - with SessionLocal() as session: - time_threshold = datetime.now() - timedelta(days=days) - - if ip: - # 交换机流量报告 - query = session.query(SwitchTrafficRecord).filter( - SwitchTrafficRecord.switch_ip == ip, - SwitchTrafficRecord.timestamp >= time_threshold - ) - df = pd.read_sql(query.statement, session.bind) - - if df.empty: - return {"error": "No data found"} - - # 按接口分组分析 - interface_stats = df.groupby('interface').agg({ - 'rate_in': ['mean', 'max', 'min'], - 'rate_out': ['mean', 'max', 'min'] - }).reset_index() - - # 生成趋势图 - trend_fig = ReportGenerator._plot_traffic_trend(df, f"Switch {ip} Traffic Trend") - - return { - "switch_ip": ip, - "period": f"Last {days} days", - "interface_stats": interface_stats.to_dict('records'), - "trend_chart": trend_fig - } - else: - # 本地网络流量报告 - query = session.query(TrafficRecord).filter( - TrafficRecord.timestamp >= time_threshold - ) - df = pd.read_sql(query.statement, session.bind) - - if df.empty: - return {"error": "No data found"} - - # 按接口分组分析 - interface_stats = df.groupby('interface').agg({ - 'bytes_sent': ['sum', 'mean', 'max'], - 'bytes_recv': ['sum', 'mean', 'max'] - }).reset_index() - - # 生成趋势图 - trend_fig = ReportGenerator._plot_traffic_trend(df, "Local Network Traffic Trend") - - return { - "report_type": "local_network", - "period": f"Last {days} days", - "interface_stats": interface_stats.to_dict('records'), - "trend_chart": trend_fig - } - - @staticmethod - def _plot_traffic_trend(df: pd.DataFrame, title: str) -> str: - """生成流量趋势图""" - plt.figure(figsize=(12, 6)) - - if 'rate_in' in df.columns: # 交换机数据 - df_grouped = df.groupby(pd.Grouper(key='timestamp', freq='1H')).agg({ - 'rate_in': 'mean', - 'rate_out': 'mean' - }) - plt.plot(df_grouped.index, df_grouped['rate_in'], label='Inbound Traffic') - plt.plot(df_grouped.index, df_grouped['rate_out'], label='Outbound Traffic') - plt.ylabel("Traffic Rate (bytes/sec)") - else: # 本地网络数据 - df_grouped = df.groupby(pd.Grouper(key='timestamp', freq='1H')).agg({ - 'bytes_sent': 'sum', - 'bytes_recv': 'sum' - }) - plt.plot(df_grouped.index, df_grouped['bytes_sent'], label='Bytes Sent') - plt.plot(df_grouped.index, df_grouped['bytes_recv'], label='Bytes Received') - plt.ylabel("Bytes") - - plt.title(title) - plt.xlabel("Time") - plt.legend() - plt.grid(True) - - # 转换为base64 - buf = BytesIO() - plt.savefig(buf, format='png', bbox_inches='tight') - plt.close() - buf.seek(0) - return base64.b64encode(buf.read()).decode('utf-8') - - @staticmethod - def generate_config_history_report(days: int = 30) -> Dict: - """生成配置变更历史报告""" - # 需要实现配置历史记录功能 - pass - - @staticmethod - def generate_security_report() -> Dict: - """生成安全评估报告""" - # 需要实现安全扫描功能 - pass - - @staticmethod - def generate_performance_report() -> Dict: - """生成性能评估报告""" - # 需要实现性能基准测试功能 - pass \ No newline at end of file From 3260af32fcf508e3613a956680ecbd7d7e907d47 Mon Sep 17 00:00:00 2001 From: Jerry Date: Tue, 12 Aug 2025 17:15:34 +0800 Subject: [PATCH 4/7] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E5=BC=82=E6=AD=A5?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/README.md | 1 - src/backend/app/api/command_parser.py | 14 - src/backend/app/api/endpoints.py | 344 +++++++++++++++++--- src/backend/app/services/ai_services.py | 40 ++- src/backend/app/services/network_scanner.py | 12 +- 5 files changed, 326 insertions(+), 85 deletions(-) delete mode 100644 src/backend/app/api/command_parser.py diff --git a/src/backend/README.md b/src/backend/README.md index 0a34790..d0a536e 100644 --- a/src/backend/README.md +++ b/src/backend/README.md @@ -10,7 +10,6 @@ src/backend/ │ ├── __init__.py # 创建 Flask 应用实例 │ ├── api/ # API 路由模块 │ │ ├—── __init__.py # 注册 API 蓝图 -│ │ ├── command_parser.py # /api/parse_command 接口 │ │ └── network_config.py # /api/apply_config 接口 │ └── services/ # 核心服务逻辑 │ └── ai_services.py # 调用外部 AI 服务生成配置 diff --git a/src/backend/app/api/command_parser.py b/src/backend/app/api/command_parser.py deleted file mode 100644 index f198d10..0000000 --- a/src/backend/app/api/command_parser.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Dict, Any -from src.backend.app.services.ai_services import AIService -from src.backend.config import settings - - -class CommandParser: - def __init__(self): - self.ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL) - - async def parse(self, command: str) -> Dict[str, Any]: - """ - 解析中文命令并返回配置 - """ - return await self.ai_service.parse_command(command) diff --git a/src/backend/app/api/endpoints.py b/src/backend/app/api/endpoints.py index 4bcd616..20ed723 100644 --- a/src/backend/app/api/endpoints.py +++ b/src/backend/app/api/endpoints.py @@ -1,19 +1,30 @@ import socket -from fastapi import (APIRouter, HTTPException, Response) +from datetime import datetime, timedelta +from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect) from typing import List from pydantic import BaseModel +import asyncio +from fastapi.responses import HTMLResponse, JSONResponse +import matplotlib.pyplot as plt +import io +import base64 import psutil import ipaddress + +from ..services.switch_traffic_monitor import get_switch_monitor +from ..utils import logger from ...app.services.ai_services import AIService from ...app.api.network_config import SwitchConfigurator from ...config import settings from ..services.network_scanner import NetworkScanner - - +from ...app.services.traffic_monitor import traffic_monitor +from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord +from src.backend.app.api.database import SessionLocal router = APIRouter(prefix="", tags=["API"]) scanner = NetworkScanner() + @router.get("/", include_in_schema=False) async def root(): return { @@ -25,16 +36,18 @@ async def root(): "/apply_config", "/scan_network", "/list_devices", - "/batch_apply_config" + "/batch_apply_config", "/traffic/switch/current", "/traffic/switch/history" ] } + @router.get("/favicon.ico", include_in_schema=False) async def favicon(): return Response(status_code=204) + class BatchConfigRequest(BaseModel): config: dict switch_ips: List[str] @@ -42,14 +55,31 @@ class BatchConfigRequest(BaseModel): password: str = None timeout: int = None + +@router.post("/batch_apply_config") +async def batch_apply_config(request: BatchConfigRequest): + results = {} + for ip in request.switch_ips: + try: + configurator = SwitchConfigurator( + username=request.username, + password=request.password, + timeout=request.timeout) + results[ip] = await configurator.apply_config(ip, request.config) + except Exception as e: + results[ip] = str(e) + return {"results": results} + + @router.get("/test") async def test_endpoint(): return {"message": "Hello World"} + @router.get("/scan_network", summary="扫描网络中的交换机") async def scan_network(subnet: str = "192.168.1.0/24"): try: - devices = await scanner.scan_subnet(subnet) + devices = await asyncio.to_thread(scanner.scan_subnet, subnet) return { "success": True, "devices": devices, @@ -58,14 +88,18 @@ async def scan_network(subnet: str = "192.168.1.0/24"): except Exception as e: raise HTTPException(500, f"扫描失败: {str(e)}") + @router.get("/list_devices", summary="列出已发现的交换机") async def list_devices(): return { - "devices": await scanner.load_cached_devices() + "devices": await asyncio.to_thread(scanner.load_cached_devices) } + class CommandRequest(BaseModel): command: str + vendor: str = "huawei" + class ConfigRequest(BaseModel): config: dict @@ -73,15 +107,15 @@ class ConfigRequest(BaseModel): username: str = None password: str = None timeout: int = None + vendor: str = "huawei" + @router.post("/parse_command", response_model=dict) async def parse_command(request: CommandRequest): - """ - 解析中文命令并返回JSON配置 - """ + """解析中文命令并返回JSON配置""" try: ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL) - config = await ai_service.parse_command(request.command) + config = await ai_service.parse_command(request.command, request.vendor) return {"success": True, "config": config} except Exception as e: raise HTTPException( @@ -89,16 +123,16 @@ async def parse_command(request: CommandRequest): detail=f"Failed to parse command: {str(e)}" ) + @router.post("/apply_config", response_model=dict) async def apply_config(request: ConfigRequest): - """ - 应用配置到交换机(弃用) - """ + """应用配置到交换机""" try: configurator = SwitchConfigurator( username=request.username, password=request.password, - timeout=request.timeout + timeout=request.timeout, + vendor=request.vendor ) result = await configurator.safe_apply(request.switch_ip, request.config) return {"success": True, "result": result} @@ -132,14 +166,10 @@ class CLICommandRequest(BaseModel): return [cmd for cmd in self.commands if not (cmd.startswith("!username=") or cmd.startswith("!password="))] + @router.post("/execute_cli_commands", response_model=dict) async def execute_cli_commands(request: CLICommandRequest): - """ - 执行前端生成的CLI命令 - 支持在commands中嵌入凭据: - !username=admin - !password=cisco123 - """ + """执行前端生成的CLI命令""" try: username, password = request.extract_credentials() clean_commands = request.get_clean_commands() @@ -163,43 +193,255 @@ async def execute_cli_commands(request: CLICommandRequest): except Exception as e: raise HTTPException(500, detail=str(e)) -@router.get("/", include_in_schema=False) -async def root(): + +@router.get("/traffic/interfaces", summary="获取所有网络接口") +async def get_network_interfaces(): return { - "message": "欢迎使用AI交换机配置系统", - "docs": f"{settings.API_PREFIX}/docs", - "redoc": f"{settings.API_PREFIX}/redoc", - "endpoints": [ - "/parse_command", - "/apply_config", - "/scan_network", - "/list_devices", - "/batch_apply_config", - "/traffic/switch/current", - "/traffic/switch/history" - ] + "interfaces": await asyncio.to_thread(traffic_monitor.get_interfaces) } + + +@router.get("/traffic/current", summary="获取当前流量数据") +async def get_current_traffic(interface: str = None): + return await asyncio.to_thread(traffic_monitor.get_current_traffic, interface) + + +@router.get("/traffic/history", summary="获取流量历史数据") +async def get_traffic_history(interface: str = None, limit: int = 100): + history = await asyncio.to_thread(traffic_monitor.get_traffic_history, interface) + return { + "sent": history["sent"][-limit:], + "recv": history["recv"][-limit:], + "time": [t.isoformat() for t in history["time"]][-limit:] + } + + +@router.get("/traffic/records", summary="获取流量记录") +async def get_traffic_records(interface: str = None, limit: int = 100): + def sync_get_records(): + with SessionLocal() as session: + query = session.query(TrafficRecord) + if interface: + query = query.filter(TrafficRecord.interface == interface) + records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all() + return [record.to_dict() for record in records] + + return await asyncio.to_thread(sync_get_records) + + +@router.websocket("/ws/traffic") +async def websocket_traffic(websocket: WebSocket): + await websocket.accept() + try: + while True: + traffic_data = await asyncio.to_thread(traffic_monitor.get_current_traffic) + await websocket.send_json(traffic_data) + await asyncio.sleep(1) + except WebSocketDisconnect: + print("客户端断开连接") + + +@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口") +async def get_switch_interfaces(switch_ip: str): + try: + monitor = get_switch_monitor(switch_ip) + interfaces = list(monitor.interface_oids.keys()) + return { + "switch_ip": switch_ip, + "interfaces": interfaces + } + except Exception as e: + logger.error(f"获取交换机接口失败: {str(e)}") + raise HTTPException(500, f"获取接口失败: {str(e)}") + + +async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict: + """获取指定交换机接口的当前流量数据""" + try: + def sync_get_record(): + with SessionLocal() as session: + record = session.query(SwitchTrafficRecord).filter( + SwitchTrafficRecord.switch_ip == switch_ip, + SwitchTrafficRecord.interface == interface + ).order_by(SwitchTrafficRecord.timestamp.desc()).first() + + if not record: + return { + "switch_ip": switch_ip, + "interface": interface, + "rate_in": 0.0, + "rate_out": 0.0, + "bytes_in": 0, + "bytes_out": 0 + } + + return { + "switch_ip": switch_ip, + "interface": interface, + "rate_in": record.rate_in, + "rate_out": record.rate_out, + "bytes_in": record.bytes_in, + "bytes_out": record.bytes_out + } + + return await asyncio.to_thread(sync_get_record) + except Exception as e: + logger.error(f"获取接口流量失败: {str(e)}") + raise HTTPException(500, f"获取接口流量失败: {str(e)}") + + +@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据") +async def get_switch_current_traffic(switch_ip: str, interface: str = None): + try: + monitor = get_switch_monitor(switch_ip) + if not interface: + traffic_data = {} + for iface in monitor.interface_oids: + traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface) + return { + "switch_ip": switch_ip, + "traffic": traffic_data + } + return await get_interface_current_traffic(switch_ip, interface) + except Exception as e: + logger.error(f"获取交换机流量失败: {str(e)}") + raise HTTPException(500, f"获取流量失败: {str(e)}") + + +@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据") +async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10): + try: + monitor = get_switch_monitor(switch_ip) + if not interface: + return { + "switch_ip": switch_ip, + "history": await asyncio.to_thread(monitor.get_traffic_history) + } + + def sync_get_history(): + with SessionLocal() as session: + time_threshold = datetime.now() - timedelta(minutes=minutes) + records = session.query(SwitchTrafficRecord).filter( + SwitchTrafficRecord.switch_ip == switch_ip, + SwitchTrafficRecord.interface == interface, + SwitchTrafficRecord.timestamp >= time_threshold + ).order_by(SwitchTrafficRecord.timestamp.asc()).all() + + history_data = { + "in": [record.rate_in for record in records], + "out": [record.rate_out for record in records], + "time": [record.timestamp.isoformat() for record in records] + } + return history_data + + history_data = await asyncio.to_thread(sync_get_history) + return { + "switch_ip": switch_ip, + "interface": interface, + "history": history_data + } + except Exception as e: + logger.error(f"获取历史流量失败: {str(e)}") + raise HTTPException(500, f"获取历史流量失败: {str(e)}") + + +@router.websocket("/ws/traffic/switch") +async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None): + await websocket.accept() + try: + monitor = get_switch_monitor(switch_ip) + while True: + if interface: + traffic_data = await get_interface_current_traffic(switch_ip, interface) + await websocket.send_json(traffic_data) + else: + traffic_data = {} + for iface in monitor.interface_oids: + traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface) + await websocket.send_json({ + "switch_ip": switch_ip, + "traffic": traffic_data + }) + await asyncio.sleep(1) + except WebSocketDisconnect: + logger.info(f"客户端断开交换机流量连接: {switch_ip}") + except Exception as e: + logger.error(f"交换机流量WebSocket错误: {str(e)}") + await websocket.close(code=1011, reason=str(e)) + + +@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化") +async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10): + try: + history = await get_switch_traffic_history(switch_ip, interface, minutes) + history_data = history["history"] + time_points = [datetime.fromisoformat(t) for t in history_data["time"]] + in_rates = history_data["in"] + out_rates = history_data["out"] + + def generate_plot(): + plt.figure(figsize=(12, 6)) + plt.plot(time_points, in_rates, label="流入流量 (B/s)") + plt.plot(time_points, out_rates, label="流出流量 (B/s)") + plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟") + plt.xlabel("时间") + plt.ylabel("流量 (字节/秒)") + plt.legend() + plt.grid(True) + plt.xticks(rotation=45) + plt.tight_layout() + buf = io.BytesIO() + plt.savefig(buf, format="png") + buf.seek(0) + image_base64 = base64.b64encode(buf.read()).decode("utf-8") + plt.close() + return image_base64 + + image_base64 = await asyncio.to_thread(generate_plot) + return f""" + + + 交换机流量监控 + + + +
+

交换机 {switch_ip} 接口 {interface} 流量监控

+ 流量图表 +

更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

+
+ + + """ + except Exception as e: + logger.error(f"生成流量图表失败: {str(e)}") + return HTMLResponse(content=f"

错误

{str(e)}

", status_code=500) + + @router.get("/network_adapters", summary="获取网络适配器网段") async def get_network_adapters(): try: - net_if_addrs = psutil.net_if_addrs() - - networks = [] - for interface, addrs in net_if_addrs.items(): - for addr in addrs: - if addr.family == socket.AF_INET: - ip = addr.address - netmask = addr.netmask - - network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False) - networks.append({ - "adapter": interface, - "network": str(network), - "ip": ip, - "subnet_mask": netmask - }) + def sync_get_adapters(): + net_if_addrs = psutil.net_if_addrs() + networks = [] + for interface, addrs in net_if_addrs.items(): + for addr in addrs: + if addr.family == socket.AF_INET: + ip = addr.address + netmask = addr.netmask + network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False) + networks.append({ + "adapter": interface, + "network": str(network), + "ip": ip, + "subnet_mask": netmask + }) + return networks + networks = await asyncio.to_thread(sync_get_adapters) return {"networks": networks} - except Exception as e: return {"error": f"获取网络适配器信息失败: {str(e)}"} \ No newline at end of file diff --git a/src/backend/app/services/ai_services.py b/src/backend/app/services/ai_services.py index d315477..2987bd0 100644 --- a/src/backend/app/services/ai_services.py +++ b/src/backend/app/services/ai_services.py @@ -1,7 +1,5 @@ -from typing import Dict, Any, Coroutine - -import httpx -from openai import OpenAI +from typing import Any +from openai import AsyncOpenAI import json from src.backend.app.utils.exceptions import SiliconFlowAPIException from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam @@ -12,28 +10,44 @@ class AIService: def __init__(self, api_key: str, api_url: str): self.api_key = api_key self.api_url = api_url - self.client = OpenAI( + self.client = AsyncOpenAI( api_key=self.api_key, base_url=self.api_url, # timeout=httpx.Timeout(30.0) ) - async def parse_command(self, command: str) -> Any | None: + async def parse_command(self, command: str, vendor: str = "huawei") -> Any | None: """ 调用硅基流动API解析中文命令 """ - prompt = """ - 你是一个网络设备配置专家,精通各种类型的路由器的配置,请将以下用户的中文命令转换为网络设备配置JSON - 但是请注意,由于贪婪的人们追求极高的效率,所以你必须严格按照 JSON 格式返回数据,不要包含任何额外文本或 Markdown 代码块 + vendor_prompts = { + "huawei": "华为交换机配置命令", + "cisco": "思科交换机配置命令", + "h3c": "H3C交换机配置命令", + "ruijie": "锐捷交换机配置命令", + "zte": "中兴交换机配置命令" + } + + prompt = f""" + 你是一个网络设备配置专家,精通各种类型的路由器的配置,请将以下用户的中文命令转换为{vendor_prompts.get(vendor, '网络设备')}配置JSON。 + 但是请注意,由于贪婪的人们追求极高的效率,所以你必须严格按照 JSON 格式返回数据,不要包含任何额外文本或 Markdown 代码块。 返回格式要求: 1. 必须包含'type'字段指明配置类型(vlan/interface/acl/route等) 2. 必须包含'commands'字段,包含可直接执行的命令列表 3. 其他参数根据配置类型动态添加 4. 不要包含解释性文本、步骤说明或注释 - 5.要包含使用ssh连接交换机后的完整命令包括但不完全包括system-view,退出,保存等完整操作,注意保存还需要输入Y + 5. 要包含使用ssh连接交换机后的完整命令包括但不完全包括system-view,退出,保存等完整操作,注意保存还需要输入Y + + 根据厂商{vendor}的不同,命令格式如下: + - 华为: system-view → quit → save Y + - 思科: enable → configure terminal → exit → write memory + - H3C: system-view → quit → save + - 锐捷: enable → configure terminal → exit → write + - 中兴: enable → configure terminal → exit → write memory + 示例命令:'创建VLAN 100,名称为TEST' - 示例返回:{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["system-view","vlan 100", "name TEST","quit","quit","save","Y"]} - 注意:这里生成的commands中需包含登录交换机和保存等所有操作命令,我们使ssh连接交换机,你不需要给出连接ssh的命令,你只需要给出使用ssh连接到交换机后所输入的全部命令,并且注意在system-view状态下是不能save的,需要再quit到用户视图 + 华为示例返回:{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["system-view","vlan 100", "name TEST","quit","quit","save","Y"]}} + 思科示例返回:{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["enable","configure terminal","vlan 100", "name TEST","exit","exit","write memory"]}} """ messages = [ @@ -42,7 +56,7 @@ class AIService: ] try: - response = self.client.chat.completions.create( + response = await self.client.chat.completions.create( model="deepseek-ai/DeepSeek-V3", messages=messages, temperature=0.3, diff --git a/src/backend/app/services/network_scanner.py b/src/backend/app/services/network_scanner.py index 77ac5ff..bdb7f0b 100644 --- a/src/backend/app/services/network_scanner.py +++ b/src/backend/app/services/network_scanner.py @@ -9,13 +9,13 @@ class NetworkScanner: self.cache_path = Path(cache_path) self.nm = nmap.PortScanner() - async def scan_subnet(self, subnet: str = "192.168.1.0/24", ports: List[int] = [22, 23, 80]) -> List[Dict]: + def scan_subnet(self, subnet: str = "192.168.1.0/24", ports: List[int] = [22, 23, 80]) -> List[Dict]: """扫描指定子网的设备,获取设备信息和开放端口""" logger.info(f"Scanning subnet: {subnet}") devices = [] try: - await self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}') + self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}') for host in self.nm.all_hosts(): ip = host mac = self.nm[host]['addresses'].get('mac', 'N/A') @@ -33,19 +33,19 @@ class NetworkScanner: except Exception as e: logger.error(f"Error while scanning subnet: {e}") - await self._save_to_cache(devices) + self._save_to_cache(devices) return devices - async def _save_to_cache(self, devices: List[Dict]): + def _save_to_cache(self, devices: List[Dict]): """保存扫描结果到本地文件""" with open(self.cache_path, "w") as f: json.dump(devices, f, indent=2) logger.info(f"Saved {len(devices)} devices to cache") - async def load_cached_devices(self) -> List[Dict]: + def load_cached_devices(self) -> List[Dict]: """从缓存加载设备列表""" if not self.cache_path.exists(): return [] with open(self.cache_path) as f: - return json.load(f) + return json.load(f) \ No newline at end of file From 7fee04bc0a43694310274549093bb298a1c1e5e8 Mon Sep 17 00:00:00 2001 From: Jerry <129190939+Jerryplusy@users.noreply.github.com> Date: Thu, 14 Aug 2025 22:26:06 +0800 Subject: [PATCH 5/7] =?UTF-8?q?=F0=9F=91=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 3cf6031..59082e7 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # 基于人工智能实现的交换机自动或半自动配置 +这只是一个用于应付实验论文的奇葩储存库 + +猎奇卡顿 · 龟速更新 · 随时跑路 + ### 技术栈 - **Python3** - Flask @@ -11,10 +15,6 @@ - Framer-motion - chakra-ui - HTML5 -### 项目分工 -- **后端api,人工智能算法** : `3`(主要) & `log_out` & `Jerry`(maybe) 使用python -- **前端管理后台设计**:`Jerry`使用react -- **论文撰写**:`log_out` ### 各部分说明 @@ -22,18 +22,3 @@ [逻辑处理后端](https://github.com/Jerryplusy/AI-powered-switches/blob/main/src/backend/README.md) -### 贡献流程 - -- **后端api**: - - 对于`3`:直接推送到`main`分支 - - 对于`Jerry`&`log_out`:新建额外的`feat`或`fix`分支,提交推送到自己的分支,然后提交`pullrequest`到`main`分支并指定`3`审核 - -- **前端管理后台**: - - 对于`Jerry`:直接推送更新到`main`分支 - - 对于`3`&`log_out`:新建额外的`feat`或`fix`分支,提交推送到自己的分支,然后提交`pullrequest`到`main`分支并指定`Jerry`审核 - -- **论文(thesis)**: - - 提交`pullrequest`并指定`log_out`审核 -### 项目活动时间 -2025 6 - 8月 - From eb6aeb52162d47fcb2f4851cdb6ad8f48a6897d0 Mon Sep 17 00:00:00 2001 From: Jerry Date: Thu, 21 Aug 2025 19:17:14 +0800 Subject: [PATCH 6/7] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=89=B9=E9=87=8F?= =?UTF-8?q?=E6=93=8D=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/api/endpoints.py | 76 +---- src/backend/app/api/network_config.py | 317 ++++-------------- src/backend/app/models/requests.py | 26 ++ src/backend/app/services/ai_services.py | 88 ++--- .../pages/config/DeviceConfigModal.jsx | 56 +++- src/frontend/src/pages/ConfigPage.jsx | 243 +++++++------- src/frontend/src/services/api/api.js | 24 +- 7 files changed, 313 insertions(+), 517 deletions(-) create mode 100644 src/backend/app/models/requests.py diff --git a/src/backend/app/api/endpoints.py b/src/backend/app/api/endpoints.py index 20ed723..697e687 100644 --- a/src/backend/app/api/endpoints.py +++ b/src/backend/app/api/endpoints.py @@ -11,6 +11,7 @@ import base64 import psutil import ipaddress +from ..models.requests import CLICommandRequest, ConfigRequest from ..services.switch_traffic_monitor import get_switch_monitor from ..utils import logger from ...app.services.ai_services import AIService @@ -55,22 +56,6 @@ class BatchConfigRequest(BaseModel): password: str = None timeout: int = None - -@router.post("/batch_apply_config") -async def batch_apply_config(request: BatchConfigRequest): - results = {} - for ip in request.switch_ips: - try: - configurator = SwitchConfigurator( - username=request.username, - password=request.password, - timeout=request.timeout) - results[ip] = await configurator.apply_config(ip, request.config) - except Exception as e: - results[ip] = str(e) - return {"results": results} - - @router.get("/test") async def test_endpoint(): return {"message": "Hello World"} @@ -96,27 +81,29 @@ async def list_devices(): } +class DeviceItem(BaseModel): + name: str + ip: str + vendor: str + class CommandRequest(BaseModel): command: str - vendor: str = "huawei" - - -class ConfigRequest(BaseModel): - config: dict - switch_ip: str - username: str = None - password: str = None - timeout: int = None - vendor: str = "huawei" - + devices: List[DeviceItem] @router.post("/parse_command", response_model=dict) async def parse_command(request: CommandRequest): - """解析中文命令并返回JSON配置""" + """解析中文命令并返回每台设备的配置 JSON""" + missing_vendor = [d for d in request.devices if not d.vendor or d.vendor.strip() == ""] + if missing_vendor: + names = ", ".join([d.name for d in missing_vendor]) + raise HTTPException( + status_code=400, + detail=f"以下设备未配置厂商: {names}" + ) try: ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL) - config = await ai_service.parse_command(request.command, request.vendor) - return {"success": True, "config": config} + config = await ai_service.parse_command(request.command, [d.dict() for d in request.devices]) + return {"success": True, "config": config.get("results", [])} except Exception as e: raise HTTPException( status_code=400, @@ -141,44 +128,16 @@ async def apply_config(request: ConfigRequest): status_code=500, detail=f"Failed to apply config: {str(e)}" ) - - -class CLICommandRequest(BaseModel): - switch_ip: str - commands: List[str] - is_ensp: bool = False - - def extract_credentials(self) -> tuple: - """从commands中提取用户名和密码""" - username = None - password = None - - for cmd in self.commands: - if cmd.startswith("!username="): - username = cmd.split("=")[1] - elif cmd.startswith("!password="): - password = cmd.split("=")[1] - - return username, password - - def get_clean_commands(self) -> List[str]: - """获取去除凭据后的实际命令""" - return [cmd for cmd in self.commands - if not (cmd.startswith("!username=") or cmd.startswith("!password="))] - - @router.post("/execute_cli_commands", response_model=dict) async def execute_cli_commands(request: CLICommandRequest): """执行前端生成的CLI命令""" try: username, password = request.extract_credentials() - clean_commands = request.get_clean_commands() configurator = SwitchConfigurator( username=username, password=password, timeout=settings.SWITCH_TIMEOUT, - ensp_mode=request.is_ensp ) result = await configurator.execute_raw_commands( @@ -188,7 +147,6 @@ async def execute_cli_commands(request: CLICommandRequest): return { "success": True, "output": result, - "mode": "eNSP" if request.is_ensp else "SSH" } except Exception as e: raise HTTPException(500, detail=str(e)) diff --git a/src/backend/app/api/network_config.py b/src/backend/app/api/network_config.py index bd0c225..b512273 100644 --- a/src/backend/app/api/network_config.py +++ b/src/backend/app/api/network_config.py @@ -1,19 +1,14 @@ import asyncio import logging import telnetlib3 -from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional -import aiofiles -import asyncssh from pydantic import BaseModel -from tenacity import retry, stop_after_attempt, wait_exponential from src.backend.app.utils.logger import logger from src.backend.config import settings - # ---------------------- # 数据模型 # ---------------------- @@ -25,36 +20,31 @@ class SwitchConfig(BaseModel): ip_address: Optional[str] = None vlan: Optional[int] = None - # ---------------------- # 异常类 # ---------------------- class SwitchConfigException(Exception): pass - class EnspConnectionException(SwitchConfigException): pass - -class SSHConnectionException(SwitchConfigException): - pass - - # ---------------------- -# 核心配置器(完整双模式) +# 核心配置器 # ---------------------- class SwitchConfigurator: + connection_pool: Dict[str, tuple] = {} + def __init__( - self, - username: str = None, - password: str = None, - timeout: int = None, - max_workers: int = 5, - ensp_mode: bool = False, - ensp_port: int = 2000, - ensp_command_delay: float = 0.5, - **ssh_options + self, + username: str = None, + password: str = None, + timeout: int = None, + max_workers: int = 5, + ensp_mode: bool = False, + ensp_port: int = 2000, + ensp_command_delay: float = 0.5, + **ssh_options ): self.username = username if username is not None else settings.SWITCH_USERNAME self.password = password if password is not None else settings.SWITCH_PASSWORD @@ -67,253 +57,66 @@ class SwitchConfigurator: self.ensp_delay = ensp_command_delay self.ssh_options = ssh_options - async def apply_config(self, ip: str, config: Union[Dict, SwitchConfig]) -> str: - """实际配置逻辑""" - if isinstance(config, dict): - config = SwitchConfig(**config) - - commands = ( - self._generate_ensp_commands(config) - if self.ensp_mode - else self._generate_standard_commands(config) - ) - return await self._send_commands(ip, commands) - - async def _send_commands(self, ip: str, commands: List[str]) -> str: - """双模式命令发送""" - return ( - await self._send_ensp_commands(ip, commands) - ) - - async def _send_ensp_commands(self, ip: str, commands: List[str]) -> str: + async def _get_or_create_connection(self, ip: str): """ - 通过 Telnet 协议连接 eNSP 设备 + 从连接池获取连接,如果没有则新建 Telnet 连接 + """ + if ip in self.connection_pool: + logger.debug(f"复用已有连接: {ip}") + return self.connection_pool[ip] + + logger.info(f"建立新连接: {ip}") + reader, writer = await telnetlib3.open_connection(host=ip, port=23) + + try: + if self.username != 'NONE' : + await asyncio.wait_for(reader.readuntil(b"Username:"), timeout=self.timeout) + writer.write(f"{self.username}\n") + + await asyncio.wait_for(reader.readuntil(b"Password:"), timeout=self.timeout) + writer.write(f"{self.password}\n") + + await asyncio.sleep(1) + except asyncio.TimeoutError: + writer.close() + raise EnspConnectionException("登录超时,未收到用户名或密码提示") + except Exception as e: + writer.close() + raise EnspConnectionException(f"登录异常: {e}") + + self.connection_pool[ip] = (reader, writer) + return reader, writer + + async def _send_ensp_commands(self, ip: str, commands: List[str]) -> bool: + """ + 通过 Telnet 协议发送命令 """ try: - logger.info(f"连接设备 {ip},端口23") - reader, writer = await telnetlib3.open_connection(host=ip, port=23) - logger.debug("连接成功,开始登录流程") + reader, writer = await self._get_or_create_connection(ip) - try: - if self.username != 'NONE': - await asyncio.wait_for(reader.readuntil(b"Username:"), timeout=self.timeout) - logger.debug("收到 'Username:' 提示,发送用户名") - writer.write(f"{self.username}\n") - - await asyncio.wait_for(reader.readuntil(b"Password:"), timeout=self.timeout) - logger.debug("收到 'Password:' 提示,发送密码") - writer.write(f"{self.password}\n") - - await asyncio.sleep(1) - except asyncio.TimeoutError: - raise EnspConnectionException("登录超时,未收到用户名或密码提示") - - output = "" for cmd in commands: if cmd.startswith("!"): logger.debug(f"跳过特殊命令: {cmd}") continue - logger.info(f"发送命令: {cmd}") + logger.info(f"[{ip}] 发送命令: {cmd}") writer.write(f"{cmd}\n") await writer.drain() + await asyncio.sleep(self.ensp_delay) - command_output = "" - try: - while True: - data = await asyncio.wait_for(reader.read(1024), timeout=1) - if not data: - logger.debug("读取到空数据,结束当前命令读取") - break - command_output += data - logger.debug(f"收到数据: {repr(data)}") - except asyncio.TimeoutError: - logger.debug("命令输出读取超时,继续执行下一条命令") - - output += f"\n[命令: {cmd} 输出开始]\n{command_output}\n[命令: {cmd} 输出结束]\n" - - logger.info("所有命令执行完毕,关闭连接") - writer.close() - - return output + logger.info(f"[{ip}] 所有命令发送完成") + return True except asyncio.TimeoutError as e: - logger.error(f"连接或读取超时: {e}") - raise EnspConnectionException(f"eNSP连接超时: {e}") + logger.error(f"[{ip}] 连接或读取超时: {e}") + return False except Exception as e: - logger.error(f"连接或执行异常: {e}", exc_info=True) - raise EnspConnectionException(f"eNSP连接失败: {e}") - - @staticmethod - def _generate_ensp_commands(config: SwitchConfig) -> List[str]: - """生成eNSP命令序列""" - commands = ["system-view"] - if config.type == "vlan": - commands.extend([ - f"vlan {config.vlan_id}", - f"description {config.name or ''}" - ]) - elif config.type == "interface": - commands.extend([ - f"interface {config.interface}", - "port link-type access", - f"port default vlan {config.vlan}" if config.vlan else "", - f"ip address {config.ip_address}" if config.ip_address else "" - ]) - commands.append("return") - return [c for c in commands if c.strip()] - - async def _send_ssh_commands(self, ip: str, commands: List[str]) -> str: - """AsyncSSH执行命令""" - async with self.semaphore: - try: - async with asyncssh.connect( - host=ip, - username=self.username, - password=self.password, - connect_timeout=self.timeout, - **self.ssh_options - ) as conn: - results = [] - for cmd in commands: - result = await conn.run(cmd, check=True) - results.append(result.stdout) - return "\n".join(results) - except asyncssh.Error as e: - raise SSHConnectionException(f"SSH操作失败: {str(e)}") - except Exception as e: - raise SSHConnectionException(f"连接异常: {str(e)}") - - async def execute_raw_commands(self, ip: str, commands: List[str]) -> str: - """ - 执行原始CLI命令 - """ - return await self._send_commands(ip, commands) - - - @staticmethod - def _generate_standard_commands(config: SwitchConfig) -> List[str]: - """生成标准CLI命令""" - commands = [] - if config.type == "vlan": - commands.extend([ - f"vlan {config.vlan_id}", - f"name {config.name or ''}" - ]) - elif config.type == "interface": - commands.extend([ - f"interface {config.interface}", - f"switchport access vlan {config.vlan}" if config.vlan else "", - f"ip address {config.ip_address}" if config.ip_address else "" - ]) - return commands - - async def _validate_config(self, ip: str, config: SwitchConfig) -> bool: - """验证配置是否生效""" - current = await self._get_current_config(ip) - if config.type == "vlan": - return f"vlan {config.vlan_id}" in current - elif config.type == "interface" and config.vlan: - return f"switchport access vlan {config.vlan}" in current - return True - - async def _get_current_config(self, ip: str) -> str: - """获取当前配置""" - commands = ( - ["display current-configuration"] - if self.ensp_mode - else ["show running-config"] - ) - try: - return await self._send_commands(ip, commands) - except (EnspConnectionException, SSHConnectionException) as e: - raise SwitchConfigException(f"配置获取失败: {str(e)}") - - async def _backup_config(self, ip: str) -> Path: - """备份配置到文件""" - backup_path = self.backup_dir / f"{ip}_{datetime.now().isoformat()}.cfg" - config = await self._get_current_config(ip) - async with aiofiles.open(backup_path, "w") as f: - await f.write(config) - return backup_path - - async def _restore_config(self, ip: str, backup_path: Path) -> bool: - """从备份恢复配置""" - try: - async with aiofiles.open(backup_path) as f: - config = await f.read() - commands = ( - ["system-view", config, "return"] - if self.ensp_mode - else [f"configure terminal\n{config}\nend"] - ) - await self._send_commands(ip, commands) - return True - except Exception as e: - logging.error(f"恢复失败: {str(e)}") + logger.error(f"[{ip}] 命令发送异常: {e}", exc_info=True) return False - - @retry( - stop=stop_after_attempt(2), - wait=wait_exponential(multiplier=1, min=4, max=10) - ) - async def safe_apply( - self, - ip: str, - config: Union[Dict, SwitchConfig] - ) -> Dict[str, Union[str, bool, Path]]: - """安全配置应用(自动回滚)""" - backup_path = await self._backup_config(ip) - try: - result = await self.apply_config(ip, config) - if not await self._validate_config(ip, config): - raise SwitchConfigException("配置验证失败") - return { - "status": "success", - "output": result, - "backup_path": str(backup_path) - } - except (EnspConnectionException, SSHConnectionException, SwitchConfigException) as e: - restore_status = await self._restore_config(ip, backup_path) - return { - "status": "failed", - "error": str(e), - "backup_path": str(backup_path), - "restore_success": restore_status - } - - -# ---------------------- -# 使用示例 -# ---------------------- -async def demo(): - ensp_configurator = SwitchConfigurator( - ensp_mode=True, - ensp_port=2000, - username="admin", - password="admin", - timeout=15 - ) - ensp_result = await ensp_configurator.safe_apply("127.0.0.1", { - "type": "interface", - "interface": "GigabitEthernet0/0/1", - "vlan": 100, - "ip_address": "192.168.1.2 255.255.255.0" - }) - print("eNSP配置结果:", ensp_result) - - ssh_configurator = SwitchConfigurator( - username="cisco", - password="cisco123", - timeout=15 - ) - ssh_result = await ssh_configurator.safe_apply("192.168.1.1", { - "type": "vlan", - "vlan_id": 200, - "name": "Production" - }) - print("SSH配置结果:", ssh_result) - - -if __name__ == "__main__": - asyncio.run(demo()) + async def execute_raw_commands(self, ip: str, commands: List[str]) -> bool: + """ + 对外接口:单台交换机执行命令 + """ + async with self.semaphore: + success = await self._send_ensp_commands(ip, commands) + return success diff --git a/src/backend/app/models/requests.py b/src/backend/app/models/requests.py new file mode 100644 index 0000000..4c696f1 --- /dev/null +++ b/src/backend/app/models/requests.py @@ -0,0 +1,26 @@ +from typing import List, Optional +from pydantic import BaseModel + +class BatchConfigRequest(BaseModel): + config: dict + switch_ips: List[str] + username: Optional[str] = None + password: Optional[str] = None + timeout: Optional[int] = None + +class ConfigRequest(BaseModel): + config: dict + switch_ip: str + username: Optional[str] = None + password: Optional[str] = None + timeout: Optional[int] = None + vendor: str = "huawei" + +class CLICommandRequest(BaseModel): + switch_ip: str + commands: List[str] + username: Optional[str] = None + password: Optional[str] = None + + def extract_credentials(self): + return self.username or "NONE", self.password or "NONE" diff --git a/src/backend/app/services/ai_services.py b/src/backend/app/services/ai_services.py index 2987bd0..74e48c9 100644 --- a/src/backend/app/services/ai_services.py +++ b/src/backend/app/services/ai_services.py @@ -1,54 +1,47 @@ -from typing import Any +from typing import Any, List, Dict from openai import AsyncOpenAI import json from src.backend.app.utils.exceptions import SiliconFlowAPIException from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam -from src.backend.app.utils.logger import logger - class AIService: def __init__(self, api_key: str, api_url: str): - self.api_key = api_key - self.api_url = api_url - self.client = AsyncOpenAI( - api_key=self.api_key, - base_url=self.api_url, - # timeout=httpx.Timeout(30.0) - ) + self.client = AsyncOpenAI(api_key=api_key, base_url=api_url) - async def parse_command(self, command: str, vendor: str = "huawei") -> Any | None: + async def parse_command(self, command: str, devices: List[Dict]) -> Dict[str, Any]: """ - 调用硅基流动API解析中文命令 + 针对一组设备和一条自然语言命令,生成每台设备的配置 JSON """ - vendor_prompts = { - "huawei": "华为交换机配置命令", - "cisco": "思科交换机配置命令", - "h3c": "H3C交换机配置命令", - "ruijie": "锐捷交换机配置命令", - "zte": "中兴交换机配置命令" - } + devices_str = json.dumps(devices, ensure_ascii=False, indent=2) + + example = """[{"device": {"name": "sw1","ip": "192.168.1.10","vendor": "huawei","username": "NONE", "password": "Huawei"},"config": {"type": "vlan","vlan_id": 300,"name": "Sales","commands": ["system-view","vlan 300","name Sales","quit","quit","save","Y"]}}]""" prompt = f""" - 你是一个网络设备配置专家,精通各种类型的路由器的配置,请将以下用户的中文命令转换为{vendor_prompts.get(vendor, '网络设备')}配置JSON。 - 但是请注意,由于贪婪的人们追求极高的效率,所以你必须严格按照 JSON 格式返回数据,不要包含任何额外文本或 Markdown 代码块。 - 返回格式要求: - 1. 必须包含'type'字段指明配置类型(vlan/interface/acl/route等) - 2. 必须包含'commands'字段,包含可直接执行的命令列表 - 3. 其他参数根据配置类型动态添加 - 4. 不要包含解释性文本、步骤说明或注释 - 5. 要包含使用ssh连接交换机后的完整命令包括但不完全包括system-view,退出,保存等完整操作,注意保存还需要输入Y +你是一个网络设备配置专家。现在有以下设备: +{devices_str} - 根据厂商{vendor}的不同,命令格式如下: - - 华为: system-view → quit → save Y - - 思科: enable → configure terminal → exit → write memory - - H3C: system-view → quit → save - - 锐捷: enable → configure terminal → exit → write - - 中兴: enable → configure terminal → exit → write memory +用户输入了一条命令:{command} - 示例命令:'创建VLAN 100,名称为TEST' - 华为示例返回:{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["system-view","vlan 100", "name TEST","quit","quit","save","Y"]}} - 思科示例返回:{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["enable","configure terminal","vlan 100", "name TEST","exit","exit","write memory"]}} - """ +你的任务: +- 为每台设备分别生成配置 +- 输出一个 JSON 数组,每个元素对应一台设备 +- 每个对象必须包含: + - device: 原始设备信息 (name, ip, vendor,username,password) + - config: 配置详情 + - type: 配置类型 (如 vlan/interface/acl/route) + - commands: 可直接执行的命令数组 (必须包含进入配置、退出、保存命令) + - 其他字段: 根据配置类型动态添加 +- 严格返回 JSON,不要包含解释说明或 markdown + +各厂商保存命令规则: +- 华为: system-view → quit → save Y +- 思科: enable → configure terminal → exit → write memory +- H3C: system-view → quit → save +- 锐捷: enable → configure terminal → exit → write +- 中兴: enable → configure terminal → exit → write memory + +返回示例(仅作为格式参考,不要照抄 VLAN ID 和命令内容,请根据实际命令生成):{example} +""" messages = [ ChatCompletionSystemMessageParam(role="system", content=prompt), @@ -59,29 +52,18 @@ class AIService: response = await self.client.chat.completions.create( model="deepseek-ai/DeepSeek-V3", messages=messages, - temperature=0.3, - max_tokens=1000, + temperature=0.2, + max_tokens=1500, response_format={"type": "json_object"} ) - logger.debug(response) - config_str = response.choices[0].message.content.strip() + configs = json.loads(config_str) - try: - config = json.loads(config_str) - return config - except json.JSONDecodeError: - if config_str.startswith("```json"): - config_str = config_str[7:-3].strip() - return json.loads(config_str) - raise SiliconFlowAPIException("Invalid JSON format returned from AI") - except KeyError: - logger.error(KeyError) - raise SiliconFlowAPIException("errrrrrrro") + return {"success": True, "results": configs} except Exception as e: raise SiliconFlowAPIException( - detail=f"API请求失败: {str(e)}", + detail=f"AI 解析配置失败: {str(e)}", status_code=getattr(e, "status_code", 500) ) diff --git a/src/frontend/src/components/pages/config/DeviceConfigModal.jsx b/src/frontend/src/components/pages/config/DeviceConfigModal.jsx index 744ad81..cd495d0 100644 --- a/src/frontend/src/components/pages/config/DeviceConfigModal.jsx +++ b/src/frontend/src/components/pages/config/DeviceConfigModal.jsx @@ -12,29 +12,29 @@ import { Field, Input, Stack, + Portal, + Select, } from '@chakra-ui/react'; import { motion } from 'framer-motion'; import { FiCheck } from 'react-icons/fi'; -import Notification from '@/libs/system/Notification'; +import { createListCollection } from '@chakra-ui/react'; const MotionBox = motion(Box); -/** - * 设备配置弹窗 - * @param isOpen 是否打开 - * @param onClose 关闭弹窗 - * @param onSave 保存修改 - * @param device 当前设备 - * @returns {JSX.Element} - * @constructor - */ +const vendors = ['huawei', 'cisco', 'h3c', 'ruijie', 'zte']; + const DeviceConfigModal = ({ isOpen, onClose, onSave, device }) => { const [username, setUsername] = useState(device.username || ''); const [password, setPassword] = useState(device.password || ''); + const [vendor, setVendor] = useState(device.vendor || ''); const [saved, setSaved] = useState(false); + const vendorCollection = createListCollection({ + items: vendors.map((v) => ({ label: v.toUpperCase(), value: v })), + }); + const handleSave = () => { - const updatedDevice = { ...device, username, password }; + const updatedDevice = { ...device, username, password, vendor }; onSave(updatedDevice); setSaved(true); setTimeout(() => { @@ -82,6 +82,40 @@ const DeviceConfigModal = ({ isOpen, onClose, onSave, device }) => { type={'password'} /> + + + 交换机厂商 + setVendor(value[0] || '')} + placeholder={'请选择厂商'} + size={'sm'} + colorPalette={'teal'} + > + + + + + + + + + + + + + + {vendorCollection.items.map((item) => ( + + {item.label} + + ))} + + + + + diff --git a/src/frontend/src/pages/ConfigPage.jsx b/src/frontend/src/pages/ConfigPage.jsx index 8591ce9..4a0ff67 100644 --- a/src/frontend/src/pages/ConfigPage.jsx +++ b/src/frontend/src/pages/ConfigPage.jsx @@ -22,22 +22,16 @@ import ConfigTool from '@/libs/config/ConfigTool'; import { api } from '@/services/api/api'; import Notification from '@/libs/system/Notification'; import Common from '@/libs/common'; -import configEffect from '@/libs/script/configPage/configEffect'; const testMode = ConfigTool.load().testMode; const ConfigPage = () => { const [devices, setDevices] = useState([]); - const [selectedDevice, setSelectedDevice] = useState(''); - const [selectedDeviceConfig, setSelectedDeviceConfig] = useState(''); + const [selectedDevices, setSelectedDevices] = useState([]); + const [deviceConfigs, setDeviceConfigs] = useState({}); const [inputText, setInputText] = useState(''); - const [parsedConfig, setParsedConfig] = useState(''); - const [editableConfig, setEditableConfig] = useState(''); const [applying, setApplying] = useState(false); const [hasParsed, setHasParsed] = useState(false); - const [isPeizhi, setisPeizhi] = useState(false); - const [isApplying, setIsApplying] = useState(false); - const [applyStatus, setApplyStatus] = useState([]); const deviceCollection = createListCollection({ items: devices.map((device) => ({ @@ -52,18 +46,30 @@ const ConfigPage = () => { }, []); const handleParse = async () => { - if (!selectedDevice || !inputText.trim()) { + if (selectedDevices.length === 0 || !inputText.trim()) { Notification.error({ title: '操作失败', - description: '请选择设备并输入配置指令', + description: '请选择至少一个设备并输入配置指令', + }); + return; + } + + const selectedConfigs = devices.filter((device) => selectedDevices.includes(device.ip)); + const deviceWithoutVendor = selectedConfigs.find((d) => !d.vendor || d.vendor.trim() === ''); + if (deviceWithoutVendor) { + Notification.error({ + title: '操作失败', + description: `设备 ${deviceWithoutVendor.name} 暂未配置厂商,请先配置厂商`, }); return; } try { - const performParse = async () => { - return await api.parseCommand(inputText); - }; + const performParse = async () => + await api.parseCommand({ + command: inputText, + devices: selectedConfigs, + }); const resultWrapper = await Notification.promise({ promise: performParse(), @@ -82,11 +88,15 @@ const ConfigPage = () => { }); let result = await resultWrapper.unwrap(); - if (result?.data) { - setParsedConfig(JSON.stringify(result.data)); - setEditableConfig(JSON.stringify(result.data)); + if (result?.data?.config) { + const configMap = {}; + result.data.config.forEach((item) => { + if (item.device?.ip) { + configMap[item.device.ip] = item; + } + }); + setDeviceConfigs(configMap); setHasParsed(true); - setisPeizhi(true); } } catch (error) { console.error('配置解析异常:', error); @@ -98,62 +108,80 @@ const ConfigPage = () => { }; const handleApply = async () => { - if (!editableConfig) { + if (!hasParsed) { Notification.warn({ - title: '配置为空', - description: '请先解析或编辑有效配置', + title: '未解析配置', + description: '请先解析配置再应用', }); return; } setApplying(true); - setIsApplying(true); try { const applyOperation = async () => { if (testMode) { - Common.sleep(1000).then(() => ({ success: true })); - } else { - let commands = JSON.parse(editableConfig)?.config?.commands; - console.log(`commands:${JSON.stringify(commands)}`); - const deviceConfig = JSON.parse(selectedDeviceConfig); - console.log(`deviceConfig:${JSON.stringify(deviceConfig)}`); - if (!deviceConfig.password) { + await Common.sleep(1000); + Notification.success({ + title: '测试模式成功', + description: '配置已模拟应用', + }); + return; + } + const applyPromises = selectedDevices.map(async (ip) => { + const deviceItem = deviceConfigs[ip]; + if (!deviceItem) return; + + const deviceConfig = deviceItem.config; + + if (!deviceItem.device.password) { Notification.warn({ - title: '所选交换机暂未配置用户名(可选)和密码', - description: '请前往交换机设备处配置username和password', + title: `交换机 ${deviceItem.device.name} 暂未配置密码`, + description: '请前往交换机设备处配置用户名和密码', }); - return false; + console.log(JSON.stringify(deviceItem)); + return; } - if (deviceConfig.username || deviceConfig.username.toString() !== '') { - commands.push(`!username=${deviceConfig.username.toString()}`); - } else { - commands.push(`!username=NONE`); + + if (!deviceItem.device.username) { + Notification.warn({ + title: `交换机 ${deviceItem.device.name} 暂未配置用户名,将使用NONE作为用户名`, + }); + deviceItem.device.username = 'NONE'; } - commands.push(`!password=${deviceConfig.password.toString()}`); - const res = await api.applyConfig(selectedDevice, commands); - if (res) { + + const commands = [...deviceConfig.commands]; + + try { + const res = await api.applyConfig( + ip, + commands, + deviceItem.device.username, + deviceItem.device.password + ); Notification.success({ - title: '配置完毕', + title: `配置完毕 - ${deviceItem.device.name}`, description: JSON.stringify(res), }); - } else { + } catch (err) { Notification.error({ - title: '配置过程出现错误', - description: '请检查API提示', + title: `配置过程出现错误 - ${deviceItem.device.name}`, + description: err.message || '请检查API提示', }); } - } + }); + + await Promise.all(applyPromises); }; await Notification.promise({ - promise: applyOperation, + promise: applyOperation(), loading: { title: '配置应用中', description: '正在推送配置到设备...', }, success: { - title: '应用成功', - description: '配置已成功生效', + title: '应用完成', + description: '所有设备配置已推送', }, error: { title: '应用失败', @@ -174,19 +202,16 @@ const ConfigPage = () => { 交换机配置中心 + 选择交换机设备 { - const selectedIp = value[0] ?? ''; - setSelectedDevice(selectedIp); - const fullDeviceConfig = devices.find((device) => device.ip === selectedIp); - setSelectedDeviceConfig(JSON.stringify(fullDeviceConfig)); - }} + value={selectedDevices} + onValueChange={({ value }) => setSelectedDevices(value)} placeholder={'请选择交换机设备'} size={'sm'} colorPalette={'teal'} @@ -202,7 +227,7 @@ const ConfigPage = () => { - + {deviceCollection.items.map((item) => ( @@ -214,13 +239,14 @@ const ConfigPage = () => { + 配置指令输入