mirror of
https://github.com/Jerryplusy/AI-powered-switches.git
synced 2025-10-14 09:49:19 +00:00
220 lines
7.8 KiB
Python
220 lines
7.8 KiB
Python
import networkx as nx
|
|
import numpy as np
|
|
from scipy.optimize import linprog
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class FlowDemand:
|
|
source: str
|
|
destination: str
|
|
bandwidth: float
|
|
priority: float = 1.0
|
|
|
|
|
|
class NetworkOptimizer:
|
|
def __init__(self, devices: List[Dict[str, Any]]):
|
|
"""基于图论的网络优化模型"""
|
|
self.graph = self.build_topology_graph(devices)
|
|
self.traffic_matrix: Optional[np.ndarray] = None
|
|
self._initialize_capacities()
|
|
|
|
def _initialize_capacities(self) -> Dict[Tuple[str, str], float]:
|
|
"""初始化链路容量字典"""
|
|
self.remaining_capacity = {
|
|
(u, v): data['bandwidth']
|
|
for u, v, data in self.graph.edges(data=True)
|
|
}
|
|
return self.remaining_capacity
|
|
|
|
def build_topology_graph(self, devices: List[Dict[str, Any]]) -> nx.Graph:
|
|
"""构建带权重的网络拓扑图"""
|
|
G = nx.Graph()
|
|
|
|
for device in devices:
|
|
G.add_node(device['ip'],
|
|
type=device['type'],
|
|
capacity=device.get('capacity', 1000))
|
|
|
|
# 添加接口作为子节点
|
|
for interface in device.get('interfaces', []):
|
|
interface_id = f"{device['ip']}_{interface['name']}"
|
|
G.add_node(interface_id,
|
|
type='interface',
|
|
capacity=interface.get('bandwidth', 1000))
|
|
G.add_edge(device['ip'], interface_id,
|
|
bandwidth=interface.get('bandwidth', 1000),
|
|
latency=interface.get('latency', 1))
|
|
|
|
# 添加设备间连接
|
|
self._connect_devices(G, devices)
|
|
return G
|
|
|
|
@staticmethod
|
|
def _connect_devices(graph: nx.Graph, devices: List[Dict[str, Any]]):
|
|
"""自动连接设备"""
|
|
for i in range(len(devices) - 1):
|
|
graph.add_edge(
|
|
devices[i]['ip'],
|
|
devices[i + 1]['ip'],
|
|
bandwidth=1000,
|
|
latency=5
|
|
)
|
|
|
|
@staticmethod
|
|
def _build_flow_conservation(nodes: List[str]) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""构建流量守恒约束矩阵"""
|
|
num_nodes = len(nodes)
|
|
A_eq: List[np.ndarray] = [] # 明确指定列表元素类型
|
|
b_eq: List[float] = [] # 明确指定float列表
|
|
|
|
for i in range(num_nodes):
|
|
for j in range(num_nodes):
|
|
if i != j:
|
|
constraint = np.zeros((num_nodes, num_nodes))
|
|
constraint[i][j] = 1
|
|
constraint[j][i] = -1
|
|
A_eq.append(constraint.flatten())
|
|
b_eq.append(0.0) # 明确使用float类型
|
|
|
|
return np.array(A_eq, dtype=np.float64), np.array(b_eq, dtype=np.float64)
|
|
def optimize_bandwidth_allocation(
|
|
self,
|
|
demands: List[FlowDemand]
|
|
) -> Dict[str, Dict[str, float]]:
|
|
"""
|
|
基于线性规划的带宽分配优化
|
|
返回: {源IP: {目标IP: 分配带宽}}
|
|
"""
|
|
demand_dict = {(d.source, d.destination): float(d.bandwidth) for d in demands} # 确保float类型
|
|
return self._optimize_with_highs(demand_dict)
|
|
|
|
def _optimize_with_highs(self, demands: Dict[Tuple[str, str], float]) -> Dict[str, Dict[str, float]]:
|
|
"""使用HiGHS求解器实现"""
|
|
nodes = list(self.graph.nodes())
|
|
node_index = {node: i for i, node in enumerate(nodes)}
|
|
|
|
# 构建流量矩阵
|
|
self.traffic_matrix = np.zeros((len(nodes), len(nodes)))
|
|
for (src, dst), bw in demands.items():
|
|
if src in node_index and dst in node_index:
|
|
self.traffic_matrix[node_index[src]][node_index[dst]] = bw
|
|
|
|
# 构建约束
|
|
c = np.ones(len(nodes) ** 2) # 最小化总流量
|
|
A_ub, b_ub = self._build_capacity_constraints(nodes, node_index)
|
|
A_eq, b_eq = self._build_flow_conservation(nodes)
|
|
|
|
# 求解线性规划
|
|
res = linprog(
|
|
c=c,
|
|
A_ub=A_ub,
|
|
b_ub=b_ub,
|
|
A_eq=A_eq,
|
|
b_eq=b_eq,
|
|
bounds=(0, None),
|
|
method='highs'
|
|
)
|
|
|
|
if not res.success:
|
|
raise ValueError(f"Optimization failed: {res.message}")
|
|
|
|
return self._format_results(res.x, nodes, node_index)
|
|
|
|
def _build_capacity_constraints(self, nodes: List[str], node_index: Dict[str, int]) -> Tuple[
|
|
np.ndarray, np.ndarray]:
|
|
"""构建容量约束矩阵"""
|
|
A_ub = []
|
|
b_ub = []
|
|
|
|
for u, v, data in self.graph.edges(data=True):
|
|
capacity = float(data.get('bandwidth', 1000)) # 确保float类型
|
|
b_ub.append(capacity)
|
|
|
|
constraint = np.zeros((len(nodes), len(nodes)))
|
|
for src, dst in [(u, v), (v, u)]:
|
|
if src in node_index and dst in node_index:
|
|
i, j = node_index[src], node_index[dst]
|
|
constraint[i][j] += 1
|
|
A_ub.append(constraint.flatten())
|
|
|
|
return np.array(A_ub), np.array(b_ub)
|
|
|
|
@staticmethod
|
|
def _format_results(solution: np.ndarray, nodes: List[str], node_index: Dict[str, int]) -> Dict[
|
|
str, Dict[str, float]]:
|
|
"""格式化优化结果"""
|
|
flows = solution.reshape((len(nodes), len(nodes)))
|
|
return {
|
|
nodes[i]: {
|
|
nodes[j]: float(flows[i][j]) # 明确转换为float
|
|
for j in range(len(nodes))
|
|
if flows[i][j] > 0.001
|
|
}
|
|
for i in range(len(nodes))
|
|
}
|
|
|
|
def optimize_qos(self, flows: List[FlowDemand]) -> Dict[str, Dict[str, Any]]:
|
|
"""
|
|
带QoS的流量优化
|
|
返回: {
|
|
"源-目标": {
|
|
"path": 路径列表,
|
|
"allocated": 分配带宽,
|
|
"priority": 优先级
|
|
}
|
|
}
|
|
"""
|
|
sorted_flows = sorted(flows, key=lambda x: x.priority, reverse=True)
|
|
results = {}
|
|
|
|
# 使用实例变量而非局部变量
|
|
self._initialize_capacities() # 重置剩余带宽
|
|
|
|
for flow in sorted_flows:
|
|
path = self.find_optimal_path(flow.source, flow.destination, flow.bandwidth)
|
|
if not path:
|
|
continue
|
|
|
|
min_bw = min(self.graph[u][v]['bandwidth'] for u, v in zip(path[:-1], path[1:]))
|
|
allocated = min(min_bw, flow.bandwidth)
|
|
|
|
# 更新剩余带宽
|
|
for u, v in zip(path[:-1], path[1:]):
|
|
self.remaining_capacity[(u, v)] -= allocated
|
|
if self.remaining_capacity[(u, v)] <= 0:
|
|
self.graph[u][v]['bandwidth'] = 0
|
|
|
|
results[f"{flow.source}-{flow.destination}"] = {
|
|
"path": path,
|
|
"allocated": float(allocated), # 确保float类型
|
|
"priority": float(flow.priority)
|
|
}
|
|
|
|
return results
|
|
|
|
def find_optimal_path(self, source: str, target: str,
|
|
bandwidth: float = 1.0) -> Optional[List[str]]:
|
|
"""
|
|
改进的最优路径查找
|
|
"""
|
|
try:
|
|
paths = nx.shortest_simple_paths(
|
|
self.graph, source, target,
|
|
weight=lambda u, v, d: 1 / max(1, d['bandwidth']) + d['latency'] # 避免除零
|
|
)
|
|
return next(
|
|
(path for path in paths
|
|
if self._path_has_sufficient_bandwidth(path, bandwidth)),
|
|
None
|
|
)
|
|
except (nx.NetworkXNoPath, nx.NodeNotFound):
|
|
return None
|
|
|
|
def _path_has_sufficient_bandwidth(self, path: List[str], bw: float) -> bool:
|
|
"""检查路径带宽是否满足要求"""
|
|
return all(
|
|
self.graph[u][v]['bandwidth'] >= bw
|
|
for u, v in zip(path[:-1], path[1:])
|
|
) |