Merge remote-tracking branch 'origin/main'

# Conflicts:
#	src/backend/app/api/endpoints.py
#	src/backend/app/services/ai_services.py
#	src/backend/config.py
This commit is contained in:
3 2025-08-30 15:32:12 +08:00
commit 6a21cdef9a
15 changed files with 268 additions and 1099 deletions

View File

@ -1,5 +1,9 @@
# 基于人工智能实现的交换机自动或半自动配置
这只是一个用于应付实验论文的奇葩储存库
猎奇卡顿 · 龟速更新 · 随时跑路
### 技术栈
- **Python3**
- Flask
@ -11,10 +15,6 @@
- Framer-motion
- chakra-ui
- HTML5
### 项目分工
- **后端api人工智能算法** : `3`(主要) & `log_out` & `Jerry`(maybe) 使用python
- **前端管理后台设计**`Jerry`使用react
- **论文撰写**`log_out`
### 各部分说明
@ -22,18 +22,3 @@
[逻辑处理后端](https://github.com/Jerryplusy/AI-powered-switches/blob/main/src/backend/README.md)
### 贡献流程
- **后端api**:
- 对于`3`:直接推送到`main`分支
- 对于`Jerry`&`log_out`:新建额外的`feat``fix`分支,提交推送到自己的分支,然后提交`pullrequest``main`分支并指定`3`审核
- **前端管理后台**:
- 对于`Jerry`:直接推送更新到`main`分支
- 对于`3`&`log_out`:新建额外的`feat``fix`分支,提交推送到自己的分支,然后提交`pullrequest``main`分支并指定`Jerry`审核
- **论文(thesis)**
- 提交`pullrequest`并指定`log_out`审核
### 项目活动时间
2025 6 - 8月

View File

@ -10,7 +10,6 @@ src/backend/
│ ├── __init__.py # 创建 Flask 应用实例
│ ├── api/ # API 路由模块
│ │ ├—── __init__.py # 注册 API 蓝图
│ │ ├── command_parser.py # /api/parse_command 接口
│ │ └── network_config.py # /api/apply_config 接口
│ └── services/ # 核心服务逻辑
│ └── ai_services.py # 调用外部 AI 服务生成配置

View File

@ -1,14 +0,0 @@
from typing import Dict, Any
from src.backend.app.services.ai_services import AIService
from src.backend.config import settings
class CommandParser:
def __init__(self):
self.ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
async def parse(self, command: str) -> Dict[str, Any]:
"""
解析中文命令并返回配置
"""
return await self.ai_service.parse_command(command)

View File

@ -1,19 +1,14 @@
import asyncio
import logging
import telnetlib3
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional
import aiofiles
import asyncssh
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_exponential
from src.backend.app.utils.logger import logger
from src.backend.config import settings
# ----------------------
# 数据模型
# ----------------------
@ -25,26 +20,21 @@ class SwitchConfig(BaseModel):
ip_address: Optional[str] = None
vlan: Optional[int] = None
# ----------------------
# 异常类
# ----------------------
class SwitchConfigException(Exception):
pass
class EnspConnectionException(SwitchConfigException):
pass
class SSHConnectionException(SwitchConfigException):
pass
# ----------------------
# 核心配置器(完整双模式)
# 核心配置器
# ----------------------
class SwitchConfigurator:
connection_pool: Dict[str, tuple] = {}
def __init__(
self,
username: str = None,
@ -67,253 +57,66 @@ class SwitchConfigurator:
self.ensp_delay = ensp_command_delay
self.ssh_options = ssh_options
async def apply_config(self, ip: str, config: Union[Dict, SwitchConfig]) -> str:
"""实际配置逻辑"""
if isinstance(config, dict):
config = SwitchConfig(**config)
commands = (
self._generate_ensp_commands(config)
if self.ensp_mode
else self._generate_standard_commands(config)
)
return await self._send_commands(ip, commands)
async def _send_commands(self, ip: str, commands: List[str]) -> str:
"""双模式命令发送"""
return (
await self._send_ensp_commands(ip, commands)
)
async def _send_ensp_commands(self, ip: str, commands: List[str]) -> str:
async def _get_or_create_connection(self, ip: str):
"""
通过 Telnet 协议连接 eNSP 设备
从连接池获取连接如果没有则新建 Telnet 连接
"""
try:
logger.info(f"连接设备 {ip}端口23")
if ip in self.connection_pool:
logger.debug(f"复用已有连接: {ip}")
return self.connection_pool[ip]
logger.info(f"建立新连接: {ip}")
reader, writer = await telnetlib3.open_connection(host=ip, port=23)
logger.debug("连接成功,开始登录流程")
try:
if self.username != 'NONE':
if self.username != 'NONE' :
await asyncio.wait_for(reader.readuntil(b"Username:"), timeout=self.timeout)
logger.debug("收到 'Username:' 提示,发送用户名")
writer.write(f"{self.username}\n")
await asyncio.wait_for(reader.readuntil(b"Password:"), timeout=self.timeout)
logger.debug("收到 'Password:' 提示,发送密码")
writer.write(f"{self.password}\n")
await asyncio.sleep(1)
except asyncio.TimeoutError:
writer.close()
raise EnspConnectionException("登录超时,未收到用户名或密码提示")
except Exception as e:
writer.close()
raise EnspConnectionException(f"登录异常: {e}")
self.connection_pool[ip] = (reader, writer)
return reader, writer
async def _send_ensp_commands(self, ip: str, commands: List[str]) -> bool:
"""
通过 Telnet 协议发送命令
"""
try:
reader, writer = await self._get_or_create_connection(ip)
output = ""
for cmd in commands:
if cmd.startswith("!"):
logger.debug(f"跳过特殊命令: {cmd}")
continue
logger.info(f"发送命令: {cmd}")
logger.info(f"[{ip}] 发送命令: {cmd}")
writer.write(f"{cmd}\n")
await writer.drain()
await asyncio.sleep(self.ensp_delay)
command_output = ""
try:
while True:
data = await asyncio.wait_for(reader.read(1024), timeout=1)
if not data:
logger.debug("读取到空数据,结束当前命令读取")
break
command_output += data
logger.debug(f"收到数据: {repr(data)}")
except asyncio.TimeoutError:
logger.debug("命令输出读取超时,继续执行下一条命令")
output += f"\n[命令: {cmd} 输出开始]\n{command_output}\n[命令: {cmd} 输出结束]\n"
logger.info("所有命令执行完毕,关闭连接")
writer.close()
return output
logger.info(f"[{ip}] 所有命令发送完成")
return True
except asyncio.TimeoutError as e:
logger.error(f"连接或读取超时: {e}")
raise EnspConnectionException(f"eNSP连接超时: {e}")
logger.error(f"[{ip}] 连接或读取超时: {e}")
return False
except Exception as e:
logger.error(f"连接或执行异常: {e}", exc_info=True)
raise EnspConnectionException(f"eNSP连接失败: {e}")
@staticmethod
def _generate_ensp_commands(config: SwitchConfig) -> List[str]:
"""生成eNSP命令序列"""
commands = ["system-view"]
if config.type == "vlan":
commands.extend([
f"vlan {config.vlan_id}",
f"description {config.name or ''}"
])
elif config.type == "interface":
commands.extend([
f"interface {config.interface}",
"port link-type access",
f"port default vlan {config.vlan}" if config.vlan else "",
f"ip address {config.ip_address}" if config.ip_address else ""
])
commands.append("return")
return [c for c in commands if c.strip()]
async def _send_ssh_commands(self, ip: str, commands: List[str]) -> str:
"""AsyncSSH执行命令"""
async with self.semaphore:
try:
async with asyncssh.connect(
host=ip,
username=self.username,
password=self.password,
connect_timeout=self.timeout,
**self.ssh_options
) as conn:
results = []
for cmd in commands:
result = await conn.run(cmd, check=True)
results.append(result.stdout)
return "\n".join(results)
except asyncssh.Error as e:
raise SSHConnectionException(f"SSH操作失败: {str(e)}")
except Exception as e:
raise SSHConnectionException(f"连接异常: {str(e)}")
async def execute_raw_commands(self, ip: str, commands: List[str]) -> str:
"""
执行原始CLI命令
"""
return await self._send_commands(ip, commands)
@staticmethod
def _generate_standard_commands(config: SwitchConfig) -> List[str]:
"""生成标准CLI命令"""
commands = []
if config.type == "vlan":
commands.extend([
f"vlan {config.vlan_id}",
f"name {config.name or ''}"
])
elif config.type == "interface":
commands.extend([
f"interface {config.interface}",
f"switchport access vlan {config.vlan}" if config.vlan else "",
f"ip address {config.ip_address}" if config.ip_address else ""
])
return commands
async def _validate_config(self, ip: str, config: SwitchConfig) -> bool:
"""验证配置是否生效"""
current = await self._get_current_config(ip)
if config.type == "vlan":
return f"vlan {config.vlan_id}" in current
elif config.type == "interface" and config.vlan:
return f"switchport access vlan {config.vlan}" in current
return True
async def _get_current_config(self, ip: str) -> str:
"""获取当前配置"""
commands = (
["display current-configuration"]
if self.ensp_mode
else ["show running-config"]
)
try:
return await self._send_commands(ip, commands)
except (EnspConnectionException, SSHConnectionException) as e:
raise SwitchConfigException(f"配置获取失败: {str(e)}")
async def _backup_config(self, ip: str) -> Path:
"""备份配置到文件"""
backup_path = self.backup_dir / f"{ip}_{datetime.now().isoformat()}.cfg"
config = await self._get_current_config(ip)
async with aiofiles.open(backup_path, "w") as f:
await f.write(config)
return backup_path
async def _restore_config(self, ip: str, backup_path: Path) -> bool:
"""从备份恢复配置"""
try:
async with aiofiles.open(backup_path) as f:
config = await f.read()
commands = (
["system-view", config, "return"]
if self.ensp_mode
else [f"configure terminal\n{config}\nend"]
)
await self._send_commands(ip, commands)
return True
except Exception as e:
logging.error(f"恢复失败: {str(e)}")
logger.error(f"[{ip}] 命令发送异常: {e}", exc_info=True)
return False
@retry(
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=4, max=10)
)
async def safe_apply(
self,
ip: str,
config: Union[Dict, SwitchConfig]
) -> Dict[str, Union[str, bool, Path]]:
"""安全配置应用(自动回滚)"""
backup_path = await self._backup_config(ip)
try:
result = await self.apply_config(ip, config)
if not await self._validate_config(ip, config):
raise SwitchConfigException("配置验证失败")
return {
"status": "success",
"output": result,
"backup_path": str(backup_path)
}
except (EnspConnectionException, SSHConnectionException, SwitchConfigException) as e:
restore_status = await self._restore_config(ip, backup_path)
return {
"status": "failed",
"error": str(e),
"backup_path": str(backup_path),
"restore_success": restore_status
}
# ----------------------
# 使用示例
# ----------------------
async def demo():
ensp_configurator = SwitchConfigurator(
ensp_mode=True,
ensp_port=2000,
username="admin",
password="admin",
timeout=15
)
ensp_result = await ensp_configurator.safe_apply("127.0.0.1", {
"type": "interface",
"interface": "GigabitEthernet0/0/1",
"vlan": 100,
"ip_address": "192.168.1.2 255.255.255.0"
})
print("eNSP配置结果:", ensp_result)
ssh_configurator = SwitchConfigurator(
username="cisco",
password="cisco123",
timeout=15
)
ssh_result = await ssh_configurator.safe_apply("192.168.1.1", {
"type": "vlan",
"vlan_id": 200,
"name": "Production"
})
print("SSH配置结果:", ssh_result)
if __name__ == "__main__":
asyncio.run(demo())
async def execute_raw_commands(self, ip: str, commands: List[str]) -> bool:
"""
对外接口单台交换机执行命令
"""
async with self.semaphore:
success = await self._send_ensp_commands(ip, commands)
return success

View File

@ -0,0 +1,26 @@
from typing import List, Optional
from pydantic import BaseModel
class BatchConfigRequest(BaseModel):
config: dict
switch_ips: List[str]
username: Optional[str] = None
password: Optional[str] = None
timeout: Optional[int] = None
class ConfigRequest(BaseModel):
config: dict
switch_ip: str
username: Optional[str] = None
password: Optional[str] = None
timeout: Optional[int] = None
vendor: str = "huawei"
class CLICommandRequest(BaseModel):
switch_ip: str
commands: List[str]
username: Optional[str] = None
password: Optional[str] = None
def extract_credentials(self):
return self.username or "NONE", self.password or "NONE"

View File

@ -1,85 +0,0 @@
import re
from typing import Dict, List, Tuple
from ..utils.exceptions import SwitchConfigException
class ConfigValidator:
@staticmethod
def validate_vlan_config(config: Dict) -> Tuple[bool, str]:
"""验证VLAN配置"""
if 'vlan_id' not in config:
return False, "Missing VLAN ID"
vlan_id = config['vlan_id']
if not (1 <= vlan_id <= 4094):
return False, f"Invalid VLAN ID {vlan_id}. Must be 1-4094"
if 'name' in config and len(config['name']) > 32:
return False, "VLAN name too long (max 32 chars)"
return True, "Valid VLAN config"
@staticmethod
def validate_interface_config(config: Dict) -> Tuple[bool, str]:
"""验证接口配置"""
required_fields = ['interface', 'ip_address']
for field in required_fields:
if field not in config:
return False, f"Missing required field: {field}"
# 验证IP地址格式
ip_pattern = r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$'
if not re.match(ip_pattern, config['ip_address']):
return False, "Invalid IP address format"
# 验证接口名称格式
interface_pattern = r'^(GigabitEthernet|FastEthernet|Eth)\d+/\d+/\d+$'
if not re.match(interface_pattern, config['interface']):
return False, "Invalid interface name format"
return True, "Valid interface config"
@staticmethod
def check_security_risks(commands: List[str]) -> List[str]:
"""检查潜在安全风险"""
risky_commands = []
dangerous_patterns = [
r'no\s+aaa', # 禁用认证
r'enable\s+password', # 明文密码
r'service\s+password-encryption', # 弱加密
r'ip\s+http\s+server', # 启用HTTP服务
r'no\s+ip\s+http\s+secure-server' # 禁用HTTPS
]
for cmd in commands:
for pattern in dangerous_patterns:
if re.search(pattern, cmd, re.IGNORECASE):
risky_commands.append(cmd)
break
return risky_commands
@staticmethod
def validate_full_config(config: Dict) -> Tuple[bool, List[str]]:
"""全面验证配置"""
errors = []
if 'type' not in config:
errors.append("Missing configuration type")
return False, errors
if config['type'] == 'vlan':
valid, msg = ConfigValidator.validate_vlan_config(config)
if not valid:
errors.append(msg)
elif config['type'] == 'interface':
valid, msg = ConfigValidator.validate_interface_config(config)
if not valid:
errors.append(msg)
if 'commands' in config:
risks = ConfigValidator.check_security_risks(config['commands'])
if risks:
errors.append(f"Potential security risks detected: {', '.join(risks)}")
return len(errors) == 0, errors

View File

@ -1,129 +0,0 @@
import asyncio
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List
import networkx as nx
from ..models.traffic_models import SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
from ..utils.exceptions import SwitchConfigException
from ..services.network_scanner import NetworkScanner
class FailoverManager:
def __init__(self, config_backup_dir: str = "config_backups"):
self.scanner = NetworkScanner()
self.backup_dir = Path(config_backup_dir)
self.backup_dir.mkdir(exist_ok=True)
async def detect_failures(self, threshold: float = 0.9) -> List[Dict]:
"""检测可能的网络故障"""
with SessionLocal() as session:
# 获取最近5分钟的所有交换机流量记录
records = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.timestamp >= datetime.now() - timedelta(minutes=5)
).all()
# 分析流量异常
abnormal_devices = []
device_stats = {}
for record in records:
if record.switch_ip not in device_stats:
device_stats[record.switch_ip] = {
'interfaces': {},
'max_in': 0,
'max_out': 0
}
if record.interface not in device_stats[record.switch_ip]['interfaces']:
device_stats[record.switch_ip]['interfaces'][record.interface] = {
'in': [],
'out': []
}
device_stats[record.switch_ip]['interfaces'][record.interface]['in'].append(record.rate_in)
device_stats[record.switch_ip]['interfaces'][record.interface]['out'].append(record.rate_out)
# 更新设备最大流量
if record.rate_in > device_stats[record.switch_ip]['max_in']:
device_stats[record.switch_ip]['max_in'] = record.rate_in
if record.rate_out > device_stats[record.switch_ip]['max_out']:
device_stats[record.switch_ip]['max_out'] = record.rate_out
# 检测异常
for ip, stats in device_stats.items():
for iface, traffic in stats['interfaces'].items():
avg_in = sum(traffic['in']) / len(traffic['in'])
avg_out = sum(traffic['out']) / len(traffic['out'])
# 如果平均流量超过最大流量的90%,认为可能有问题
if avg_in > stats['max_in'] * threshold or \
avg_out > stats['max_out'] * threshold:
abnormal_devices.append({
'ip': ip,
'interface': iface,
'avg_in': avg_in,
'avg_out': avg_out,
'max_in': stats['max_in'],
'max_out': stats['max_out']
})
return abnormal_devices
async def automatic_recovery(self, failed_device_ip: str) -> bool:
"""自动故障恢复"""
try:
# 1. 检查设备是否在线
devices = self.scanner.scan_subnet(failed_device_ip + "/32")
if not devices:
raise SwitchConfigException(f"Device {failed_device_ip} is offline")
# 2. 查找最近的配置备份
backup_files = sorted(self.backup_dir.glob(f"{failed_device_ip}_*.cfg"))
if not backup_files:
raise SwitchConfigException(f"No backup found for {failed_device_ip}")
latest_backup = backup_files[-1]
# 3. 恢复配置
with open(latest_backup) as f:
config_commands = f.read().splitlines()
# 使用SSH执行恢复命令
# (这里需要实现SSH连接和执行命令的逻辑)
return True
except Exception as e:
raise SwitchConfigException(f"Recovery failed: {str(e)}")
async def redundancy_check(self, critical_devices: List[str]) -> Dict:
"""检查关键设备的冗余配置"""
results = {}
topology = self.scanner.get_current_topology()
for device_ip in critical_devices:
# 检查是否有备用路径
try:
paths = list(nx.all_shortest_paths(
topology, source=device_ip, target="core_switch"))
if len(paths) > 1:
results[device_ip] = {
'status': 'redundant',
'path_count': len(paths)
}
else:
results[device_ip] = {
'status': 'single_point_of_failure',
'recommendation': 'Add redundant links'
}
except:
results[device_ip] = {
'status': 'disconnected',
'recommendation': 'Check physical connection'
}
return results

View File

@ -1,220 +0,0 @@
import networkx as nx
import numpy as np
from scipy.optimize import linprog
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
@dataclass
class FlowDemand:
source: str
destination: str
bandwidth: float
priority: float = 1.0
class NetworkOptimizer:
def __init__(self, devices: List[Dict[str, Any]]):
"""基于图论的网络优化模型"""
self.graph = self.build_topology_graph(devices)
self.traffic_matrix: Optional[np.ndarray] = None
self._initialize_capacities()
def _initialize_capacities(self) -> Dict[Tuple[str, str], float]:
"""初始化链路容量字典"""
self.remaining_capacity = {
(u, v): data['bandwidth']
for u, v, data in self.graph.edges(data=True)
}
return self.remaining_capacity
def build_topology_graph(self, devices: List[Dict[str, Any]]) -> nx.Graph:
"""构建带权重的网络拓扑图"""
G = nx.Graph()
for device in devices:
G.add_node(device['ip'],
type=device['type'],
capacity=device.get('capacity', 1000))
# 添加接口作为子节点
for interface in device.get('interfaces', []):
interface_id = f"{device['ip']}_{interface['name']}"
G.add_node(interface_id,
type='interface',
capacity=interface.get('bandwidth', 1000))
G.add_edge(device['ip'], interface_id,
bandwidth=interface.get('bandwidth', 1000),
latency=interface.get('latency', 1))
# 添加设备间连接
self._connect_devices(G, devices)
return G
@staticmethod
def _connect_devices(graph: nx.Graph, devices: List[Dict[str, Any]]):
"""自动连接设备"""
for i in range(len(devices) - 1):
graph.add_edge(
devices[i]['ip'],
devices[i + 1]['ip'],
bandwidth=1000,
latency=5
)
@staticmethod
def _build_flow_conservation(nodes: List[str]) -> Tuple[np.ndarray, np.ndarray]:
"""构建流量守恒约束矩阵"""
num_nodes = len(nodes)
A_eq: List[np.ndarray] = [] # 明确指定列表元素类型
b_eq: List[float] = [] # 明确指定float列表
for i in range(num_nodes):
for j in range(num_nodes):
if i != j:
constraint = np.zeros((num_nodes, num_nodes))
constraint[i][j] = 1
constraint[j][i] = -1
A_eq.append(constraint.flatten())
b_eq.append(0.0) # 明确使用float类型
return np.array(A_eq, dtype=np.float64), np.array(b_eq, dtype=np.float64)
def optimize_bandwidth_allocation(
self,
demands: List[FlowDemand]
) -> Dict[str, Dict[str, float]]:
"""
基于线性规划的带宽分配优化
返回: {源IP: {目标IP: 分配带宽}}
"""
demand_dict = {(d.source, d.destination): float(d.bandwidth) for d in demands} # 确保float类型
return self._optimize_with_highs(demand_dict)
def _optimize_with_highs(self, demands: Dict[Tuple[str, str], float]) -> Dict[str, Dict[str, float]]:
"""使用HiGHS求解器实现"""
nodes = list(self.graph.nodes())
node_index = {node: i for i, node in enumerate(nodes)}
# 构建流量矩阵
self.traffic_matrix = np.zeros((len(nodes), len(nodes)))
for (src, dst), bw in demands.items():
if src in node_index and dst in node_index:
self.traffic_matrix[node_index[src]][node_index[dst]] = bw
# 构建约束
c = np.ones(len(nodes) ** 2) # 最小化总流量
A_ub, b_ub = self._build_capacity_constraints(nodes, node_index)
A_eq, b_eq = self._build_flow_conservation(nodes)
# 求解线性规划
res = linprog(
c=c,
A_ub=A_ub,
b_ub=b_ub,
A_eq=A_eq,
b_eq=b_eq,
bounds=(0, None),
method='highs'
)
if not res.success:
raise ValueError(f"Optimization failed: {res.message}")
return self._format_results(res.x, nodes, node_index)
def _build_capacity_constraints(self, nodes: List[str], node_index: Dict[str, int]) -> Tuple[
np.ndarray, np.ndarray]:
"""构建容量约束矩阵"""
A_ub = []
b_ub = []
for u, v, data in self.graph.edges(data=True):
capacity = float(data.get('bandwidth', 1000)) # 确保float类型
b_ub.append(capacity)
constraint = np.zeros((len(nodes), len(nodes)))
for src, dst in [(u, v), (v, u)]:
if src in node_index and dst in node_index:
i, j = node_index[src], node_index[dst]
constraint[i][j] += 1
A_ub.append(constraint.flatten())
return np.array(A_ub), np.array(b_ub)
@staticmethod
def _format_results(solution: np.ndarray, nodes: List[str], node_index: Dict[str, int]) -> Dict[
str, Dict[str, float]]:
"""格式化优化结果"""
flows = solution.reshape((len(nodes), len(nodes)))
return {
nodes[i]: {
nodes[j]: float(flows[i][j]) # 明确转换为float
for j in range(len(nodes))
if flows[i][j] > 0.001
}
for i in range(len(nodes))
}
def optimize_qos(self, flows: List[FlowDemand]) -> Dict[str, Dict[str, Any]]:
"""
带QoS的流量优化
返回: {
"源-目标": {
"path": 路径列表,
"allocated": 分配带宽,
"priority": 优先级
}
}
"""
sorted_flows = sorted(flows, key=lambda x: x.priority, reverse=True)
results = {}
# 使用实例变量而非局部变量
self._initialize_capacities() # 重置剩余带宽
for flow in sorted_flows:
path = self.find_optimal_path(flow.source, flow.destination, flow.bandwidth)
if not path:
continue
min_bw = min(self.graph[u][v]['bandwidth'] for u, v in zip(path[:-1], path[1:]))
allocated = min(min_bw, flow.bandwidth)
# 更新剩余带宽
for u, v in zip(path[:-1], path[1:]):
self.remaining_capacity[(u, v)] -= allocated
if self.remaining_capacity[(u, v)] <= 0:
self.graph[u][v]['bandwidth'] = 0
results[f"{flow.source}-{flow.destination}"] = {
"path": path,
"allocated": float(allocated), # 确保float类型
"priority": float(flow.priority)
}
return results
def find_optimal_path(self, source: str, target: str,
bandwidth: float = 1.0) -> Optional[List[str]]:
"""
改进的最优路径查找
"""
try:
paths = nx.shortest_simple_paths(
self.graph, source, target,
weight=lambda u, v, d: 1 / max(1, d['bandwidth']) + d['latency'] # 避免除零
)
return next(
(path for path in paths
if self._path_has_sufficient_bandwidth(path, bandwidth)),
None
)
except (nx.NetworkXNoPath, nx.NodeNotFound):
return None
def _path_has_sufficient_bandwidth(self, path: List[str], bw: float) -> bool:
"""检查路径带宽是否满足要求"""
return all(
self.graph[u][v]['bandwidth'] >= bw
for u, v in zip(path[:-1], path[1:])
)

View File

@ -1,104 +0,0 @@
from datetime import datetime, timedelta
import networkx as nx
import matplotlib.pyplot as plt
from io import BytesIO
import base64
from typing import Dict, List
from ..models.traffic_models import SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
class NetworkVisualizer:
def __init__(self):
self.graph = nx.Graph()
def update_topology(self, devices: List[Dict]):
"""更新网络拓扑图"""
self.graph.clear()
# 添加节点
for device in devices:
self.graph.add_node(
device['ip'],
type='switch',
label=f"Switch\n{device['ip']}"
)
# 添加连接(简化版,实际应根据扫描结果)
if len(devices) > 1:
for i in range(len(devices) - 1):
self.graph.add_edge(
devices[i]['ip'],
devices[i + 1]['ip'],
bandwidth=1000,
label="1Gbps"
)
def generate_topology_image(self) -> str:
"""生成拓扑图并返回base64编码"""
plt.figure(figsize=(10, 8))
pos = nx.spring_layout(self.graph)
# 绘制节点
node_colors = []
for node in self.graph.nodes():
if self.graph.nodes[node]['type'] == 'switch':
node_colors.append('lightblue')
nx.draw_networkx_nodes(
self.graph, pos,
node_size=2000,
node_color=node_colors
)
# 绘制边
nx.draw_networkx_edges(
self.graph, pos,
width=2,
alpha=0.5
)
# 绘制标签
node_labels = nx.get_node_attributes(self.graph, 'label')
nx.draw_networkx_labels(
self.graph, pos,
labels=node_labels,
font_size=8
)
edge_labels = nx.get_edge_attributes(self.graph, 'label')
nx.draw_networkx_edge_labels(
self.graph, pos,
edge_labels=edge_labels,
font_size=8
)
plt.title("Network Topology")
plt.axis('off')
# 转换为base64
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close()
buf.seek(0)
return base64.b64encode(buf.read()).decode('utf-8')
@staticmethod
def get_traffic_heatmap(minutes: int = 10) -> Dict:
"""获取流量热力图数据"""
with SessionLocal() as session:
records = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.timestamp >= datetime.now() - timedelta(minutes=minutes)
).all()
heatmap_data = {}
for record in records:
if record.switch_ip not in heatmap_data:
heatmap_data[record.switch_ip] = {}
heatmap_data[record.switch_ip][record.interface] = {
'in': record.rate_in,
'out': record.rate_out
}
return heatmap_data

View File

@ -1,120 +0,0 @@
import pandas as pd
import matplotlib.pyplot as plt
from io import BytesIO
import base64
from datetime import datetime, timedelta
from typing import Dict, List
from ..models.traffic_models import TrafficRecord, SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
class ReportGenerator:
@staticmethod
def generate_traffic_report(ip: str = None, days: int = 7) -> Dict:
"""生成流量分析报告"""
with SessionLocal() as session:
time_threshold = datetime.now() - timedelta(days=days)
if ip:
# 交换机流量报告
query = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == ip,
SwitchTrafficRecord.timestamp >= time_threshold
)
df = pd.read_sql(query.statement, session.bind)
if df.empty:
return {"error": "No data found"}
# 按接口分组分析
interface_stats = df.groupby('interface').agg({
'rate_in': ['mean', 'max', 'min'],
'rate_out': ['mean', 'max', 'min']
}).reset_index()
# 生成趋势图
trend_fig = ReportGenerator._plot_traffic_trend(df, f"Switch {ip} Traffic Trend")
return {
"switch_ip": ip,
"period": f"Last {days} days",
"interface_stats": interface_stats.to_dict('records'),
"trend_chart": trend_fig
}
else:
# 本地网络流量报告
query = session.query(TrafficRecord).filter(
TrafficRecord.timestamp >= time_threshold
)
df = pd.read_sql(query.statement, session.bind)
if df.empty:
return {"error": "No data found"}
# 按接口分组分析
interface_stats = df.groupby('interface').agg({
'bytes_sent': ['sum', 'mean', 'max'],
'bytes_recv': ['sum', 'mean', 'max']
}).reset_index()
# 生成趋势图
trend_fig = ReportGenerator._plot_traffic_trend(df, "Local Network Traffic Trend")
return {
"report_type": "local_network",
"period": f"Last {days} days",
"interface_stats": interface_stats.to_dict('records'),
"trend_chart": trend_fig
}
@staticmethod
def _plot_traffic_trend(df: pd.DataFrame, title: str) -> str:
"""生成流量趋势图"""
plt.figure(figsize=(12, 6))
if 'rate_in' in df.columns: # 交换机数据
df_grouped = df.groupby(pd.Grouper(key='timestamp', freq='1H')).agg({
'rate_in': 'mean',
'rate_out': 'mean'
})
plt.plot(df_grouped.index, df_grouped['rate_in'], label='Inbound Traffic')
plt.plot(df_grouped.index, df_grouped['rate_out'], label='Outbound Traffic')
plt.ylabel("Traffic Rate (bytes/sec)")
else: # 本地网络数据
df_grouped = df.groupby(pd.Grouper(key='timestamp', freq='1H')).agg({
'bytes_sent': 'sum',
'bytes_recv': 'sum'
})
plt.plot(df_grouped.index, df_grouped['bytes_sent'], label='Bytes Sent')
plt.plot(df_grouped.index, df_grouped['bytes_recv'], label='Bytes Received')
plt.ylabel("Bytes")
plt.title(title)
plt.xlabel("Time")
plt.legend()
plt.grid(True)
# 转换为base64
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close()
buf.seek(0)
return base64.b64encode(buf.read()).decode('utf-8')
@staticmethod
def generate_config_history_report(days: int = 30) -> Dict:
"""生成配置变更历史报告"""
# 需要实现配置历史记录功能
pass
@staticmethod
def generate_security_report() -> Dict:
"""生成安全评估报告"""
# 需要实现安全扫描功能
pass
@staticmethod
def generate_performance_report() -> Dict:
"""生成性能评估报告"""
# 需要实现性能基准测试功能
pass

View File

@ -8,6 +8,7 @@ from typing import Dict, Optional, List
from ..models.traffic_models import TrafficRecord
from src.backend.app.api.database import SessionLocal
from ..utils.logger import logger
class TrafficMonitor:
@ -33,7 +34,7 @@ class TrafficMonitor:
if not self.running:
self.running = True
self.task = asyncio.create_task(self._monitor_loop())
print("流量监控已启动")
logger.info("流量监控已启动")
async def stop_monitoring(self):
"""停止流量监控"""
@ -44,7 +45,7 @@ class TrafficMonitor:
await self.task
except asyncio.CancelledError:
pass
print("流量监控已停止")
logger.info("流量监控已停止")
async def _monitor_loop(self):
"""监控主循环"""

View File

@ -12,29 +12,29 @@ import {
Field,
Input,
Stack,
Portal,
Select,
} from '@chakra-ui/react';
import { motion } from 'framer-motion';
import { FiCheck } from 'react-icons/fi';
import Notification from '@/libs/system/Notification';
import { createListCollection } from '@chakra-ui/react';
const MotionBox = motion(Box);
/**
* 设备配置弹窗
* @param isOpen 是否打开
* @param onClose 关闭弹窗
* @param onSave 保存修改
* @param device 当前设备
* @returns {JSX.Element}
* @constructor
*/
const vendors = ['huawei', 'cisco', 'h3c', 'ruijie', 'zte'];
const DeviceConfigModal = ({ isOpen, onClose, onSave, device }) => {
const [username, setUsername] = useState(device.username || '');
const [password, setPassword] = useState(device.password || '');
const [vendor, setVendor] = useState(device.vendor || '');
const [saved, setSaved] = useState(false);
const vendorCollection = createListCollection({
items: vendors.map((v) => ({ label: v.toUpperCase(), value: v })),
});
const handleSave = () => {
const updatedDevice = { ...device, username, password };
const updatedDevice = { ...device, username, password, vendor };
onSave(updatedDevice);
setSaved(true);
setTimeout(() => {
@ -82,6 +82,40 @@ const DeviceConfigModal = ({ isOpen, onClose, onSave, device }) => {
type={'password'}
/>
</Field.Root>
<Field.Root>
<Field.Label>交换机厂商</Field.Label>
<Select.Root
collection={vendorCollection}
value={vendor ? [vendor] : []}
onValueChange={({ value }) => setVendor(value[0] || '')}
placeholder={'请选择厂商'}
size={'sm'}
colorPalette={'teal'}
>
<Select.HiddenSelect />
<Select.Control>
<Select.Trigger>
<Select.ValueText />
</Select.Trigger>
<Select.IndicatorGroup>
<Select.Indicator />
<Select.ClearTrigger />
</Select.IndicatorGroup>
</Select.Control>
<Portal>
<Select.Positioner style={{ zIndex: 1500 }}>
<Select.Content>
{vendorCollection.items.map((item) => (
<Select.Item key={item.value} item={item}>
{item.label}
</Select.Item>
))}
</Select.Content>
</Select.Positioner>
</Portal>
</Select.Root>
</Field.Root>
</Stack>
</DialogBody>

View File

@ -22,22 +22,16 @@ import ConfigTool from '@/libs/config/ConfigTool';
import { api } from '@/services/api/api';
import Notification from '@/libs/system/Notification';
import Common from '@/libs/common';
import configEffect from '@/libs/script/configPage/configEffect';
const testMode = ConfigTool.load().testMode;
const ConfigPage = () => {
const [devices, setDevices] = useState([]);
const [selectedDevice, setSelectedDevice] = useState('');
const [selectedDeviceConfig, setSelectedDeviceConfig] = useState('');
const [selectedDevices, setSelectedDevices] = useState([]);
const [deviceConfigs, setDeviceConfigs] = useState({});
const [inputText, setInputText] = useState('');
const [parsedConfig, setParsedConfig] = useState('');
const [editableConfig, setEditableConfig] = useState('');
const [applying, setApplying] = useState(false);
const [hasParsed, setHasParsed] = useState(false);
const [isPeizhi, setisPeizhi] = useState(false);
const [isApplying, setIsApplying] = useState(false);
const [applyStatus, setApplyStatus] = useState([]);
const deviceCollection = createListCollection({
items: devices.map((device) => ({
@ -52,18 +46,30 @@ const ConfigPage = () => {
}, []);
const handleParse = async () => {
if (!selectedDevice || !inputText.trim()) {
if (selectedDevices.length === 0 || !inputText.trim()) {
Notification.error({
title: '操作失败',
description: '请选择设备并输入配置指令',
description: '请选择至少一个设备并输入配置指令',
});
return;
}
const selectedConfigs = devices.filter((device) => selectedDevices.includes(device.ip));
const deviceWithoutVendor = selectedConfigs.find((d) => !d.vendor || d.vendor.trim() === '');
if (deviceWithoutVendor) {
Notification.error({
title: '操作失败',
description: `设备 ${deviceWithoutVendor.name} 暂未配置厂商,请先配置厂商`,
});
return;
}
try {
const performParse = async () => {
return await api.parseCommand(inputText);
};
const performParse = async () =>
await api.parseCommand({
command: inputText,
devices: selectedConfigs,
});
const resultWrapper = await Notification.promise({
promise: performParse(),
@ -82,11 +88,15 @@ const ConfigPage = () => {
});
let result = await resultWrapper.unwrap();
if (result?.data) {
setParsedConfig(JSON.stringify(result.data));
setEditableConfig(JSON.stringify(result.data));
if (result?.data?.config) {
const configMap = {};
result.data.config.forEach((item) => {
if (item.device?.ip) {
configMap[item.device.ip] = item;
}
});
setDeviceConfigs(configMap);
setHasParsed(true);
setisPeizhi(true);
}
} catch (error) {
console.error('配置解析异常:', error);
@ -98,62 +108,80 @@ const ConfigPage = () => {
};
const handleApply = async () => {
if (!editableConfig) {
if (!hasParsed) {
Notification.warn({
title: '配置为空',
description: '请先解析或编辑有效配置',
title: '未解析配置',
description: '请先解析配置再应用',
});
return;
}
setApplying(true);
setIsApplying(true);
try {
const applyOperation = async () => {
if (testMode) {
Common.sleep(1000).then(() => ({ success: true }));
} else {
let commands = JSON.parse(editableConfig)?.config?.commands;
console.log(`commands:${JSON.stringify(commands)}`);
const deviceConfig = JSON.parse(selectedDeviceConfig);
console.log(`deviceConfig:${JSON.stringify(deviceConfig)}`);
if (!deviceConfig.password) {
Notification.warn({
title: '所选交换机暂未配置用户名(可选)和密码',
description: '请前往交换机设备处配置username和password',
});
return false;
}
if (deviceConfig.username || deviceConfig.username.toString() !== '') {
commands.push(`!username=${deviceConfig.username.toString()}`);
} else {
commands.push(`!username=NONE`);
}
commands.push(`!password=${deviceConfig.password.toString()}`);
const res = await api.applyConfig(selectedDevice, commands);
if (res) {
await Common.sleep(1000);
Notification.success({
title: '配置完毕',
title: '测试模式成功',
description: '配置已模拟应用',
});
return;
}
const applyPromises = selectedDevices.map(async (ip) => {
const deviceItem = deviceConfigs[ip];
if (!deviceItem) return;
const deviceConfig = deviceItem.config;
if (!deviceItem.device.password) {
Notification.warn({
title: `交换机 ${deviceItem.device.name} 暂未配置密码`,
description: '请前往交换机设备处配置用户名和密码',
});
console.log(JSON.stringify(deviceItem));
return;
}
if (!deviceItem.device.username) {
Notification.warn({
title: `交换机 ${deviceItem.device.name} 暂未配置用户名,将使用NONE作为用户名`,
});
deviceItem.device.username = 'NONE';
}
const commands = [...deviceConfig.commands];
try {
const res = await api.applyConfig(
ip,
commands,
deviceItem.device.username,
deviceItem.device.password
);
Notification.success({
title: `配置完毕 - ${deviceItem.device.name}`,
description: JSON.stringify(res),
});
} else {
} catch (err) {
Notification.error({
title: '配置过程出现错误',
description: '请检查API提示',
title: `配置过程出现错误 - ${deviceItem.device.name}`,
description: err.message || '请检查API提示',
});
}
}
});
await Promise.all(applyPromises);
};
await Notification.promise({
promise: applyOperation,
promise: applyOperation(),
loading: {
title: '配置应用中',
description: '正在推送配置到设备...',
},
success: {
title: '应用成功',
description: '配置已成功生效',
title: '应用成',
description: '所有设备配置已推送',
},
error: {
title: '应用失败',
@ -174,19 +202,16 @@ const ConfigPage = () => {
<Heading fontSize={'xl'} color={'teal.300'}>
交换机配置中心
</Heading>
<Field.Root>
<Field.Label fontWeight={'bold'} mb={1} fontSize="sm">
选择交换机设备
</Field.Label>
<Select.Root
multiple
collection={deviceCollection}
value={selectedDevice ? [selectedDevice] : []}
onValueChange={({ value }) => {
const selectedIp = value[0] ?? '';
setSelectedDevice(selectedIp);
const fullDeviceConfig = devices.find((device) => device.ip === selectedIp);
setSelectedDeviceConfig(JSON.stringify(fullDeviceConfig));
}}
value={selectedDevices}
onValueChange={({ value }) => setSelectedDevices(value)}
placeholder={'请选择交换机设备'}
size={'sm'}
colorPalette={'teal'}
@ -202,7 +227,7 @@ const ConfigPage = () => {
</Select.IndicatorGroup>
</Select.Control>
<Portal>
<Select.Positioner>
<Select.Positioner style={{ zIndex: 1500 }}>
<Select.Content>
{deviceCollection.items.map((item) => (
<Select.Item key={item.value} item={item}>
@ -214,13 +239,14 @@ const ConfigPage = () => {
</Portal>
</Select.Root>
</Field.Root>
<Field.Root>
<Field.Label fontWeight={'bold'} mb={1} fontSize="sm">
配置指令输入
</Field.Label>
<Textarea
rows={4}
placeholder={'例创建VLAN 10并配置IP 192.168.10.1/24并在端口1启用SSH访问"'}
placeholder={'例创建VLAN 10并配置IP 192.168.10.1/24并在端口1启用SSH访问'}
value={inputText}
colorPalette={'teal'}
orientation={'vertical'}
@ -229,30 +255,27 @@ const ConfigPage = () => {
size={'sm'}
/>
</Field.Root>
<Button
colorScheme={'teal'}
variant={'solid'}
size={'sm'}
onClick={handleParse}
isDisabled={!selectedDevice || !inputText.trim()}
isDisabled={selectedDevices.length === 0 || !inputText.trim()}
>
解析配置
</Button>
{isPeizhi && parsedConfig && (
{hasParsed && selectedDevices.length > 0 && (
<FadeInWrapper delay={0.2}>
<VStack spacing={4} align={'stretch'}>
{(() => {
let parsed;
try {
parsed = JSON.parse(editableConfig);
} catch (e) {
return <Text color={'red.300'}>配置 JSON 格式错误无法解析</Text>;
}
const config = parsed.config ? [parsed.config] : parsed;
return config.map((cfg, idx) => (
{selectedDevices.map((ip) => {
const item = deviceConfigs[ip];
if (!item) return null;
const cfg = item.config;
return (
<Box
key={idx}
key={ip}
p={4}
bg={'whiteAlpha.100'}
borderRadius={'xl'}
@ -260,7 +283,7 @@ const ConfigPage = () => {
borderColor={'whiteAlpha.300'}
>
<Text fontSize={'lg'} fontWeight={'bold'} mb={2}>
配置类型: {cfg.type}
设备: {item.device.name} ({ip}) - 配置类型: {cfg.type}
</Text>
{Object.entries(cfg).map(([key, value]) => {
@ -278,9 +301,16 @@ const ConfigPage = () => {
value={value}
onChange={(e) => {
const newVal = e.target.value;
const updated = JSON.parse(editableConfig);
updated.config[key] = newVal;
setEditableConfig(JSON.stringify(updated, null, 2));
setDeviceConfigs((prev) => ({
...prev,
[ip]: {
...prev[ip],
config: {
...prev[ip].config,
[key]: newVal,
},
},
}));
}}
/>
</Field.Root>
@ -299,20 +329,26 @@ const ConfigPage = () => {
value={cmd}
onChange={(e) => {
const newCmd = e.target.value;
const updated = JSON.parse(editableConfig);
updated.config.commands[i] = newCmd;
setEditableConfig(JSON.stringify(updated, null, 2));
setDeviceConfigs((prev) => {
const updated = { ...prev };
updated[ip].config.commands[i] = newCmd;
return updated;
});
}}
/>
</Field.Root>
))}
<HStack mt={4} spacing={3} justify={'flex-end'}>
<Button
variant={'outline'}
colorScheme={'gray'}
size={'sm'}
onClick={() => {
setEditableConfig(parsedConfig);
setDeviceConfigs((prev) => ({
...prev,
[ip]: item,
}));
Notification.success({
title: '成功重置配置!',
description: '现在您可以重新审查生成的配置',
@ -341,63 +377,14 @@ const ConfigPage = () => {
size={'sm'}
onClick={handleApply}
isLoading={applying}
isDisabled={!editableConfig}
isDisabled={!cfg.commands || cfg.commands.length === 0}
>
应用到交换机
</Button>
</HStack>
</Box>
));
})()}
{
<FadeInWrapper delay={0.2}>
<VStack spacing={4} align={'stretch'}>
<Box
p={4}
bg={'whiteAlpha.100'}
borderRadius={'xl'}
border={'1px solid'}
borderColor={'whiteAlpha.300'}
>
<Text fontSize={'lg'} fontWeight={'bold'} mb={2}>
应用配置命令
</Text>
<Box>
{JSON.parse(editableConfig).config?.commands.map((command, index) => (
<HStack key={index} mb={2}>
<Text fontSize={'sm'} flex={1}>
{command}
</Text>
<Spinner
size={'sm'}
color={applyStatus[index] === 'success' ? 'green.500' : 'red.500'}
display={
applyStatus[index] === 'pending' ||
applyStatus[index] === 'in-progress'
? 'inline-block'
: 'none'
}
/>
<Text
color={applyStatus[index] === 'success' ? 'green.500' : 'red.500'}
ml={2}
>
{applyStatus[index] === 'success'
? '成功'
: applyStatus[index] === 'failed'
? '失败'
: applyStatus[index] === 'in-progress'
? '正在应用'
: ''}
</Text>
</HStack>
))}
</Box>
</Box>
</VStack>
</FadeInWrapper>
}
);
})}
</VStack>
</FadeInWrapper>
)}

View File

@ -25,7 +25,7 @@ export const api = {
/**
* 扫描网络
* @param subnet 子网地址
* @param {string} subnet 子网地址
* @returns {Promise<axios.AxiosResponse<any>>}
*/
scan: (subnet) => axios.get(buildUrl('/api/scan_network'), { params: { subnet } }),
@ -38,19 +38,25 @@ export const api = {
/**
* 解析命令
* @param text 文本
* @param {Object} payload
* @param {string} payload.command - 自然语言命令
* @param {Array<Object>} payload.devices - 设备列表
* 每个对象包含 { id: string, ip: string, vendor: string(huawei/cisco/h3c/ruijie/zte) }
* @returns {Promise<axios.AxiosResponse<any>>}
*/
parseCommand: (text) => axios.post(buildUrl('/api/parse_command'), { command: text }),
parseCommand: ({ command, devices }) =>
axios.post(buildUrl('/api/parse_command'), { command, devices }),
/**
* 应用配置
* @param switch_ip 交换机ip
* @param commands 配置,为数组[]
* @param {string} switch_ip 交换机IP
* @param {Array<string>} commands 配置命令数组
* @param username 用户名无时使用NONE
* @param password 密码
* @returns {Promise<axios.AxiosResponse<any>>}
*/
applyConfig: (switch_ip, commands) =>
axios.post(buildUrl('/api/execute_cli_commands'), { switch_ip: switch_ip, commands: commands }),
applyConfig: (switch_ip, commands, username, password) =>
axios.post(buildUrl('/api/execute_cli_commands'), { switch_ip, commands, username, password }),
/**
* 获取网络适配器信息
@ -60,7 +66,7 @@ export const api = {
/**
* 更新基础URL
* @param url
* @param {string} url
*/
updateBaseUrl: (url) => {
const config = ConfigTool.load();
@ -77,7 +83,7 @@ export const getConfig = () => ConfigTool.load();
/**
* 获取基础URL
* @returns {string|string}
* @returns {string}
*/
export const getBaseUrl = () => ConfigTool.load().backendUrl || '';