diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml.disabled similarity index 100% rename from .github/workflows/deploy.yml rename to .github/workflows/deploy.yml.disabled diff --git a/README.md b/README.md index 3cf6031..59082e7 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # 基于人工智能实现的交换机自动或半自动配置 +这只是一个用于应付实验论文的奇葩储存库 + +猎奇卡顿 · 龟速更新 · 随时跑路 + ### 技术栈 - **Python3** - Flask @@ -11,10 +15,6 @@ - Framer-motion - chakra-ui - HTML5 -### 项目分工 -- **后端api,人工智能算法** : `3`(主要) & `log_out` & `Jerry`(maybe) 使用python -- **前端管理后台设计**:`Jerry`使用react -- **论文撰写**:`log_out` ### 各部分说明 @@ -22,18 +22,3 @@ [逻辑处理后端](https://github.com/Jerryplusy/AI-powered-switches/blob/main/src/backend/README.md) -### 贡献流程 - -- **后端api**: - - 对于`3`:直接推送到`main`分支 - - 对于`Jerry`&`log_out`:新建额外的`feat`或`fix`分支,提交推送到自己的分支,然后提交`pullrequest`到`main`分支并指定`3`审核 - -- **前端管理后台**: - - 对于`Jerry`:直接推送更新到`main`分支 - - 对于`3`&`log_out`:新建额外的`feat`或`fix`分支,提交推送到自己的分支,然后提交`pullrequest`到`main`分支并指定`Jerry`审核 - -- **论文(thesis)**: - - 提交`pullrequest`并指定`log_out`审核 -### 项目活动时间 -2025 6 - 8月 - diff --git a/src/backend/README.md b/src/backend/README.md index 0a34790..d0a536e 100644 --- a/src/backend/README.md +++ b/src/backend/README.md @@ -10,7 +10,6 @@ src/backend/ │ ├── __init__.py # 创建 Flask 应用实例 │ ├── api/ # API 路由模块 │ │ ├—── __init__.py # 注册 API 蓝图 -│ │ ├── command_parser.py # /api/parse_command 接口 │ │ └── network_config.py # /api/apply_config 接口 │ └── services/ # 核心服务逻辑 │ └── ai_services.py # 调用外部 AI 服务生成配置 diff --git a/src/backend/app/api/command_parser.py b/src/backend/app/api/command_parser.py deleted file mode 100644 index f198d10..0000000 --- a/src/backend/app/api/command_parser.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Dict, Any -from src.backend.app.services.ai_services import AIService -from src.backend.config import settings - - -class CommandParser: - def __init__(self): - self.ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL) - - async def parse(self, command: str) -> Dict[str, Any]: - """ - 解析中文命令并返回配置 - """ - return await self.ai_service.parse_command(command) diff --git a/src/backend/app/api/network_config.py b/src/backend/app/api/network_config.py index bd0c225..b512273 100644 --- a/src/backend/app/api/network_config.py +++ b/src/backend/app/api/network_config.py @@ -1,19 +1,14 @@ import asyncio import logging import telnetlib3 -from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional -import aiofiles -import asyncssh from pydantic import BaseModel -from tenacity import retry, stop_after_attempt, wait_exponential from src.backend.app.utils.logger import logger from src.backend.config import settings - # ---------------------- # 数据模型 # ---------------------- @@ -25,36 +20,31 @@ class SwitchConfig(BaseModel): ip_address: Optional[str] = None vlan: Optional[int] = None - # ---------------------- # 异常类 # ---------------------- class SwitchConfigException(Exception): pass - class EnspConnectionException(SwitchConfigException): pass - -class SSHConnectionException(SwitchConfigException): - pass - - # ---------------------- -# 核心配置器(完整双模式) +# 核心配置器 # ---------------------- class SwitchConfigurator: + connection_pool: Dict[str, tuple] = {} + def __init__( - self, - username: str = None, - password: str = None, - timeout: int = None, - max_workers: int = 5, - ensp_mode: bool = False, - ensp_port: int = 2000, - ensp_command_delay: float = 0.5, - **ssh_options + self, + username: str = None, + password: str = None, + timeout: int = None, + max_workers: int = 5, + ensp_mode: bool = False, + ensp_port: int = 2000, + ensp_command_delay: float = 0.5, + **ssh_options ): self.username = username if username is not None else settings.SWITCH_USERNAME self.password = password if password is not None else settings.SWITCH_PASSWORD @@ -67,253 +57,66 @@ class SwitchConfigurator: self.ensp_delay = ensp_command_delay self.ssh_options = ssh_options - async def apply_config(self, ip: str, config: Union[Dict, SwitchConfig]) -> str: - """实际配置逻辑""" - if isinstance(config, dict): - config = SwitchConfig(**config) - - commands = ( - self._generate_ensp_commands(config) - if self.ensp_mode - else self._generate_standard_commands(config) - ) - return await self._send_commands(ip, commands) - - async def _send_commands(self, ip: str, commands: List[str]) -> str: - """双模式命令发送""" - return ( - await self._send_ensp_commands(ip, commands) - ) - - async def _send_ensp_commands(self, ip: str, commands: List[str]) -> str: + async def _get_or_create_connection(self, ip: str): """ - 通过 Telnet 协议连接 eNSP 设备 + 从连接池获取连接,如果没有则新建 Telnet 连接 + """ + if ip in self.connection_pool: + logger.debug(f"复用已有连接: {ip}") + return self.connection_pool[ip] + + logger.info(f"建立新连接: {ip}") + reader, writer = await telnetlib3.open_connection(host=ip, port=23) + + try: + if self.username != 'NONE' : + await asyncio.wait_for(reader.readuntil(b"Username:"), timeout=self.timeout) + writer.write(f"{self.username}\n") + + await asyncio.wait_for(reader.readuntil(b"Password:"), timeout=self.timeout) + writer.write(f"{self.password}\n") + + await asyncio.sleep(1) + except asyncio.TimeoutError: + writer.close() + raise EnspConnectionException("登录超时,未收到用户名或密码提示") + except Exception as e: + writer.close() + raise EnspConnectionException(f"登录异常: {e}") + + self.connection_pool[ip] = (reader, writer) + return reader, writer + + async def _send_ensp_commands(self, ip: str, commands: List[str]) -> bool: + """ + 通过 Telnet 协议发送命令 """ try: - logger.info(f"连接设备 {ip},端口23") - reader, writer = await telnetlib3.open_connection(host=ip, port=23) - logger.debug("连接成功,开始登录流程") + reader, writer = await self._get_or_create_connection(ip) - try: - if self.username != 'NONE': - await asyncio.wait_for(reader.readuntil(b"Username:"), timeout=self.timeout) - logger.debug("收到 'Username:' 提示,发送用户名") - writer.write(f"{self.username}\n") - - await asyncio.wait_for(reader.readuntil(b"Password:"), timeout=self.timeout) - logger.debug("收到 'Password:' 提示,发送密码") - writer.write(f"{self.password}\n") - - await asyncio.sleep(1) - except asyncio.TimeoutError: - raise EnspConnectionException("登录超时,未收到用户名或密码提示") - - output = "" for cmd in commands: if cmd.startswith("!"): logger.debug(f"跳过特殊命令: {cmd}") continue - logger.info(f"发送命令: {cmd}") + logger.info(f"[{ip}] 发送命令: {cmd}") writer.write(f"{cmd}\n") await writer.drain() + await asyncio.sleep(self.ensp_delay) - command_output = "" - try: - while True: - data = await asyncio.wait_for(reader.read(1024), timeout=1) - if not data: - logger.debug("读取到空数据,结束当前命令读取") - break - command_output += data - logger.debug(f"收到数据: {repr(data)}") - except asyncio.TimeoutError: - logger.debug("命令输出读取超时,继续执行下一条命令") - - output += f"\n[命令: {cmd} 输出开始]\n{command_output}\n[命令: {cmd} 输出结束]\n" - - logger.info("所有命令执行完毕,关闭连接") - writer.close() - - return output + logger.info(f"[{ip}] 所有命令发送完成") + return True except asyncio.TimeoutError as e: - logger.error(f"连接或读取超时: {e}") - raise EnspConnectionException(f"eNSP连接超时: {e}") + logger.error(f"[{ip}] 连接或读取超时: {e}") + return False except Exception as e: - logger.error(f"连接或执行异常: {e}", exc_info=True) - raise EnspConnectionException(f"eNSP连接失败: {e}") - - @staticmethod - def _generate_ensp_commands(config: SwitchConfig) -> List[str]: - """生成eNSP命令序列""" - commands = ["system-view"] - if config.type == "vlan": - commands.extend([ - f"vlan {config.vlan_id}", - f"description {config.name or ''}" - ]) - elif config.type == "interface": - commands.extend([ - f"interface {config.interface}", - "port link-type access", - f"port default vlan {config.vlan}" if config.vlan else "", - f"ip address {config.ip_address}" if config.ip_address else "" - ]) - commands.append("return") - return [c for c in commands if c.strip()] - - async def _send_ssh_commands(self, ip: str, commands: List[str]) -> str: - """AsyncSSH执行命令""" - async with self.semaphore: - try: - async with asyncssh.connect( - host=ip, - username=self.username, - password=self.password, - connect_timeout=self.timeout, - **self.ssh_options - ) as conn: - results = [] - for cmd in commands: - result = await conn.run(cmd, check=True) - results.append(result.stdout) - return "\n".join(results) - except asyncssh.Error as e: - raise SSHConnectionException(f"SSH操作失败: {str(e)}") - except Exception as e: - raise SSHConnectionException(f"连接异常: {str(e)}") - - async def execute_raw_commands(self, ip: str, commands: List[str]) -> str: - """ - 执行原始CLI命令 - """ - return await self._send_commands(ip, commands) - - - @staticmethod - def _generate_standard_commands(config: SwitchConfig) -> List[str]: - """生成标准CLI命令""" - commands = [] - if config.type == "vlan": - commands.extend([ - f"vlan {config.vlan_id}", - f"name {config.name or ''}" - ]) - elif config.type == "interface": - commands.extend([ - f"interface {config.interface}", - f"switchport access vlan {config.vlan}" if config.vlan else "", - f"ip address {config.ip_address}" if config.ip_address else "" - ]) - return commands - - async def _validate_config(self, ip: str, config: SwitchConfig) -> bool: - """验证配置是否生效""" - current = await self._get_current_config(ip) - if config.type == "vlan": - return f"vlan {config.vlan_id}" in current - elif config.type == "interface" and config.vlan: - return f"switchport access vlan {config.vlan}" in current - return True - - async def _get_current_config(self, ip: str) -> str: - """获取当前配置""" - commands = ( - ["display current-configuration"] - if self.ensp_mode - else ["show running-config"] - ) - try: - return await self._send_commands(ip, commands) - except (EnspConnectionException, SSHConnectionException) as e: - raise SwitchConfigException(f"配置获取失败: {str(e)}") - - async def _backup_config(self, ip: str) -> Path: - """备份配置到文件""" - backup_path = self.backup_dir / f"{ip}_{datetime.now().isoformat()}.cfg" - config = await self._get_current_config(ip) - async with aiofiles.open(backup_path, "w") as f: - await f.write(config) - return backup_path - - async def _restore_config(self, ip: str, backup_path: Path) -> bool: - """从备份恢复配置""" - try: - async with aiofiles.open(backup_path) as f: - config = await f.read() - commands = ( - ["system-view", config, "return"] - if self.ensp_mode - else [f"configure terminal\n{config}\nend"] - ) - await self._send_commands(ip, commands) - return True - except Exception as e: - logging.error(f"恢复失败: {str(e)}") + logger.error(f"[{ip}] 命令发送异常: {e}", exc_info=True) return False - - @retry( - stop=stop_after_attempt(2), - wait=wait_exponential(multiplier=1, min=4, max=10) - ) - async def safe_apply( - self, - ip: str, - config: Union[Dict, SwitchConfig] - ) -> Dict[str, Union[str, bool, Path]]: - """安全配置应用(自动回滚)""" - backup_path = await self._backup_config(ip) - try: - result = await self.apply_config(ip, config) - if not await self._validate_config(ip, config): - raise SwitchConfigException("配置验证失败") - return { - "status": "success", - "output": result, - "backup_path": str(backup_path) - } - except (EnspConnectionException, SSHConnectionException, SwitchConfigException) as e: - restore_status = await self._restore_config(ip, backup_path) - return { - "status": "failed", - "error": str(e), - "backup_path": str(backup_path), - "restore_success": restore_status - } - - -# ---------------------- -# 使用示例 -# ---------------------- -async def demo(): - ensp_configurator = SwitchConfigurator( - ensp_mode=True, - ensp_port=2000, - username="admin", - password="admin", - timeout=15 - ) - ensp_result = await ensp_configurator.safe_apply("127.0.0.1", { - "type": "interface", - "interface": "GigabitEthernet0/0/1", - "vlan": 100, - "ip_address": "192.168.1.2 255.255.255.0" - }) - print("eNSP配置结果:", ensp_result) - - ssh_configurator = SwitchConfigurator( - username="cisco", - password="cisco123", - timeout=15 - ) - ssh_result = await ssh_configurator.safe_apply("192.168.1.1", { - "type": "vlan", - "vlan_id": 200, - "name": "Production" - }) - print("SSH配置结果:", ssh_result) - - -if __name__ == "__main__": - asyncio.run(demo()) + async def execute_raw_commands(self, ip: str, commands: List[str]) -> bool: + """ + 对外接口:单台交换机执行命令 + """ + async with self.semaphore: + success = await self._send_ensp_commands(ip, commands) + return success diff --git a/src/backend/app/models/requests.py b/src/backend/app/models/requests.py new file mode 100644 index 0000000..4c696f1 --- /dev/null +++ b/src/backend/app/models/requests.py @@ -0,0 +1,26 @@ +from typing import List, Optional +from pydantic import BaseModel + +class BatchConfigRequest(BaseModel): + config: dict + switch_ips: List[str] + username: Optional[str] = None + password: Optional[str] = None + timeout: Optional[int] = None + +class ConfigRequest(BaseModel): + config: dict + switch_ip: str + username: Optional[str] = None + password: Optional[str] = None + timeout: Optional[int] = None + vendor: str = "huawei" + +class CLICommandRequest(BaseModel): + switch_ip: str + commands: List[str] + username: Optional[str] = None + password: Optional[str] = None + + def extract_credentials(self): + return self.username or "NONE", self.password or "NONE" diff --git a/src/backend/app/services/config_validator.py b/src/backend/app/services/config_validator.py deleted file mode 100644 index 5be5839..0000000 --- a/src/backend/app/services/config_validator.py +++ /dev/null @@ -1,85 +0,0 @@ -import re -from typing import Dict, List, Tuple -from ..utils.exceptions import SwitchConfigException - - -class ConfigValidator: - @staticmethod - def validate_vlan_config(config: Dict) -> Tuple[bool, str]: - """验证VLAN配置""" - if 'vlan_id' not in config: - return False, "Missing VLAN ID" - - vlan_id = config['vlan_id'] - if not (1 <= vlan_id <= 4094): - return False, f"Invalid VLAN ID {vlan_id}. Must be 1-4094" - - if 'name' in config and len(config['name']) > 32: - return False, "VLAN name too long (max 32 chars)" - - return True, "Valid VLAN config" - - @staticmethod - def validate_interface_config(config: Dict) -> Tuple[bool, str]: - """验证接口配置""" - required_fields = ['interface', 'ip_address'] - for field in required_fields: - if field not in config: - return False, f"Missing required field: {field}" - - # 验证IP地址格式 - ip_pattern = r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$' - if not re.match(ip_pattern, config['ip_address']): - return False, "Invalid IP address format" - - # 验证接口名称格式 - interface_pattern = r'^(GigabitEthernet|FastEthernet|Eth)\d+/\d+/\d+$' - if not re.match(interface_pattern, config['interface']): - return False, "Invalid interface name format" - - return True, "Valid interface config" - - @staticmethod - def check_security_risks(commands: List[str]) -> List[str]: - """检查潜在安全风险""" - risky_commands = [] - dangerous_patterns = [ - r'no\s+aaa', # 禁用认证 - r'enable\s+password', # 明文密码 - r'service\s+password-encryption', # 弱加密 - r'ip\s+http\s+server', # 启用HTTP服务 - r'no\s+ip\s+http\s+secure-server' # 禁用HTTPS - ] - - for cmd in commands: - for pattern in dangerous_patterns: - if re.search(pattern, cmd, re.IGNORECASE): - risky_commands.append(cmd) - break - - return risky_commands - - @staticmethod - def validate_full_config(config: Dict) -> Tuple[bool, List[str]]: - """全面验证配置""" - errors = [] - - if 'type' not in config: - errors.append("Missing configuration type") - return False, errors - - if config['type'] == 'vlan': - valid, msg = ConfigValidator.validate_vlan_config(config) - if not valid: - errors.append(msg) - elif config['type'] == 'interface': - valid, msg = ConfigValidator.validate_interface_config(config) - if not valid: - errors.append(msg) - - if 'commands' in config: - risks = ConfigValidator.check_security_risks(config['commands']) - if risks: - errors.append(f"Potential security risks detected: {', '.join(risks)}") - - return len(errors) == 0, errors \ No newline at end of file diff --git a/src/backend/app/services/failover_manager.py b/src/backend/app/services/failover_manager.py deleted file mode 100644 index 84c31d9..0000000 --- a/src/backend/app/services/failover_manager.py +++ /dev/null @@ -1,129 +0,0 @@ -import asyncio -from datetime import datetime, timedelta -from pathlib import Path -from typing import Dict, List - -import networkx as nx - -from ..models.traffic_models import SwitchTrafficRecord -from src.backend.app.api.database import SessionLocal -from ..utils.exceptions import SwitchConfigException -from ..services.network_scanner import NetworkScanner - - -class FailoverManager: - def __init__(self, config_backup_dir: str = "config_backups"): - self.scanner = NetworkScanner() - self.backup_dir = Path(config_backup_dir) - self.backup_dir.mkdir(exist_ok=True) - - async def detect_failures(self, threshold: float = 0.9) -> List[Dict]: - """检测可能的网络故障""" - with SessionLocal() as session: - # 获取最近5分钟的所有交换机流量记录 - records = session.query(SwitchTrafficRecord).filter( - SwitchTrafficRecord.timestamp >= datetime.now() - timedelta(minutes=5) - ).all() - - # 分析流量异常 - abnormal_devices = [] - device_stats = {} - - for record in records: - if record.switch_ip not in device_stats: - device_stats[record.switch_ip] = { - 'interfaces': {}, - 'max_in': 0, - 'max_out': 0 - } - - if record.interface not in device_stats[record.switch_ip]['interfaces']: - device_stats[record.switch_ip]['interfaces'][record.interface] = { - 'in': [], - 'out': [] - } - - device_stats[record.switch_ip]['interfaces'][record.interface]['in'].append(record.rate_in) - device_stats[record.switch_ip]['interfaces'][record.interface]['out'].append(record.rate_out) - - # 更新设备最大流量 - if record.rate_in > device_stats[record.switch_ip]['max_in']: - device_stats[record.switch_ip]['max_in'] = record.rate_in - if record.rate_out > device_stats[record.switch_ip]['max_out']: - device_stats[record.switch_ip]['max_out'] = record.rate_out - - # 检测异常 - for ip, stats in device_stats.items(): - for iface, traffic in stats['interfaces'].items(): - avg_in = sum(traffic['in']) / len(traffic['in']) - avg_out = sum(traffic['out']) / len(traffic['out']) - - # 如果平均流量超过最大流量的90%,认为可能有问题 - if avg_in > stats['max_in'] * threshold or \ - avg_out > stats['max_out'] * threshold: - abnormal_devices.append({ - 'ip': ip, - 'interface': iface, - 'avg_in': avg_in, - 'avg_out': avg_out, - 'max_in': stats['max_in'], - 'max_out': stats['max_out'] - }) - - return abnormal_devices - - async def automatic_recovery(self, failed_device_ip: str) -> bool: - """自动故障恢复""" - try: - # 1. 检查设备是否在线 - devices = self.scanner.scan_subnet(failed_device_ip + "/32") - if not devices: - raise SwitchConfigException(f"Device {failed_device_ip} is offline") - - # 2. 查找最近的配置备份 - backup_files = sorted(self.backup_dir.glob(f"{failed_device_ip}_*.cfg")) - if not backup_files: - raise SwitchConfigException(f"No backup found for {failed_device_ip}") - - latest_backup = backup_files[-1] - - # 3. 恢复配置 - with open(latest_backup) as f: - config_commands = f.read().splitlines() - - # 使用SSH执行恢复命令 - # (这里需要实现SSH连接和执行命令的逻辑) - - return True - - except Exception as e: - raise SwitchConfigException(f"Recovery failed: {str(e)}") - - async def redundancy_check(self, critical_devices: List[str]) -> Dict: - """检查关键设备的冗余配置""" - results = {} - topology = self.scanner.get_current_topology() - - for device_ip in critical_devices: - # 检查是否有备用路径 - try: - paths = list(nx.all_shortest_paths( - topology, source=device_ip, target="core_switch")) - - if len(paths) > 1: - results[device_ip] = { - 'status': 'redundant', - 'path_count': len(paths) - } - else: - results[device_ip] = { - 'status': 'single_point_of_failure', - 'recommendation': 'Add redundant links' - } - except: - results[device_ip] = { - 'status': 'disconnected', - 'recommendation': 'Check physical connection' - } - - return results \ No newline at end of file diff --git a/src/backend/app/services/network_optimizer.py b/src/backend/app/services/network_optimizer.py deleted file mode 100644 index 39c14f7..0000000 --- a/src/backend/app/services/network_optimizer.py +++ /dev/null @@ -1,220 +0,0 @@ -import networkx as nx -import numpy as np -from scipy.optimize import linprog -from typing import Dict, List, Optional, Tuple, Any -from dataclasses import dataclass - - -@dataclass -class FlowDemand: - source: str - destination: str - bandwidth: float - priority: float = 1.0 - - -class NetworkOptimizer: - def __init__(self, devices: List[Dict[str, Any]]): - """基于图论的网络优化模型""" - self.graph = self.build_topology_graph(devices) - self.traffic_matrix: Optional[np.ndarray] = None - self._initialize_capacities() - - def _initialize_capacities(self) -> Dict[Tuple[str, str], float]: - """初始化链路容量字典""" - self.remaining_capacity = { - (u, v): data['bandwidth'] - for u, v, data in self.graph.edges(data=True) - } - return self.remaining_capacity - - def build_topology_graph(self, devices: List[Dict[str, Any]]) -> nx.Graph: - """构建带权重的网络拓扑图""" - G = nx.Graph() - - for device in devices: - G.add_node(device['ip'], - type=device['type'], - capacity=device.get('capacity', 1000)) - - # 添加接口作为子节点 - for interface in device.get('interfaces', []): - interface_id = f"{device['ip']}_{interface['name']}" - G.add_node(interface_id, - type='interface', - capacity=interface.get('bandwidth', 1000)) - G.add_edge(device['ip'], interface_id, - bandwidth=interface.get('bandwidth', 1000), - latency=interface.get('latency', 1)) - - # 添加设备间连接 - self._connect_devices(G, devices) - return G - - @staticmethod - def _connect_devices(graph: nx.Graph, devices: List[Dict[str, Any]]): - """自动连接设备""" - for i in range(len(devices) - 1): - graph.add_edge( - devices[i]['ip'], - devices[i + 1]['ip'], - bandwidth=1000, - latency=5 - ) - - @staticmethod - def _build_flow_conservation(nodes: List[str]) -> Tuple[np.ndarray, np.ndarray]: - """构建流量守恒约束矩阵""" - num_nodes = len(nodes) - A_eq: List[np.ndarray] = [] # 明确指定列表元素类型 - b_eq: List[float] = [] # 明确指定float列表 - - for i in range(num_nodes): - for j in range(num_nodes): - if i != j: - constraint = np.zeros((num_nodes, num_nodes)) - constraint[i][j] = 1 - constraint[j][i] = -1 - A_eq.append(constraint.flatten()) - b_eq.append(0.0) # 明确使用float类型 - - return np.array(A_eq, dtype=np.float64), np.array(b_eq, dtype=np.float64) - def optimize_bandwidth_allocation( - self, - demands: List[FlowDemand] - ) -> Dict[str, Dict[str, float]]: - """ - 基于线性规划的带宽分配优化 - 返回: {源IP: {目标IP: 分配带宽}} - """ - demand_dict = {(d.source, d.destination): float(d.bandwidth) for d in demands} # 确保float类型 - return self._optimize_with_highs(demand_dict) - - def _optimize_with_highs(self, demands: Dict[Tuple[str, str], float]) -> Dict[str, Dict[str, float]]: - """使用HiGHS求解器实现""" - nodes = list(self.graph.nodes()) - node_index = {node: i for i, node in enumerate(nodes)} - - # 构建流量矩阵 - self.traffic_matrix = np.zeros((len(nodes), len(nodes))) - for (src, dst), bw in demands.items(): - if src in node_index and dst in node_index: - self.traffic_matrix[node_index[src]][node_index[dst]] = bw - - # 构建约束 - c = np.ones(len(nodes) ** 2) # 最小化总流量 - A_ub, b_ub = self._build_capacity_constraints(nodes, node_index) - A_eq, b_eq = self._build_flow_conservation(nodes) - - # 求解线性规划 - res = linprog( - c=c, - A_ub=A_ub, - b_ub=b_ub, - A_eq=A_eq, - b_eq=b_eq, - bounds=(0, None), - method='highs' - ) - - if not res.success: - raise ValueError(f"Optimization failed: {res.message}") - - return self._format_results(res.x, nodes, node_index) - - def _build_capacity_constraints(self, nodes: List[str], node_index: Dict[str, int]) -> Tuple[ - np.ndarray, np.ndarray]: - """构建容量约束矩阵""" - A_ub = [] - b_ub = [] - - for u, v, data in self.graph.edges(data=True): - capacity = float(data.get('bandwidth', 1000)) # 确保float类型 - b_ub.append(capacity) - - constraint = np.zeros((len(nodes), len(nodes))) - for src, dst in [(u, v), (v, u)]: - if src in node_index and dst in node_index: - i, j = node_index[src], node_index[dst] - constraint[i][j] += 1 - A_ub.append(constraint.flatten()) - - return np.array(A_ub), np.array(b_ub) - - @staticmethod - def _format_results(solution: np.ndarray, nodes: List[str], node_index: Dict[str, int]) -> Dict[ - str, Dict[str, float]]: - """格式化优化结果""" - flows = solution.reshape((len(nodes), len(nodes))) - return { - nodes[i]: { - nodes[j]: float(flows[i][j]) # 明确转换为float - for j in range(len(nodes)) - if flows[i][j] > 0.001 - } - for i in range(len(nodes)) - } - - def optimize_qos(self, flows: List[FlowDemand]) -> Dict[str, Dict[str, Any]]: - """ - 带QoS的流量优化 - 返回: { - "源-目标": { - "path": 路径列表, - "allocated": 分配带宽, - "priority": 优先级 - } - } - """ - sorted_flows = sorted(flows, key=lambda x: x.priority, reverse=True) - results = {} - - # 使用实例变量而非局部变量 - self._initialize_capacities() # 重置剩余带宽 - - for flow in sorted_flows: - path = self.find_optimal_path(flow.source, flow.destination, flow.bandwidth) - if not path: - continue - - min_bw = min(self.graph[u][v]['bandwidth'] for u, v in zip(path[:-1], path[1:])) - allocated = min(min_bw, flow.bandwidth) - - # 更新剩余带宽 - for u, v in zip(path[:-1], path[1:]): - self.remaining_capacity[(u, v)] -= allocated - if self.remaining_capacity[(u, v)] <= 0: - self.graph[u][v]['bandwidth'] = 0 - - results[f"{flow.source}-{flow.destination}"] = { - "path": path, - "allocated": float(allocated), # 确保float类型 - "priority": float(flow.priority) - } - - return results - - def find_optimal_path(self, source: str, target: str, - bandwidth: float = 1.0) -> Optional[List[str]]: - """ - 改进的最优路径查找 - """ - try: - paths = nx.shortest_simple_paths( - self.graph, source, target, - weight=lambda u, v, d: 1 / max(1, d['bandwidth']) + d['latency'] # 避免除零 - ) - return next( - (path for path in paths - if self._path_has_sufficient_bandwidth(path, bandwidth)), - None - ) - except (nx.NetworkXNoPath, nx.NodeNotFound): - return None - - def _path_has_sufficient_bandwidth(self, path: List[str], bw: float) -> bool: - """检查路径带宽是否满足要求""" - return all( - self.graph[u][v]['bandwidth'] >= bw - for u, v in zip(path[:-1], path[1:]) - ) \ No newline at end of file diff --git a/src/backend/app/services/network_visualizer.py b/src/backend/app/services/network_visualizer.py deleted file mode 100644 index f9c28fd..0000000 --- a/src/backend/app/services/network_visualizer.py +++ /dev/null @@ -1,104 +0,0 @@ -from datetime import datetime, timedelta - -import networkx as nx -import matplotlib.pyplot as plt -from io import BytesIO -import base64 -from typing import Dict, List -from ..models.traffic_models import SwitchTrafficRecord -from src.backend.app.api.database import SessionLocal - - -class NetworkVisualizer: - def __init__(self): - self.graph = nx.Graph() - - def update_topology(self, devices: List[Dict]): - """更新网络拓扑图""" - self.graph.clear() - - # 添加节点 - for device in devices: - self.graph.add_node( - device['ip'], - type='switch', - label=f"Switch\n{device['ip']}" - ) - - # 添加连接(简化版,实际应根据扫描结果) - if len(devices) > 1: - for i in range(len(devices) - 1): - self.graph.add_edge( - devices[i]['ip'], - devices[i + 1]['ip'], - bandwidth=1000, - label="1Gbps" - ) - - def generate_topology_image(self) -> str: - """生成拓扑图并返回base64编码""" - plt.figure(figsize=(10, 8)) - pos = nx.spring_layout(self.graph) - - # 绘制节点 - node_colors = [] - for node in self.graph.nodes(): - if self.graph.nodes[node]['type'] == 'switch': - node_colors.append('lightblue') - - nx.draw_networkx_nodes( - self.graph, pos, - node_size=2000, - node_color=node_colors - ) - - # 绘制边 - nx.draw_networkx_edges( - self.graph, pos, - width=2, - alpha=0.5 - ) - - # 绘制标签 - node_labels = nx.get_node_attributes(self.graph, 'label') - nx.draw_networkx_labels( - self.graph, pos, - labels=node_labels, - font_size=8 - ) - - edge_labels = nx.get_edge_attributes(self.graph, 'label') - nx.draw_networkx_edge_labels( - self.graph, pos, - edge_labels=edge_labels, - font_size=8 - ) - - plt.title("Network Topology") - plt.axis('off') - - # 转换为base64 - buf = BytesIO() - plt.savefig(buf, format='png', bbox_inches='tight') - plt.close() - buf.seek(0) - return base64.b64encode(buf.read()).decode('utf-8') - - @staticmethod - def get_traffic_heatmap(minutes: int = 10) -> Dict: - """获取流量热力图数据""" - with SessionLocal() as session: - records = session.query(SwitchTrafficRecord).filter( - SwitchTrafficRecord.timestamp >= datetime.now() - timedelta(minutes=minutes) - ).all() - - heatmap_data = {} - for record in records: - if record.switch_ip not in heatmap_data: - heatmap_data[record.switch_ip] = {} - heatmap_data[record.switch_ip][record.interface] = { - 'in': record.rate_in, - 'out': record.rate_out - } - - return heatmap_data \ No newline at end of file diff --git a/src/backend/app/services/report_generator.py b/src/backend/app/services/report_generator.py deleted file mode 100644 index d6742f4..0000000 --- a/src/backend/app/services/report_generator.py +++ /dev/null @@ -1,120 +0,0 @@ -import pandas as pd -import matplotlib.pyplot as plt -from io import BytesIO -import base64 -from datetime import datetime, timedelta -from typing import Dict, List -from ..models.traffic_models import TrafficRecord, SwitchTrafficRecord -from src.backend.app.api.database import SessionLocal - - -class ReportGenerator: - @staticmethod - def generate_traffic_report(ip: str = None, days: int = 7) -> Dict: - """生成流量分析报告""" - with SessionLocal() as session: - time_threshold = datetime.now() - timedelta(days=days) - - if ip: - # 交换机流量报告 - query = session.query(SwitchTrafficRecord).filter( - SwitchTrafficRecord.switch_ip == ip, - SwitchTrafficRecord.timestamp >= time_threshold - ) - df = pd.read_sql(query.statement, session.bind) - - if df.empty: - return {"error": "No data found"} - - # 按接口分组分析 - interface_stats = df.groupby('interface').agg({ - 'rate_in': ['mean', 'max', 'min'], - 'rate_out': ['mean', 'max', 'min'] - }).reset_index() - - # 生成趋势图 - trend_fig = ReportGenerator._plot_traffic_trend(df, f"Switch {ip} Traffic Trend") - - return { - "switch_ip": ip, - "period": f"Last {days} days", - "interface_stats": interface_stats.to_dict('records'), - "trend_chart": trend_fig - } - else: - # 本地网络流量报告 - query = session.query(TrafficRecord).filter( - TrafficRecord.timestamp >= time_threshold - ) - df = pd.read_sql(query.statement, session.bind) - - if df.empty: - return {"error": "No data found"} - - # 按接口分组分析 - interface_stats = df.groupby('interface').agg({ - 'bytes_sent': ['sum', 'mean', 'max'], - 'bytes_recv': ['sum', 'mean', 'max'] - }).reset_index() - - # 生成趋势图 - trend_fig = ReportGenerator._plot_traffic_trend(df, "Local Network Traffic Trend") - - return { - "report_type": "local_network", - "period": f"Last {days} days", - "interface_stats": interface_stats.to_dict('records'), - "trend_chart": trend_fig - } - - @staticmethod - def _plot_traffic_trend(df: pd.DataFrame, title: str) -> str: - """生成流量趋势图""" - plt.figure(figsize=(12, 6)) - - if 'rate_in' in df.columns: # 交换机数据 - df_grouped = df.groupby(pd.Grouper(key='timestamp', freq='1H')).agg({ - 'rate_in': 'mean', - 'rate_out': 'mean' - }) - plt.plot(df_grouped.index, df_grouped['rate_in'], label='Inbound Traffic') - plt.plot(df_grouped.index, df_grouped['rate_out'], label='Outbound Traffic') - plt.ylabel("Traffic Rate (bytes/sec)") - else: # 本地网络数据 - df_grouped = df.groupby(pd.Grouper(key='timestamp', freq='1H')).agg({ - 'bytes_sent': 'sum', - 'bytes_recv': 'sum' - }) - plt.plot(df_grouped.index, df_grouped['bytes_sent'], label='Bytes Sent') - plt.plot(df_grouped.index, df_grouped['bytes_recv'], label='Bytes Received') - plt.ylabel("Bytes") - - plt.title(title) - plt.xlabel("Time") - plt.legend() - plt.grid(True) - - # 转换为base64 - buf = BytesIO() - plt.savefig(buf, format='png', bbox_inches='tight') - plt.close() - buf.seek(0) - return base64.b64encode(buf.read()).decode('utf-8') - - @staticmethod - def generate_config_history_report(days: int = 30) -> Dict: - """生成配置变更历史报告""" - # 需要实现配置历史记录功能 - pass - - @staticmethod - def generate_security_report() -> Dict: - """生成安全评估报告""" - # 需要实现安全扫描功能 - pass - - @staticmethod - def generate_performance_report() -> Dict: - """生成性能评估报告""" - # 需要实现性能基准测试功能 - pass \ No newline at end of file diff --git a/src/backend/app/services/traffic_monitor.py b/src/backend/app/services/traffic_monitor.py index 54adb11..1c498b3 100644 --- a/src/backend/app/services/traffic_monitor.py +++ b/src/backend/app/services/traffic_monitor.py @@ -8,6 +8,7 @@ from typing import Dict, Optional, List from ..models.traffic_models import TrafficRecord from src.backend.app.api.database import SessionLocal +from ..utils.logger import logger class TrafficMonitor: @@ -33,7 +34,7 @@ class TrafficMonitor: if not self.running: self.running = True self.task = asyncio.create_task(self._monitor_loop()) - print("流量监控已启动") + logger.info("流量监控已启动") async def stop_monitoring(self): """停止流量监控""" @@ -44,7 +45,7 @@ class TrafficMonitor: await self.task except asyncio.CancelledError: pass - print("流量监控已停止") + logger.info("流量监控已停止") async def _monitor_loop(self): """监控主循环""" diff --git a/src/frontend/src/components/pages/config/DeviceConfigModal.jsx b/src/frontend/src/components/pages/config/DeviceConfigModal.jsx index 744ad81..cd495d0 100644 --- a/src/frontend/src/components/pages/config/DeviceConfigModal.jsx +++ b/src/frontend/src/components/pages/config/DeviceConfigModal.jsx @@ -12,29 +12,29 @@ import { Field, Input, Stack, + Portal, + Select, } from '@chakra-ui/react'; import { motion } from 'framer-motion'; import { FiCheck } from 'react-icons/fi'; -import Notification from '@/libs/system/Notification'; +import { createListCollection } from '@chakra-ui/react'; const MotionBox = motion(Box); -/** - * 设备配置弹窗 - * @param isOpen 是否打开 - * @param onClose 关闭弹窗 - * @param onSave 保存修改 - * @param device 当前设备 - * @returns {JSX.Element} - * @constructor - */ +const vendors = ['huawei', 'cisco', 'h3c', 'ruijie', 'zte']; + const DeviceConfigModal = ({ isOpen, onClose, onSave, device }) => { const [username, setUsername] = useState(device.username || ''); const [password, setPassword] = useState(device.password || ''); + const [vendor, setVendor] = useState(device.vendor || ''); const [saved, setSaved] = useState(false); + const vendorCollection = createListCollection({ + items: vendors.map((v) => ({ label: v.toUpperCase(), value: v })), + }); + const handleSave = () => { - const updatedDevice = { ...device, username, password }; + const updatedDevice = { ...device, username, password, vendor }; onSave(updatedDevice); setSaved(true); setTimeout(() => { @@ -82,6 +82,40 @@ const DeviceConfigModal = ({ isOpen, onClose, onSave, device }) => { type={'password'} /> + + + 交换机厂商 + setVendor(value[0] || '')} + placeholder={'请选择厂商'} + size={'sm'} + colorPalette={'teal'} + > + + + + + + + + + + + + + + {vendorCollection.items.map((item) => ( + + {item.label} + + ))} + + + + + diff --git a/src/frontend/src/pages/ConfigPage.jsx b/src/frontend/src/pages/ConfigPage.jsx index 8591ce9..4a0ff67 100644 --- a/src/frontend/src/pages/ConfigPage.jsx +++ b/src/frontend/src/pages/ConfigPage.jsx @@ -22,22 +22,16 @@ import ConfigTool from '@/libs/config/ConfigTool'; import { api } from '@/services/api/api'; import Notification from '@/libs/system/Notification'; import Common from '@/libs/common'; -import configEffect from '@/libs/script/configPage/configEffect'; const testMode = ConfigTool.load().testMode; const ConfigPage = () => { const [devices, setDevices] = useState([]); - const [selectedDevice, setSelectedDevice] = useState(''); - const [selectedDeviceConfig, setSelectedDeviceConfig] = useState(''); + const [selectedDevices, setSelectedDevices] = useState([]); + const [deviceConfigs, setDeviceConfigs] = useState({}); const [inputText, setInputText] = useState(''); - const [parsedConfig, setParsedConfig] = useState(''); - const [editableConfig, setEditableConfig] = useState(''); const [applying, setApplying] = useState(false); const [hasParsed, setHasParsed] = useState(false); - const [isPeizhi, setisPeizhi] = useState(false); - const [isApplying, setIsApplying] = useState(false); - const [applyStatus, setApplyStatus] = useState([]); const deviceCollection = createListCollection({ items: devices.map((device) => ({ @@ -52,18 +46,30 @@ const ConfigPage = () => { }, []); const handleParse = async () => { - if (!selectedDevice || !inputText.trim()) { + if (selectedDevices.length === 0 || !inputText.trim()) { Notification.error({ title: '操作失败', - description: '请选择设备并输入配置指令', + description: '请选择至少一个设备并输入配置指令', + }); + return; + } + + const selectedConfigs = devices.filter((device) => selectedDevices.includes(device.ip)); + const deviceWithoutVendor = selectedConfigs.find((d) => !d.vendor || d.vendor.trim() === ''); + if (deviceWithoutVendor) { + Notification.error({ + title: '操作失败', + description: `设备 ${deviceWithoutVendor.name} 暂未配置厂商,请先配置厂商`, }); return; } try { - const performParse = async () => { - return await api.parseCommand(inputText); - }; + const performParse = async () => + await api.parseCommand({ + command: inputText, + devices: selectedConfigs, + }); const resultWrapper = await Notification.promise({ promise: performParse(), @@ -82,11 +88,15 @@ const ConfigPage = () => { }); let result = await resultWrapper.unwrap(); - if (result?.data) { - setParsedConfig(JSON.stringify(result.data)); - setEditableConfig(JSON.stringify(result.data)); + if (result?.data?.config) { + const configMap = {}; + result.data.config.forEach((item) => { + if (item.device?.ip) { + configMap[item.device.ip] = item; + } + }); + setDeviceConfigs(configMap); setHasParsed(true); - setisPeizhi(true); } } catch (error) { console.error('配置解析异常:', error); @@ -98,62 +108,80 @@ const ConfigPage = () => { }; const handleApply = async () => { - if (!editableConfig) { + if (!hasParsed) { Notification.warn({ - title: '配置为空', - description: '请先解析或编辑有效配置', + title: '未解析配置', + description: '请先解析配置再应用', }); return; } setApplying(true); - setIsApplying(true); try { const applyOperation = async () => { if (testMode) { - Common.sleep(1000).then(() => ({ success: true })); - } else { - let commands = JSON.parse(editableConfig)?.config?.commands; - console.log(`commands:${JSON.stringify(commands)}`); - const deviceConfig = JSON.parse(selectedDeviceConfig); - console.log(`deviceConfig:${JSON.stringify(deviceConfig)}`); - if (!deviceConfig.password) { + await Common.sleep(1000); + Notification.success({ + title: '测试模式成功', + description: '配置已模拟应用', + }); + return; + } + const applyPromises = selectedDevices.map(async (ip) => { + const deviceItem = deviceConfigs[ip]; + if (!deviceItem) return; + + const deviceConfig = deviceItem.config; + + if (!deviceItem.device.password) { Notification.warn({ - title: '所选交换机暂未配置用户名(可选)和密码', - description: '请前往交换机设备处配置username和password', + title: `交换机 ${deviceItem.device.name} 暂未配置密码`, + description: '请前往交换机设备处配置用户名和密码', }); - return false; + console.log(JSON.stringify(deviceItem)); + return; } - if (deviceConfig.username || deviceConfig.username.toString() !== '') { - commands.push(`!username=${deviceConfig.username.toString()}`); - } else { - commands.push(`!username=NONE`); + + if (!deviceItem.device.username) { + Notification.warn({ + title: `交换机 ${deviceItem.device.name} 暂未配置用户名,将使用NONE作为用户名`, + }); + deviceItem.device.username = 'NONE'; } - commands.push(`!password=${deviceConfig.password.toString()}`); - const res = await api.applyConfig(selectedDevice, commands); - if (res) { + + const commands = [...deviceConfig.commands]; + + try { + const res = await api.applyConfig( + ip, + commands, + deviceItem.device.username, + deviceItem.device.password + ); Notification.success({ - title: '配置完毕', + title: `配置完毕 - ${deviceItem.device.name}`, description: JSON.stringify(res), }); - } else { + } catch (err) { Notification.error({ - title: '配置过程出现错误', - description: '请检查API提示', + title: `配置过程出现错误 - ${deviceItem.device.name}`, + description: err.message || '请检查API提示', }); } - } + }); + + await Promise.all(applyPromises); }; await Notification.promise({ - promise: applyOperation, + promise: applyOperation(), loading: { title: '配置应用中', description: '正在推送配置到设备...', }, success: { - title: '应用成功', - description: '配置已成功生效', + title: '应用完成', + description: '所有设备配置已推送', }, error: { title: '应用失败', @@ -174,19 +202,16 @@ const ConfigPage = () => { 交换机配置中心 + 选择交换机设备 { - const selectedIp = value[0] ?? ''; - setSelectedDevice(selectedIp); - const fullDeviceConfig = devices.find((device) => device.ip === selectedIp); - setSelectedDeviceConfig(JSON.stringify(fullDeviceConfig)); - }} + value={selectedDevices} + onValueChange={({ value }) => setSelectedDevices(value)} placeholder={'请选择交换机设备'} size={'sm'} colorPalette={'teal'} @@ -202,7 +227,7 @@ const ConfigPage = () => { - + {deviceCollection.items.map((item) => ( @@ -214,13 +239,14 @@ const ConfigPage = () => { + 配置指令输入