mirror of
https://github.com/Jerryplusy/AI-powered-switches.git
synced 2025-10-14 09:49:19 +00:00
再再删掉一些东西
This commit is contained in:
parent
59c8604cda
commit
29f016bdab
@ -1,85 +0,0 @@
|
|||||||
import re
|
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
from ..utils.exceptions import SwitchConfigException
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigValidator:
|
|
||||||
@staticmethod
|
|
||||||
def validate_vlan_config(config: Dict) -> Tuple[bool, str]:
|
|
||||||
"""验证VLAN配置"""
|
|
||||||
if 'vlan_id' not in config:
|
|
||||||
return False, "Missing VLAN ID"
|
|
||||||
|
|
||||||
vlan_id = config['vlan_id']
|
|
||||||
if not (1 <= vlan_id <= 4094):
|
|
||||||
return False, f"Invalid VLAN ID {vlan_id}. Must be 1-4094"
|
|
||||||
|
|
||||||
if 'name' in config and len(config['name']) > 32:
|
|
||||||
return False, "VLAN name too long (max 32 chars)"
|
|
||||||
|
|
||||||
return True, "Valid VLAN config"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def validate_interface_config(config: Dict) -> Tuple[bool, str]:
|
|
||||||
"""验证接口配置"""
|
|
||||||
required_fields = ['interface', 'ip_address']
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in config:
|
|
||||||
return False, f"Missing required field: {field}"
|
|
||||||
|
|
||||||
# 验证IP地址格式
|
|
||||||
ip_pattern = r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$'
|
|
||||||
if not re.match(ip_pattern, config['ip_address']):
|
|
||||||
return False, "Invalid IP address format"
|
|
||||||
|
|
||||||
# 验证接口名称格式
|
|
||||||
interface_pattern = r'^(GigabitEthernet|FastEthernet|Eth)\d+/\d+/\d+$'
|
|
||||||
if not re.match(interface_pattern, config['interface']):
|
|
||||||
return False, "Invalid interface name format"
|
|
||||||
|
|
||||||
return True, "Valid interface config"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_security_risks(commands: List[str]) -> List[str]:
|
|
||||||
"""检查潜在安全风险"""
|
|
||||||
risky_commands = []
|
|
||||||
dangerous_patterns = [
|
|
||||||
r'no\s+aaa', # 禁用认证
|
|
||||||
r'enable\s+password', # 明文密码
|
|
||||||
r'service\s+password-encryption', # 弱加密
|
|
||||||
r'ip\s+http\s+server', # 启用HTTP服务
|
|
||||||
r'no\s+ip\s+http\s+secure-server' # 禁用HTTPS
|
|
||||||
]
|
|
||||||
|
|
||||||
for cmd in commands:
|
|
||||||
for pattern in dangerous_patterns:
|
|
||||||
if re.search(pattern, cmd, re.IGNORECASE):
|
|
||||||
risky_commands.append(cmd)
|
|
||||||
break
|
|
||||||
|
|
||||||
return risky_commands
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def validate_full_config(config: Dict) -> Tuple[bool, List[str]]:
|
|
||||||
"""全面验证配置"""
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
if 'type' not in config:
|
|
||||||
errors.append("Missing configuration type")
|
|
||||||
return False, errors
|
|
||||||
|
|
||||||
if config['type'] == 'vlan':
|
|
||||||
valid, msg = ConfigValidator.validate_vlan_config(config)
|
|
||||||
if not valid:
|
|
||||||
errors.append(msg)
|
|
||||||
elif config['type'] == 'interface':
|
|
||||||
valid, msg = ConfigValidator.validate_interface_config(config)
|
|
||||||
if not valid:
|
|
||||||
errors.append(msg)
|
|
||||||
|
|
||||||
if 'commands' in config:
|
|
||||||
risks = ConfigValidator.check_security_risks(config['commands'])
|
|
||||||
if risks:
|
|
||||||
errors.append(f"Potential security risks detected: {', '.join(risks)}")
|
|
||||||
|
|
||||||
return len(errors) == 0, errors
|
|
@ -1,129 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import networkx as nx
|
|
||||||
|
|
||||||
from ..models.traffic_models import SwitchTrafficRecord
|
|
||||||
from src.backend.app.api.database import SessionLocal
|
|
||||||
from ..utils.exceptions import SwitchConfigException
|
|
||||||
from ..services.network_scanner import NetworkScanner
|
|
||||||
|
|
||||||
|
|
||||||
class FailoverManager:
|
|
||||||
def __init__(self, config_backup_dir: str = "config_backups"):
|
|
||||||
self.scanner = NetworkScanner()
|
|
||||||
self.backup_dir = Path(config_backup_dir)
|
|
||||||
self.backup_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
async def detect_failures(self, threshold: float = 0.9) -> List[Dict]:
|
|
||||||
"""检测可能的网络故障"""
|
|
||||||
with SessionLocal() as session:
|
|
||||||
# 获取最近5分钟的所有交换机流量记录
|
|
||||||
records = session.query(SwitchTrafficRecord).filter(
|
|
||||||
SwitchTrafficRecord.timestamp >= datetime.now() - timedelta(minutes=5)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
# 分析流量异常
|
|
||||||
abnormal_devices = []
|
|
||||||
device_stats = {}
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
if record.switch_ip not in device_stats:
|
|
||||||
device_stats[record.switch_ip] = {
|
|
||||||
'interfaces': {},
|
|
||||||
'max_in': 0,
|
|
||||||
'max_out': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
if record.interface not in device_stats[record.switch_ip]['interfaces']:
|
|
||||||
device_stats[record.switch_ip]['interfaces'][record.interface] = {
|
|
||||||
'in': [],
|
|
||||||
'out': []
|
|
||||||
}
|
|
||||||
|
|
||||||
device_stats[record.switch_ip]['interfaces'][record.interface]['in'].append(record.rate_in)
|
|
||||||
device_stats[record.switch_ip]['interfaces'][record.interface]['out'].append(record.rate_out)
|
|
||||||
|
|
||||||
# 更新设备最大流量
|
|
||||||
if record.rate_in > device_stats[record.switch_ip]['max_in']:
|
|
||||||
device_stats[record.switch_ip]['max_in'] = record.rate_in
|
|
||||||
if record.rate_out > device_stats[record.switch_ip]['max_out']:
|
|
||||||
device_stats[record.switch_ip]['max_out'] = record.rate_out
|
|
||||||
|
|
||||||
# 检测异常
|
|
||||||
for ip, stats in device_stats.items():
|
|
||||||
for iface, traffic in stats['interfaces'].items():
|
|
||||||
avg_in = sum(traffic['in']) / len(traffic['in'])
|
|
||||||
avg_out = sum(traffic['out']) / len(traffic['out'])
|
|
||||||
|
|
||||||
# 如果平均流量超过最大流量的90%,认为可能有问题
|
|
||||||
if avg_in > stats['max_in'] * threshold or \
|
|
||||||
avg_out > stats['max_out'] * threshold:
|
|
||||||
abnormal_devices.append({
|
|
||||||
'ip': ip,
|
|
||||||
'interface': iface,
|
|
||||||
'avg_in': avg_in,
|
|
||||||
'avg_out': avg_out,
|
|
||||||
'max_in': stats['max_in'],
|
|
||||||
'max_out': stats['max_out']
|
|
||||||
})
|
|
||||||
|
|
||||||
return abnormal_devices
|
|
||||||
|
|
||||||
async def automatic_recovery(self, failed_device_ip: str) -> bool:
|
|
||||||
"""自动故障恢复"""
|
|
||||||
try:
|
|
||||||
# 1. 检查设备是否在线
|
|
||||||
devices = self.scanner.scan_subnet(failed_device_ip + "/32")
|
|
||||||
if not devices:
|
|
||||||
raise SwitchConfigException(f"Device {failed_device_ip} is offline")
|
|
||||||
|
|
||||||
# 2. 查找最近的配置备份
|
|
||||||
backup_files = sorted(self.backup_dir.glob(f"{failed_device_ip}_*.cfg"))
|
|
||||||
if not backup_files:
|
|
||||||
raise SwitchConfigException(f"No backup found for {failed_device_ip}")
|
|
||||||
|
|
||||||
latest_backup = backup_files[-1]
|
|
||||||
|
|
||||||
# 3. 恢复配置
|
|
||||||
with open(latest_backup) as f:
|
|
||||||
config_commands = f.read().splitlines()
|
|
||||||
|
|
||||||
# 使用SSH执行恢复命令
|
|
||||||
# (这里需要实现SSH连接和执行命令的逻辑)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise SwitchConfigException(f"Recovery failed: {str(e)}")
|
|
||||||
|
|
||||||
async def redundancy_check(self, critical_devices: List[str]) -> Dict:
|
|
||||||
"""检查关键设备的冗余配置"""
|
|
||||||
results = {}
|
|
||||||
topology = self.scanner.get_current_topology()
|
|
||||||
|
|
||||||
for device_ip in critical_devices:
|
|
||||||
# 检查是否有备用路径
|
|
||||||
try:
|
|
||||||
paths = list(nx.all_shortest_paths(
|
|
||||||
topology, source=device_ip, target="core_switch"))
|
|
||||||
|
|
||||||
if len(paths) > 1:
|
|
||||||
results[device_ip] = {
|
|
||||||
'status': 'redundant',
|
|
||||||
'path_count': len(paths)
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
results[device_ip] = {
|
|
||||||
'status': 'single_point_of_failure',
|
|
||||||
'recommendation': 'Add redundant links'
|
|
||||||
}
|
|
||||||
except:
|
|
||||||
results[device_ip] = {
|
|
||||||
'status': 'disconnected',
|
|
||||||
'recommendation': 'Check physical connection'
|
|
||||||
}
|
|
||||||
|
|
||||||
return results
|
|
@ -1,220 +0,0 @@
|
|||||||
import networkx as nx
|
|
||||||
import numpy as np
|
|
||||||
from scipy.optimize import linprog
|
|
||||||
from typing import Dict, List, Optional, Tuple, Any
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FlowDemand:
|
|
||||||
source: str
|
|
||||||
destination: str
|
|
||||||
bandwidth: float
|
|
||||||
priority: float = 1.0
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkOptimizer:
|
|
||||||
def __init__(self, devices: List[Dict[str, Any]]):
|
|
||||||
"""基于图论的网络优化模型"""
|
|
||||||
self.graph = self.build_topology_graph(devices)
|
|
||||||
self.traffic_matrix: Optional[np.ndarray] = None
|
|
||||||
self._initialize_capacities()
|
|
||||||
|
|
||||||
def _initialize_capacities(self) -> Dict[Tuple[str, str], float]:
|
|
||||||
"""初始化链路容量字典"""
|
|
||||||
self.remaining_capacity = {
|
|
||||||
(u, v): data['bandwidth']
|
|
||||||
for u, v, data in self.graph.edges(data=True)
|
|
||||||
}
|
|
||||||
return self.remaining_capacity
|
|
||||||
|
|
||||||
def build_topology_graph(self, devices: List[Dict[str, Any]]) -> nx.Graph:
|
|
||||||
"""构建带权重的网络拓扑图"""
|
|
||||||
G = nx.Graph()
|
|
||||||
|
|
||||||
for device in devices:
|
|
||||||
G.add_node(device['ip'],
|
|
||||||
type=device['type'],
|
|
||||||
capacity=device.get('capacity', 1000))
|
|
||||||
|
|
||||||
# 添加接口作为子节点
|
|
||||||
for interface in device.get('interfaces', []):
|
|
||||||
interface_id = f"{device['ip']}_{interface['name']}"
|
|
||||||
G.add_node(interface_id,
|
|
||||||
type='interface',
|
|
||||||
capacity=interface.get('bandwidth', 1000))
|
|
||||||
G.add_edge(device['ip'], interface_id,
|
|
||||||
bandwidth=interface.get('bandwidth', 1000),
|
|
||||||
latency=interface.get('latency', 1))
|
|
||||||
|
|
||||||
# 添加设备间连接
|
|
||||||
self._connect_devices(G, devices)
|
|
||||||
return G
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _connect_devices(graph: nx.Graph, devices: List[Dict[str, Any]]):
|
|
||||||
"""自动连接设备"""
|
|
||||||
for i in range(len(devices) - 1):
|
|
||||||
graph.add_edge(
|
|
||||||
devices[i]['ip'],
|
|
||||||
devices[i + 1]['ip'],
|
|
||||||
bandwidth=1000,
|
|
||||||
latency=5
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_flow_conservation(nodes: List[str]) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""构建流量守恒约束矩阵"""
|
|
||||||
num_nodes = len(nodes)
|
|
||||||
A_eq: List[np.ndarray] = [] # 明确指定列表元素类型
|
|
||||||
b_eq: List[float] = [] # 明确指定float列表
|
|
||||||
|
|
||||||
for i in range(num_nodes):
|
|
||||||
for j in range(num_nodes):
|
|
||||||
if i != j:
|
|
||||||
constraint = np.zeros((num_nodes, num_nodes))
|
|
||||||
constraint[i][j] = 1
|
|
||||||
constraint[j][i] = -1
|
|
||||||
A_eq.append(constraint.flatten())
|
|
||||||
b_eq.append(0.0) # 明确使用float类型
|
|
||||||
|
|
||||||
return np.array(A_eq, dtype=np.float64), np.array(b_eq, dtype=np.float64)
|
|
||||||
def optimize_bandwidth_allocation(
|
|
||||||
self,
|
|
||||||
demands: List[FlowDemand]
|
|
||||||
) -> Dict[str, Dict[str, float]]:
|
|
||||||
"""
|
|
||||||
基于线性规划的带宽分配优化
|
|
||||||
返回: {源IP: {目标IP: 分配带宽}}
|
|
||||||
"""
|
|
||||||
demand_dict = {(d.source, d.destination): float(d.bandwidth) for d in demands} # 确保float类型
|
|
||||||
return self._optimize_with_highs(demand_dict)
|
|
||||||
|
|
||||||
def _optimize_with_highs(self, demands: Dict[Tuple[str, str], float]) -> Dict[str, Dict[str, float]]:
|
|
||||||
"""使用HiGHS求解器实现"""
|
|
||||||
nodes = list(self.graph.nodes())
|
|
||||||
node_index = {node: i for i, node in enumerate(nodes)}
|
|
||||||
|
|
||||||
# 构建流量矩阵
|
|
||||||
self.traffic_matrix = np.zeros((len(nodes), len(nodes)))
|
|
||||||
for (src, dst), bw in demands.items():
|
|
||||||
if src in node_index and dst in node_index:
|
|
||||||
self.traffic_matrix[node_index[src]][node_index[dst]] = bw
|
|
||||||
|
|
||||||
# 构建约束
|
|
||||||
c = np.ones(len(nodes) ** 2) # 最小化总流量
|
|
||||||
A_ub, b_ub = self._build_capacity_constraints(nodes, node_index)
|
|
||||||
A_eq, b_eq = self._build_flow_conservation(nodes)
|
|
||||||
|
|
||||||
# 求解线性规划
|
|
||||||
res = linprog(
|
|
||||||
c=c,
|
|
||||||
A_ub=A_ub,
|
|
||||||
b_ub=b_ub,
|
|
||||||
A_eq=A_eq,
|
|
||||||
b_eq=b_eq,
|
|
||||||
bounds=(0, None),
|
|
||||||
method='highs'
|
|
||||||
)
|
|
||||||
|
|
||||||
if not res.success:
|
|
||||||
raise ValueError(f"Optimization failed: {res.message}")
|
|
||||||
|
|
||||||
return self._format_results(res.x, nodes, node_index)
|
|
||||||
|
|
||||||
def _build_capacity_constraints(self, nodes: List[str], node_index: Dict[str, int]) -> Tuple[
|
|
||||||
np.ndarray, np.ndarray]:
|
|
||||||
"""构建容量约束矩阵"""
|
|
||||||
A_ub = []
|
|
||||||
b_ub = []
|
|
||||||
|
|
||||||
for u, v, data in self.graph.edges(data=True):
|
|
||||||
capacity = float(data.get('bandwidth', 1000)) # 确保float类型
|
|
||||||
b_ub.append(capacity)
|
|
||||||
|
|
||||||
constraint = np.zeros((len(nodes), len(nodes)))
|
|
||||||
for src, dst in [(u, v), (v, u)]:
|
|
||||||
if src in node_index and dst in node_index:
|
|
||||||
i, j = node_index[src], node_index[dst]
|
|
||||||
constraint[i][j] += 1
|
|
||||||
A_ub.append(constraint.flatten())
|
|
||||||
|
|
||||||
return np.array(A_ub), np.array(b_ub)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_results(solution: np.ndarray, nodes: List[str], node_index: Dict[str, int]) -> Dict[
|
|
||||||
str, Dict[str, float]]:
|
|
||||||
"""格式化优化结果"""
|
|
||||||
flows = solution.reshape((len(nodes), len(nodes)))
|
|
||||||
return {
|
|
||||||
nodes[i]: {
|
|
||||||
nodes[j]: float(flows[i][j]) # 明确转换为float
|
|
||||||
for j in range(len(nodes))
|
|
||||||
if flows[i][j] > 0.001
|
|
||||||
}
|
|
||||||
for i in range(len(nodes))
|
|
||||||
}
|
|
||||||
|
|
||||||
def optimize_qos(self, flows: List[FlowDemand]) -> Dict[str, Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
带QoS的流量优化
|
|
||||||
返回: {
|
|
||||||
"源-目标": {
|
|
||||||
"path": 路径列表,
|
|
||||||
"allocated": 分配带宽,
|
|
||||||
"priority": 优先级
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
sorted_flows = sorted(flows, key=lambda x: x.priority, reverse=True)
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
# 使用实例变量而非局部变量
|
|
||||||
self._initialize_capacities() # 重置剩余带宽
|
|
||||||
|
|
||||||
for flow in sorted_flows:
|
|
||||||
path = self.find_optimal_path(flow.source, flow.destination, flow.bandwidth)
|
|
||||||
if not path:
|
|
||||||
continue
|
|
||||||
|
|
||||||
min_bw = min(self.graph[u][v]['bandwidth'] for u, v in zip(path[:-1], path[1:]))
|
|
||||||
allocated = min(min_bw, flow.bandwidth)
|
|
||||||
|
|
||||||
# 更新剩余带宽
|
|
||||||
for u, v in zip(path[:-1], path[1:]):
|
|
||||||
self.remaining_capacity[(u, v)] -= allocated
|
|
||||||
if self.remaining_capacity[(u, v)] <= 0:
|
|
||||||
self.graph[u][v]['bandwidth'] = 0
|
|
||||||
|
|
||||||
results[f"{flow.source}-{flow.destination}"] = {
|
|
||||||
"path": path,
|
|
||||||
"allocated": float(allocated), # 确保float类型
|
|
||||||
"priority": float(flow.priority)
|
|
||||||
}
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def find_optimal_path(self, source: str, target: str,
|
|
||||||
bandwidth: float = 1.0) -> Optional[List[str]]:
|
|
||||||
"""
|
|
||||||
改进的最优路径查找
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
paths = nx.shortest_simple_paths(
|
|
||||||
self.graph, source, target,
|
|
||||||
weight=lambda u, v, d: 1 / max(1, d['bandwidth']) + d['latency'] # 避免除零
|
|
||||||
)
|
|
||||||
return next(
|
|
||||||
(path for path in paths
|
|
||||||
if self._path_has_sufficient_bandwidth(path, bandwidth)),
|
|
||||||
None
|
|
||||||
)
|
|
||||||
except (nx.NetworkXNoPath, nx.NodeNotFound):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _path_has_sufficient_bandwidth(self, path: List[str], bw: float) -> bool:
|
|
||||||
"""检查路径带宽是否满足要求"""
|
|
||||||
return all(
|
|
||||||
self.graph[u][v]['bandwidth'] >= bw
|
|
||||||
for u, v in zip(path[:-1], path[1:])
|
|
||||||
)
|
|
@ -1,104 +0,0 @@
|
|||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
import networkx as nx
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from io import BytesIO
|
|
||||||
import base64
|
|
||||||
from typing import Dict, List
|
|
||||||
from ..models.traffic_models import SwitchTrafficRecord
|
|
||||||
from src.backend.app.api.database import SessionLocal
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkVisualizer:
|
|
||||||
def __init__(self):
|
|
||||||
self.graph = nx.Graph()
|
|
||||||
|
|
||||||
def update_topology(self, devices: List[Dict]):
|
|
||||||
"""更新网络拓扑图"""
|
|
||||||
self.graph.clear()
|
|
||||||
|
|
||||||
# 添加节点
|
|
||||||
for device in devices:
|
|
||||||
self.graph.add_node(
|
|
||||||
device['ip'],
|
|
||||||
type='switch',
|
|
||||||
label=f"Switch\n{device['ip']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加连接(简化版,实际应根据扫描结果)
|
|
||||||
if len(devices) > 1:
|
|
||||||
for i in range(len(devices) - 1):
|
|
||||||
self.graph.add_edge(
|
|
||||||
devices[i]['ip'],
|
|
||||||
devices[i + 1]['ip'],
|
|
||||||
bandwidth=1000,
|
|
||||||
label="1Gbps"
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_topology_image(self) -> str:
|
|
||||||
"""生成拓扑图并返回base64编码"""
|
|
||||||
plt.figure(figsize=(10, 8))
|
|
||||||
pos = nx.spring_layout(self.graph)
|
|
||||||
|
|
||||||
# 绘制节点
|
|
||||||
node_colors = []
|
|
||||||
for node in self.graph.nodes():
|
|
||||||
if self.graph.nodes[node]['type'] == 'switch':
|
|
||||||
node_colors.append('lightblue')
|
|
||||||
|
|
||||||
nx.draw_networkx_nodes(
|
|
||||||
self.graph, pos,
|
|
||||||
node_size=2000,
|
|
||||||
node_color=node_colors
|
|
||||||
)
|
|
||||||
|
|
||||||
# 绘制边
|
|
||||||
nx.draw_networkx_edges(
|
|
||||||
self.graph, pos,
|
|
||||||
width=2,
|
|
||||||
alpha=0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
# 绘制标签
|
|
||||||
node_labels = nx.get_node_attributes(self.graph, 'label')
|
|
||||||
nx.draw_networkx_labels(
|
|
||||||
self.graph, pos,
|
|
||||||
labels=node_labels,
|
|
||||||
font_size=8
|
|
||||||
)
|
|
||||||
|
|
||||||
edge_labels = nx.get_edge_attributes(self.graph, 'label')
|
|
||||||
nx.draw_networkx_edge_labels(
|
|
||||||
self.graph, pos,
|
|
||||||
edge_labels=edge_labels,
|
|
||||||
font_size=8
|
|
||||||
)
|
|
||||||
|
|
||||||
plt.title("Network Topology")
|
|
||||||
plt.axis('off')
|
|
||||||
|
|
||||||
# 转换为base64
|
|
||||||
buf = BytesIO()
|
|
||||||
plt.savefig(buf, format='png', bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
buf.seek(0)
|
|
||||||
return base64.b64encode(buf.read()).decode('utf-8')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_traffic_heatmap(minutes: int = 10) -> Dict:
|
|
||||||
"""获取流量热力图数据"""
|
|
||||||
with SessionLocal() as session:
|
|
||||||
records = session.query(SwitchTrafficRecord).filter(
|
|
||||||
SwitchTrafficRecord.timestamp >= datetime.now() - timedelta(minutes=minutes)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
heatmap_data = {}
|
|
||||||
for record in records:
|
|
||||||
if record.switch_ip not in heatmap_data:
|
|
||||||
heatmap_data[record.switch_ip] = {}
|
|
||||||
heatmap_data[record.switch_ip][record.interface] = {
|
|
||||||
'in': record.rate_in,
|
|
||||||
'out': record.rate_out
|
|
||||||
}
|
|
||||||
|
|
||||||
return heatmap_data
|
|
@ -1,120 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from io import BytesIO
|
|
||||||
import base64
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Dict, List
|
|
||||||
from ..models.traffic_models import TrafficRecord, SwitchTrafficRecord
|
|
||||||
from src.backend.app.api.database import SessionLocal
|
|
||||||
|
|
||||||
|
|
||||||
class ReportGenerator:
|
|
||||||
@staticmethod
|
|
||||||
def generate_traffic_report(ip: str = None, days: int = 7) -> Dict:
|
|
||||||
"""生成流量分析报告"""
|
|
||||||
with SessionLocal() as session:
|
|
||||||
time_threshold = datetime.now() - timedelta(days=days)
|
|
||||||
|
|
||||||
if ip:
|
|
||||||
# 交换机流量报告
|
|
||||||
query = session.query(SwitchTrafficRecord).filter(
|
|
||||||
SwitchTrafficRecord.switch_ip == ip,
|
|
||||||
SwitchTrafficRecord.timestamp >= time_threshold
|
|
||||||
)
|
|
||||||
df = pd.read_sql(query.statement, session.bind)
|
|
||||||
|
|
||||||
if df.empty:
|
|
||||||
return {"error": "No data found"}
|
|
||||||
|
|
||||||
# 按接口分组分析
|
|
||||||
interface_stats = df.groupby('interface').agg({
|
|
||||||
'rate_in': ['mean', 'max', 'min'],
|
|
||||||
'rate_out': ['mean', 'max', 'min']
|
|
||||||
}).reset_index()
|
|
||||||
|
|
||||||
# 生成趋势图
|
|
||||||
trend_fig = ReportGenerator._plot_traffic_trend(df, f"Switch {ip} Traffic Trend")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"switch_ip": ip,
|
|
||||||
"period": f"Last {days} days",
|
|
||||||
"interface_stats": interface_stats.to_dict('records'),
|
|
||||||
"trend_chart": trend_fig
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# 本地网络流量报告
|
|
||||||
query = session.query(TrafficRecord).filter(
|
|
||||||
TrafficRecord.timestamp >= time_threshold
|
|
||||||
)
|
|
||||||
df = pd.read_sql(query.statement, session.bind)
|
|
||||||
|
|
||||||
if df.empty:
|
|
||||||
return {"error": "No data found"}
|
|
||||||
|
|
||||||
# 按接口分组分析
|
|
||||||
interface_stats = df.groupby('interface').agg({
|
|
||||||
'bytes_sent': ['sum', 'mean', 'max'],
|
|
||||||
'bytes_recv': ['sum', 'mean', 'max']
|
|
||||||
}).reset_index()
|
|
||||||
|
|
||||||
# 生成趋势图
|
|
||||||
trend_fig = ReportGenerator._plot_traffic_trend(df, "Local Network Traffic Trend")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"report_type": "local_network",
|
|
||||||
"period": f"Last {days} days",
|
|
||||||
"interface_stats": interface_stats.to_dict('records'),
|
|
||||||
"trend_chart": trend_fig
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _plot_traffic_trend(df: pd.DataFrame, title: str) -> str:
|
|
||||||
"""生成流量趋势图"""
|
|
||||||
plt.figure(figsize=(12, 6))
|
|
||||||
|
|
||||||
if 'rate_in' in df.columns: # 交换机数据
|
|
||||||
df_grouped = df.groupby(pd.Grouper(key='timestamp', freq='1H')).agg({
|
|
||||||
'rate_in': 'mean',
|
|
||||||
'rate_out': 'mean'
|
|
||||||
})
|
|
||||||
plt.plot(df_grouped.index, df_grouped['rate_in'], label='Inbound Traffic')
|
|
||||||
plt.plot(df_grouped.index, df_grouped['rate_out'], label='Outbound Traffic')
|
|
||||||
plt.ylabel("Traffic Rate (bytes/sec)")
|
|
||||||
else: # 本地网络数据
|
|
||||||
df_grouped = df.groupby(pd.Grouper(key='timestamp', freq='1H')).agg({
|
|
||||||
'bytes_sent': 'sum',
|
|
||||||
'bytes_recv': 'sum'
|
|
||||||
})
|
|
||||||
plt.plot(df_grouped.index, df_grouped['bytes_sent'], label='Bytes Sent')
|
|
||||||
plt.plot(df_grouped.index, df_grouped['bytes_recv'], label='Bytes Received')
|
|
||||||
plt.ylabel("Bytes")
|
|
||||||
|
|
||||||
plt.title(title)
|
|
||||||
plt.xlabel("Time")
|
|
||||||
plt.legend()
|
|
||||||
plt.grid(True)
|
|
||||||
|
|
||||||
# 转换为base64
|
|
||||||
buf = BytesIO()
|
|
||||||
plt.savefig(buf, format='png', bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
buf.seek(0)
|
|
||||||
return base64.b64encode(buf.read()).decode('utf-8')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_config_history_report(days: int = 30) -> Dict:
|
|
||||||
"""生成配置变更历史报告"""
|
|
||||||
# 需要实现配置历史记录功能
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_security_report() -> Dict:
|
|
||||||
"""生成安全评估报告"""
|
|
||||||
# 需要实现安全扫描功能
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_performance_report() -> Dict:
|
|
||||||
"""生成性能评估报告"""
|
|
||||||
# 需要实现性能基准测试功能
|
|
||||||
pass
|
|
Loading…
x
Reference in New Issue
Block a user