mirror of
https://github.com/Jerryplusy/AI-powered-switches.git
synced 2025-10-14 09:49:19 +00:00
Compare commits
No commits in common. "29f016bdab96a77548718dffe58c4e5884334a89" and "0b6b9624a64907eb160232e9df72f7aaf9872f14" have entirely different histories.
29f016bdab
...
0b6b9624a6
@ -1,13 +1,30 @@
|
||||
import socket
|
||||
from fastapi import (APIRouter, HTTPException, Response)
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
import asyncio
|
||||
from fastapi.responses import HTMLResponse
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
import base64
|
||||
import psutil
|
||||
import ipaddress
|
||||
|
||||
from ..services.switch_traffic_monitor import get_switch_monitor
|
||||
from ..utils import logger
|
||||
from ...app.services.ai_services import AIService
|
||||
from ...app.api.network_config import SwitchConfigurator
|
||||
from ...config import settings
|
||||
from ..services.network_scanner import NetworkScanner
|
||||
from ...app.services.traffic_monitor import traffic_monitor
|
||||
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
from ..services.network_visualizer import NetworkVisualizer
|
||||
from ..services.config_validator import ConfigValidator
|
||||
from ..services.report_generator import ReportGenerator
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
|
||||
|
||||
@ -42,6 +59,22 @@ class BatchConfigRequest(BaseModel):
|
||||
password: str = None
|
||||
timeout: int = None
|
||||
|
||||
@router.post("/batch_apply_config")
|
||||
async def batch_apply_config(request: BatchConfigRequest):
|
||||
results = {}
|
||||
for ip in request.switch_ips:
|
||||
try:
|
||||
configurator = SwitchConfigurator(
|
||||
username=request.username,
|
||||
password=request.password,
|
||||
timeout=request.timeout )
|
||||
results[ip] = await configurator.apply_config(ip, request.config)
|
||||
except Exception as e:
|
||||
results[ip] = str(e)
|
||||
return {"results": results}
|
||||
|
||||
|
||||
|
||||
@router.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "Hello World"}
|
||||
@ -163,6 +196,47 @@ async def execute_cli_commands(request: CLICommandRequest):
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
@router.get("/traffic/interfaces", summary="获取所有网络接口")
|
||||
async def get_network_interfaces():
|
||||
return {
|
||||
"interfaces": traffic_monitor.get_interfaces()
|
||||
}
|
||||
|
||||
@router.get("/traffic/current", summary="获取当前流量数据")
|
||||
async def get_current_traffic(interface: str = None):
|
||||
return traffic_monitor.get_current_traffic(interface)
|
||||
|
||||
@router.get("/traffic/history", summary="获取流量历史数据")
|
||||
async def get_traffic_history(interface: str = None, limit: int = 100):
|
||||
history = traffic_monitor.get_traffic_history(interface)
|
||||
return {
|
||||
"sent": history["sent"][-limit:],
|
||||
"recv": history["recv"][-limit:],
|
||||
"time": [t.isoformat() for t in history["time"]][-limit:]
|
||||
}
|
||||
|
||||
@router.get("/traffic/records", summary="获取流量记录")
|
||||
async def get_traffic_records(interface: str = None, limit: int = 100):
|
||||
with SessionLocal() as session:
|
||||
query = session.query(TrafficRecord)
|
||||
if interface:
|
||||
query = query.filter(TrafficRecord.interface == interface)
|
||||
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
|
||||
return [record.to_dict() for record in records]
|
||||
|
||||
@router.websocket("/ws/traffic")
|
||||
async def websocket_traffic(websocket: WebSocket):
|
||||
"""实时流量WebSocket"""
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
traffic_data = traffic_monitor.get_current_traffic()
|
||||
await websocket.send_json(traffic_data)
|
||||
await asyncio.sleep(1)
|
||||
except WebSocketDisconnect:
|
||||
print("客户端断开连接")
|
||||
|
||||
|
||||
@router.get("/", include_in_schema=False)
|
||||
async def root():
|
||||
return {
|
||||
@ -179,6 +253,195 @@ async def root():
|
||||
"/traffic/switch/history"
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
|
||||
async def get_switch_interfaces(switch_ip: str):
|
||||
"""获取指定交换机的所有接口"""
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
interfaces = list(monitor.interface_oids.keys())
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interfaces": interfaces
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取交换机接口失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取接口失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
|
||||
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
|
||||
"""获取交换机的当前流量数据"""
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
|
||||
|
||||
if not interface:
|
||||
traffic_data = {}
|
||||
for iface in monitor.interface_oids:
|
||||
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"traffic": traffic_data
|
||||
}
|
||||
|
||||
return await get_interface_current_traffic(switch_ip, interface)
|
||||
except Exception as e:
|
||||
logger.error(f"获取交换机流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取流量失败: {str(e)}")
|
||||
|
||||
|
||||
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
|
||||
"""获取指定交换机接口的当前流量数据"""
|
||||
try:
|
||||
with SessionLocal() as session:
|
||||
|
||||
record = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.switch_ip == switch_ip,
|
||||
SwitchTrafficRecord.interface == interface
|
||||
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
|
||||
|
||||
if not record:
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"rate_in": 0.0,
|
||||
"rate_out": 0.0,
|
||||
"bytes_in": 0,
|
||||
"bytes_out": 0
|
||||
}
|
||||
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"rate_in": record.rate_in,
|
||||
"rate_out": record.rate_out,
|
||||
"bytes_in": record.bytes_in,
|
||||
"bytes_out": record.bytes_out
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取接口流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
|
||||
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
|
||||
"""获取交换机的流量历史数据"""
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
|
||||
if not interface:
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"history": monitor.get_traffic_history()
|
||||
}
|
||||
|
||||
with SessionLocal() as session:
|
||||
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
||||
|
||||
records = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.switch_ip == switch_ip,
|
||||
SwitchTrafficRecord.interface == interface,
|
||||
SwitchTrafficRecord.timestamp >= time_threshold
|
||||
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
|
||||
|
||||
history_data = {
|
||||
"in": [record.rate_in for record in records],
|
||||
"out": [record.rate_out for record in records],
|
||||
"time": [record.timestamp.isoformat() for record in records]
|
||||
}
|
||||
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"history": history_data
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取历史流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
|
||||
|
||||
|
||||
@router.websocket("/ws/traffic/switch")
|
||||
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
|
||||
"""交换机实时流量WebSocket"""
|
||||
await websocket.accept()
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
|
||||
while True:
|
||||
if interface:
|
||||
traffic_data = await get_interface_current_traffic(switch_ip, interface)
|
||||
await websocket.send_json(traffic_data)
|
||||
else:
|
||||
traffic_data = {}
|
||||
for iface in monitor.interface_oids:
|
||||
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
|
||||
|
||||
await websocket.send_json({
|
||||
"switch_ip": switch_ip,
|
||||
"traffic": traffic_data
|
||||
})
|
||||
|
||||
await asyncio.sleep(1)
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
|
||||
except Exception as e:
|
||||
logger.error(f"交换机流量WebSocket错误: {str(e)}")
|
||||
await websocket.close(code=1011, reason=str(e))
|
||||
|
||||
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
|
||||
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
|
||||
"""生成交换机流量图表"""
|
||||
try:
|
||||
history = await get_switch_traffic_history(switch_ip, interface, minutes)
|
||||
history_data = history["history"]
|
||||
|
||||
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
|
||||
in_rates = history_data["in"]
|
||||
out_rates = history_data["out"]
|
||||
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
|
||||
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
|
||||
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
|
||||
plt.xlabel("时间")
|
||||
plt.ylabel("流量 (字节/秒)")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
|
||||
plt.close()
|
||||
|
||||
return f"""
|
||||
<html>
|
||||
<head>
|
||||
<title>交换机流量监控</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
||||
.container {{ max-width: 900px; margin: 0 auto; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
|
||||
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
|
||||
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.error(f"生成流量图表失败: {str(e)}")
|
||||
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)
|
||||
|
||||
|
||||
@router.get("/network_adapters", summary="获取网络适配器网段")
|
||||
async def get_network_adapters():
|
||||
try:
|
||||
@ -202,4 +465,69 @@ async def get_network_adapters():
|
||||
return {"networks": networks}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"获取网络适配器信息失败: {str(e)}"}
|
||||
return {"error": f"获取网络适配器信息失败: {str(e)}"}
|
||||
|
||||
|
||||
visualizer = NetworkVisualizer()
|
||||
report_gen = ReportGenerator()
|
||||
|
||||
@router.get("/topology/visualize", response_class=HTMLResponse)
|
||||
async def visualize_topology():
|
||||
"""获取网络拓扑可视化图"""
|
||||
try:
|
||||
devices = await list_devices() # 复用现有的设备列表接口
|
||||
visualizer.update_topology(devices["devices"])
|
||||
image_data = visualizer.generate_topology_image()
|
||||
|
||||
return f"""
|
||||
<html>
|
||||
<head><title>Network Topology</title></head>
|
||||
<body>
|
||||
<h1>Network Topology</h1>
|
||||
<img src="data:image/png;base64,{image_data}" alt="Network Topology">
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/config/validate")
|
||||
async def validate_config(config: dict):
|
||||
"""验证配置有效性"""
|
||||
is_valid, errors = ConfigValidator.validate_full_config(config)
|
||||
return {
|
||||
"valid": is_valid,
|
||||
"errors": errors,
|
||||
"has_security_risks": len(ConfigValidator.check_security_risks(config.get("commands", []))) > 0
|
||||
}
|
||||
|
||||
|
||||
@router.get("/reports/traffic/{ip}")
|
||||
async def get_traffic_report(ip: str, days: int = 1):
|
||||
"""获取流量分析报告"""
|
||||
try:
|
||||
report = report_gen.generate_traffic_report(ip, days)
|
||||
return JSONResponse(content=report)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/reports/traffic")
|
||||
async def get_local_traffic_report(days: int = 1):
|
||||
"""获取本地网络流量报告"""
|
||||
try:
|
||||
report = report_gen.generate_traffic_report(days=days)
|
||||
return JSONResponse(content=report)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/topology/traffic_heatmap")
|
||||
async def get_traffic_heatmap(minutes: int = 10):
|
||||
"""获取流量热力图数据"""
|
||||
try:
|
||||
heatmap = visualizer.get_traffic_heatmap(minutes)
|
||||
return {"heatmap": heatmap}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
@ -42,7 +42,7 @@ class AIService:
|
||||
]
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
response = await self.client.chat.completions.create(
|
||||
model="deepseek-ai/DeepSeek-V3",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
|
85
src/backend/app/services/config_validator.py
Normal file
85
src/backend/app/services/config_validator.py
Normal file
@ -0,0 +1,85 @@
|
||||
import re
|
||||
from typing import Dict, List, Tuple
|
||||
from ..utils.exceptions import SwitchConfigException
|
||||
|
||||
|
||||
class ConfigValidator:
|
||||
@staticmethod
|
||||
def validate_vlan_config(config: Dict) -> Tuple[bool, str]:
|
||||
"""验证VLAN配置"""
|
||||
if 'vlan_id' not in config:
|
||||
return False, "Missing VLAN ID"
|
||||
|
||||
vlan_id = config['vlan_id']
|
||||
if not (1 <= vlan_id <= 4094):
|
||||
return False, f"Invalid VLAN ID {vlan_id}. Must be 1-4094"
|
||||
|
||||
if 'name' in config and len(config['name']) > 32:
|
||||
return False, "VLAN name too long (max 32 chars)"
|
||||
|
||||
return True, "Valid VLAN config"
|
||||
|
||||
@staticmethod
|
||||
def validate_interface_config(config: Dict) -> Tuple[bool, str]:
|
||||
"""验证接口配置"""
|
||||
required_fields = ['interface', 'ip_address']
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
return False, f"Missing required field: {field}"
|
||||
|
||||
# 验证IP地址格式
|
||||
ip_pattern = r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$'
|
||||
if not re.match(ip_pattern, config['ip_address']):
|
||||
return False, "Invalid IP address format"
|
||||
|
||||
# 验证接口名称格式
|
||||
interface_pattern = r'^(GigabitEthernet|FastEthernet|Eth)\d+/\d+/\d+$'
|
||||
if not re.match(interface_pattern, config['interface']):
|
||||
return False, "Invalid interface name format"
|
||||
|
||||
return True, "Valid interface config"
|
||||
|
||||
@staticmethod
|
||||
def check_security_risks(commands: List[str]) -> List[str]:
|
||||
"""检查潜在安全风险"""
|
||||
risky_commands = []
|
||||
dangerous_patterns = [
|
||||
r'no\s+aaa', # 禁用认证
|
||||
r'enable\s+password', # 明文密码
|
||||
r'service\s+password-encryption', # 弱加密
|
||||
r'ip\s+http\s+server', # 启用HTTP服务
|
||||
r'no\s+ip\s+http\s+secure-server' # 禁用HTTPS
|
||||
]
|
||||
|
||||
for cmd in commands:
|
||||
for pattern in dangerous_patterns:
|
||||
if re.search(pattern, cmd, re.IGNORECASE):
|
||||
risky_commands.append(cmd)
|
||||
break
|
||||
|
||||
return risky_commands
|
||||
|
||||
@staticmethod
|
||||
def validate_full_config(config: Dict) -> Tuple[bool, List[str]]:
|
||||
"""全面验证配置"""
|
||||
errors = []
|
||||
|
||||
if 'type' not in config:
|
||||
errors.append("Missing configuration type")
|
||||
return False, errors
|
||||
|
||||
if config['type'] == 'vlan':
|
||||
valid, msg = ConfigValidator.validate_vlan_config(config)
|
||||
if not valid:
|
||||
errors.append(msg)
|
||||
elif config['type'] == 'interface':
|
||||
valid, msg = ConfigValidator.validate_interface_config(config)
|
||||
if not valid:
|
||||
errors.append(msg)
|
||||
|
||||
if 'commands' in config:
|
||||
risks = ConfigValidator.check_security_risks(config['commands'])
|
||||
if risks:
|
||||
errors.append(f"Potential security risks detected: {', '.join(risks)}")
|
||||
|
||||
return len(errors) == 0, errors
|
129
src/backend/app/services/failover_manager.py
Normal file
129
src/backend/app/services/failover_manager.py
Normal file
@ -0,0 +1,129 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..models.traffic_models import SwitchTrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
from ..utils.exceptions import SwitchConfigException
|
||||
from ..services.network_scanner import NetworkScanner
|
||||
|
||||
|
||||
class FailoverManager:
|
||||
def __init__(self, config_backup_dir: str = "config_backups"):
|
||||
self.scanner = NetworkScanner()
|
||||
self.backup_dir = Path(config_backup_dir)
|
||||
self.backup_dir.mkdir(exist_ok=True)
|
||||
|
||||
async def detect_failures(self, threshold: float = 0.9) -> List[Dict]:
|
||||
"""检测可能的网络故障"""
|
||||
with SessionLocal() as session:
|
||||
# 获取最近5分钟的所有交换机流量记录
|
||||
records = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.timestamp >= datetime.now() - timedelta(minutes=5)
|
||||
).all()
|
||||
|
||||
# 分析流量异常
|
||||
abnormal_devices = []
|
||||
device_stats = {}
|
||||
|
||||
for record in records:
|
||||
if record.switch_ip not in device_stats:
|
||||
device_stats[record.switch_ip] = {
|
||||
'interfaces': {},
|
||||
'max_in': 0,
|
||||
'max_out': 0
|
||||
}
|
||||
|
||||
if record.interface not in device_stats[record.switch_ip]['interfaces']:
|
||||
device_stats[record.switch_ip]['interfaces'][record.interface] = {
|
||||
'in': [],
|
||||
'out': []
|
||||
}
|
||||
|
||||
device_stats[record.switch_ip]['interfaces'][record.interface]['in'].append(record.rate_in)
|
||||
device_stats[record.switch_ip]['interfaces'][record.interface]['out'].append(record.rate_out)
|
||||
|
||||
# 更新设备最大流量
|
||||
if record.rate_in > device_stats[record.switch_ip]['max_in']:
|
||||
device_stats[record.switch_ip]['max_in'] = record.rate_in
|
||||
if record.rate_out > device_stats[record.switch_ip]['max_out']:
|
||||
device_stats[record.switch_ip]['max_out'] = record.rate_out
|
||||
|
||||
# 检测异常
|
||||
for ip, stats in device_stats.items():
|
||||
for iface, traffic in stats['interfaces'].items():
|
||||
avg_in = sum(traffic['in']) / len(traffic['in'])
|
||||
avg_out = sum(traffic['out']) / len(traffic['out'])
|
||||
|
||||
# 如果平均流量超过最大流量的90%,认为可能有问题
|
||||
if avg_in > stats['max_in'] * threshold or \
|
||||
avg_out > stats['max_out'] * threshold:
|
||||
abnormal_devices.append({
|
||||
'ip': ip,
|
||||
'interface': iface,
|
||||
'avg_in': avg_in,
|
||||
'avg_out': avg_out,
|
||||
'max_in': stats['max_in'],
|
||||
'max_out': stats['max_out']
|
||||
})
|
||||
|
||||
return abnormal_devices
|
||||
|
||||
async def automatic_recovery(self, failed_device_ip: str) -> bool:
|
||||
"""自动故障恢复"""
|
||||
try:
|
||||
# 1. 检查设备是否在线
|
||||
devices = self.scanner.scan_subnet(failed_device_ip + "/32")
|
||||
if not devices:
|
||||
raise SwitchConfigException(f"Device {failed_device_ip} is offline")
|
||||
|
||||
# 2. 查找最近的配置备份
|
||||
backup_files = sorted(self.backup_dir.glob(f"{failed_device_ip}_*.cfg"))
|
||||
if not backup_files:
|
||||
raise SwitchConfigException(f"No backup found for {failed_device_ip}")
|
||||
|
||||
latest_backup = backup_files[-1]
|
||||
|
||||
# 3. 恢复配置
|
||||
with open(latest_backup) as f:
|
||||
config_commands = f.read().splitlines()
|
||||
|
||||
# 使用SSH执行恢复命令
|
||||
# (这里需要实现SSH连接和执行命令的逻辑)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
raise SwitchConfigException(f"Recovery failed: {str(e)}")
|
||||
|
||||
async def redundancy_check(self, critical_devices: List[str]) -> Dict:
|
||||
"""检查关键设备的冗余配置"""
|
||||
results = {}
|
||||
topology = self.scanner.get_current_topology()
|
||||
|
||||
for device_ip in critical_devices:
|
||||
# 检查是否有备用路径
|
||||
try:
|
||||
paths = list(nx.all_shortest_paths(
|
||||
topology, source=device_ip, target="core_switch"))
|
||||
|
||||
if len(paths) > 1:
|
||||
results[device_ip] = {
|
||||
'status': 'redundant',
|
||||
'path_count': len(paths)
|
||||
}
|
||||
else:
|
||||
results[device_ip] = {
|
||||
'status': 'single_point_of_failure',
|
||||
'recommendation': 'Add redundant links'
|
||||
}
|
||||
except:
|
||||
results[device_ip] = {
|
||||
'status': 'disconnected',
|
||||
'recommendation': 'Check physical connection'
|
||||
}
|
||||
|
||||
return results
|
220
src/backend/app/services/network_optimizer.py
Normal file
220
src/backend/app/services/network_optimizer.py
Normal file
@ -0,0 +1,220 @@
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from scipy.optimize import linprog
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowDemand:
|
||||
source: str
|
||||
destination: str
|
||||
bandwidth: float
|
||||
priority: float = 1.0
|
||||
|
||||
|
||||
class NetworkOptimizer:
|
||||
def __init__(self, devices: List[Dict[str, Any]]):
|
||||
"""基于图论的网络优化模型"""
|
||||
self.graph = self.build_topology_graph(devices)
|
||||
self.traffic_matrix: Optional[np.ndarray] = None
|
||||
self._initialize_capacities()
|
||||
|
||||
def _initialize_capacities(self) -> Dict[Tuple[str, str], float]:
|
||||
"""初始化链路容量字典"""
|
||||
self.remaining_capacity = {
|
||||
(u, v): data['bandwidth']
|
||||
for u, v, data in self.graph.edges(data=True)
|
||||
}
|
||||
return self.remaining_capacity
|
||||
|
||||
def build_topology_graph(self, devices: List[Dict[str, Any]]) -> nx.Graph:
|
||||
"""构建带权重的网络拓扑图"""
|
||||
G = nx.Graph()
|
||||
|
||||
for device in devices:
|
||||
G.add_node(device['ip'],
|
||||
type=device['type'],
|
||||
capacity=device.get('capacity', 1000))
|
||||
|
||||
# 添加接口作为子节点
|
||||
for interface in device.get('interfaces', []):
|
||||
interface_id = f"{device['ip']}_{interface['name']}"
|
||||
G.add_node(interface_id,
|
||||
type='interface',
|
||||
capacity=interface.get('bandwidth', 1000))
|
||||
G.add_edge(device['ip'], interface_id,
|
||||
bandwidth=interface.get('bandwidth', 1000),
|
||||
latency=interface.get('latency', 1))
|
||||
|
||||
# 添加设备间连接
|
||||
self._connect_devices(G, devices)
|
||||
return G
|
||||
|
||||
@staticmethod
|
||||
def _connect_devices(graph: nx.Graph, devices: List[Dict[str, Any]]):
|
||||
"""自动连接设备"""
|
||||
for i in range(len(devices) - 1):
|
||||
graph.add_edge(
|
||||
devices[i]['ip'],
|
||||
devices[i + 1]['ip'],
|
||||
bandwidth=1000,
|
||||
latency=5
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_flow_conservation(nodes: List[str]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""构建流量守恒约束矩阵"""
|
||||
num_nodes = len(nodes)
|
||||
A_eq: List[np.ndarray] = [] # 明确指定列表元素类型
|
||||
b_eq: List[float] = [] # 明确指定float列表
|
||||
|
||||
for i in range(num_nodes):
|
||||
for j in range(num_nodes):
|
||||
if i != j:
|
||||
constraint = np.zeros((num_nodes, num_nodes))
|
||||
constraint[i][j] = 1
|
||||
constraint[j][i] = -1
|
||||
A_eq.append(constraint.flatten())
|
||||
b_eq.append(0.0) # 明确使用float类型
|
||||
|
||||
return np.array(A_eq, dtype=np.float64), np.array(b_eq, dtype=np.float64)
|
||||
def optimize_bandwidth_allocation(
|
||||
self,
|
||||
demands: List[FlowDemand]
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
基于线性规划的带宽分配优化
|
||||
返回: {源IP: {目标IP: 分配带宽}}
|
||||
"""
|
||||
demand_dict = {(d.source, d.destination): float(d.bandwidth) for d in demands} # 确保float类型
|
||||
return self._optimize_with_highs(demand_dict)
|
||||
|
||||
def _optimize_with_highs(self, demands: Dict[Tuple[str, str], float]) -> Dict[str, Dict[str, float]]:
|
||||
"""使用HiGHS求解器实现"""
|
||||
nodes = list(self.graph.nodes())
|
||||
node_index = {node: i for i, node in enumerate(nodes)}
|
||||
|
||||
# 构建流量矩阵
|
||||
self.traffic_matrix = np.zeros((len(nodes), len(nodes)))
|
||||
for (src, dst), bw in demands.items():
|
||||
if src in node_index and dst in node_index:
|
||||
self.traffic_matrix[node_index[src]][node_index[dst]] = bw
|
||||
|
||||
# 构建约束
|
||||
c = np.ones(len(nodes) ** 2) # 最小化总流量
|
||||
A_ub, b_ub = self._build_capacity_constraints(nodes, node_index)
|
||||
A_eq, b_eq = self._build_flow_conservation(nodes)
|
||||
|
||||
# 求解线性规划
|
||||
res = linprog(
|
||||
c=c,
|
||||
A_ub=A_ub,
|
||||
b_ub=b_ub,
|
||||
A_eq=A_eq,
|
||||
b_eq=b_eq,
|
||||
bounds=(0, None),
|
||||
method='highs'
|
||||
)
|
||||
|
||||
if not res.success:
|
||||
raise ValueError(f"Optimization failed: {res.message}")
|
||||
|
||||
return self._format_results(res.x, nodes, node_index)
|
||||
|
||||
def _build_capacity_constraints(self, nodes: List[str], node_index: Dict[str, int]) -> Tuple[
|
||||
np.ndarray, np.ndarray]:
|
||||
"""构建容量约束矩阵"""
|
||||
A_ub = []
|
||||
b_ub = []
|
||||
|
||||
for u, v, data in self.graph.edges(data=True):
|
||||
capacity = float(data.get('bandwidth', 1000)) # 确保float类型
|
||||
b_ub.append(capacity)
|
||||
|
||||
constraint = np.zeros((len(nodes), len(nodes)))
|
||||
for src, dst in [(u, v), (v, u)]:
|
||||
if src in node_index and dst in node_index:
|
||||
i, j = node_index[src], node_index[dst]
|
||||
constraint[i][j] += 1
|
||||
A_ub.append(constraint.flatten())
|
||||
|
||||
return np.array(A_ub), np.array(b_ub)
|
||||
|
||||
@staticmethod
|
||||
def _format_results(solution: np.ndarray, nodes: List[str], node_index: Dict[str, int]) -> Dict[
|
||||
str, Dict[str, float]]:
|
||||
"""格式化优化结果"""
|
||||
flows = solution.reshape((len(nodes), len(nodes)))
|
||||
return {
|
||||
nodes[i]: {
|
||||
nodes[j]: float(flows[i][j]) # 明确转换为float
|
||||
for j in range(len(nodes))
|
||||
if flows[i][j] > 0.001
|
||||
}
|
||||
for i in range(len(nodes))
|
||||
}
|
||||
|
||||
def optimize_qos(self, flows: List[FlowDemand]) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
带QoS的流量优化
|
||||
返回: {
|
||||
"源-目标": {
|
||||
"path": 路径列表,
|
||||
"allocated": 分配带宽,
|
||||
"priority": 优先级
|
||||
}
|
||||
}
|
||||
"""
|
||||
sorted_flows = sorted(flows, key=lambda x: x.priority, reverse=True)
|
||||
results = {}
|
||||
|
||||
# 使用实例变量而非局部变量
|
||||
self._initialize_capacities() # 重置剩余带宽
|
||||
|
||||
for flow in sorted_flows:
|
||||
path = self.find_optimal_path(flow.source, flow.destination, flow.bandwidth)
|
||||
if not path:
|
||||
continue
|
||||
|
||||
min_bw = min(self.graph[u][v]['bandwidth'] for u, v in zip(path[:-1], path[1:]))
|
||||
allocated = min(min_bw, flow.bandwidth)
|
||||
|
||||
# 更新剩余带宽
|
||||
for u, v in zip(path[:-1], path[1:]):
|
||||
self.remaining_capacity[(u, v)] -= allocated
|
||||
if self.remaining_capacity[(u, v)] <= 0:
|
||||
self.graph[u][v]['bandwidth'] = 0
|
||||
|
||||
results[f"{flow.source}-{flow.destination}"] = {
|
||||
"path": path,
|
||||
"allocated": float(allocated), # 确保float类型
|
||||
"priority": float(flow.priority)
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
def find_optimal_path(self, source: str, target: str,
|
||||
bandwidth: float = 1.0) -> Optional[List[str]]:
|
||||
"""
|
||||
改进的最优路径查找
|
||||
"""
|
||||
try:
|
||||
paths = nx.shortest_simple_paths(
|
||||
self.graph, source, target,
|
||||
weight=lambda u, v, d: 1 / max(1, d['bandwidth']) + d['latency'] # 避免除零
|
||||
)
|
||||
return next(
|
||||
(path for path in paths
|
||||
if self._path_has_sufficient_bandwidth(path, bandwidth)),
|
||||
None
|
||||
)
|
||||
except (nx.NetworkXNoPath, nx.NodeNotFound):
|
||||
return None
|
||||
|
||||
def _path_has_sufficient_bandwidth(self, path: List[str], bw: float) -> bool:
|
||||
"""检查路径带宽是否满足要求"""
|
||||
return all(
|
||||
self.graph[u][v]['bandwidth'] >= bw
|
||||
for u, v in zip(path[:-1], path[1:])
|
||||
)
|
@ -15,7 +15,7 @@ class NetworkScanner:
|
||||
|
||||
devices = []
|
||||
try:
|
||||
await self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}')
|
||||
self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}')
|
||||
for host in self.nm.all_hosts():
|
||||
ip = host
|
||||
mac = self.nm[host]['addresses'].get('mac', 'N/A')
|
||||
|
104
src/backend/app/services/network_visualizer.py
Normal file
104
src/backend/app/services/network_visualizer.py
Normal file
@ -0,0 +1,104 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
from io import BytesIO
|
||||
import base64
|
||||
from typing import Dict, List
|
||||
from ..models.traffic_models import SwitchTrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
|
||||
|
||||
class NetworkVisualizer:
|
||||
def __init__(self):
|
||||
self.graph = nx.Graph()
|
||||
|
||||
def update_topology(self, devices: List[Dict]):
|
||||
"""更新网络拓扑图"""
|
||||
self.graph.clear()
|
||||
|
||||
# 添加节点
|
||||
for device in devices:
|
||||
self.graph.add_node(
|
||||
device['ip'],
|
||||
type='switch',
|
||||
label=f"Switch\n{device['ip']}"
|
||||
)
|
||||
|
||||
# 添加连接(简化版,实际应根据扫描结果)
|
||||
if len(devices) > 1:
|
||||
for i in range(len(devices) - 1):
|
||||
self.graph.add_edge(
|
||||
devices[i]['ip'],
|
||||
devices[i + 1]['ip'],
|
||||
bandwidth=1000,
|
||||
label="1Gbps"
|
||||
)
|
||||
|
||||
def generate_topology_image(self) -> str:
|
||||
"""生成拓扑图并返回base64编码"""
|
||||
plt.figure(figsize=(10, 8))
|
||||
pos = nx.spring_layout(self.graph)
|
||||
|
||||
# 绘制节点
|
||||
node_colors = []
|
||||
for node in self.graph.nodes():
|
||||
if self.graph.nodes[node]['type'] == 'switch':
|
||||
node_colors.append('lightblue')
|
||||
|
||||
nx.draw_networkx_nodes(
|
||||
self.graph, pos,
|
||||
node_size=2000,
|
||||
node_color=node_colors
|
||||
)
|
||||
|
||||
# 绘制边
|
||||
nx.draw_networkx_edges(
|
||||
self.graph, pos,
|
||||
width=2,
|
||||
alpha=0.5
|
||||
)
|
||||
|
||||
# 绘制标签
|
||||
node_labels = nx.get_node_attributes(self.graph, 'label')
|
||||
nx.draw_networkx_labels(
|
||||
self.graph, pos,
|
||||
labels=node_labels,
|
||||
font_size=8
|
||||
)
|
||||
|
||||
edge_labels = nx.get_edge_attributes(self.graph, 'label')
|
||||
nx.draw_networkx_edge_labels(
|
||||
self.graph, pos,
|
||||
edge_labels=edge_labels,
|
||||
font_size=8
|
||||
)
|
||||
|
||||
plt.title("Network Topology")
|
||||
plt.axis('off')
|
||||
|
||||
# 转换为base64
|
||||
buf = BytesIO()
|
||||
plt.savefig(buf, format='png', bbox_inches='tight')
|
||||
plt.close()
|
||||
buf.seek(0)
|
||||
return base64.b64encode(buf.read()).decode('utf-8')
|
||||
|
||||
@staticmethod
|
||||
def get_traffic_heatmap(minutes: int = 10) -> Dict:
|
||||
"""获取流量热力图数据"""
|
||||
with SessionLocal() as session:
|
||||
records = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.timestamp >= datetime.now() - timedelta(minutes=minutes)
|
||||
).all()
|
||||
|
||||
heatmap_data = {}
|
||||
for record in records:
|
||||
if record.switch_ip not in heatmap_data:
|
||||
heatmap_data[record.switch_ip] = {}
|
||||
heatmap_data[record.switch_ip][record.interface] = {
|
||||
'in': record.rate_in,
|
||||
'out': record.rate_out
|
||||
}
|
||||
|
||||
return heatmap_data
|
120
src/backend/app/services/report_generator.py
Normal file
120
src/backend/app/services/report_generator.py
Normal file
@ -0,0 +1,120 @@
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from io import BytesIO
|
||||
import base64
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List
|
||||
from ..models.traffic_models import TrafficRecord, SwitchTrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
|
||||
|
||||
class ReportGenerator:
|
||||
@staticmethod
|
||||
def generate_traffic_report(ip: str = None, days: int = 7) -> Dict:
|
||||
"""生成流量分析报告"""
|
||||
with SessionLocal() as session:
|
||||
time_threshold = datetime.now() - timedelta(days=days)
|
||||
|
||||
if ip:
|
||||
# 交换机流量报告
|
||||
query = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.switch_ip == ip,
|
||||
SwitchTrafficRecord.timestamp >= time_threshold
|
||||
)
|
||||
df = pd.read_sql(query.statement, session.bind)
|
||||
|
||||
if df.empty:
|
||||
return {"error": "No data found"}
|
||||
|
||||
# 按接口分组分析
|
||||
interface_stats = df.groupby('interface').agg({
|
||||
'rate_in': ['mean', 'max', 'min'],
|
||||
'rate_out': ['mean', 'max', 'min']
|
||||
}).reset_index()
|
||||
|
||||
# 生成趋势图
|
||||
trend_fig = ReportGenerator._plot_traffic_trend(df, f"Switch {ip} Traffic Trend")
|
||||
|
||||
return {
|
||||
"switch_ip": ip,
|
||||
"period": f"Last {days} days",
|
||||
"interface_stats": interface_stats.to_dict('records'),
|
||||
"trend_chart": trend_fig
|
||||
}
|
||||
else:
|
||||
# 本地网络流量报告
|
||||
query = session.query(TrafficRecord).filter(
|
||||
TrafficRecord.timestamp >= time_threshold
|
||||
)
|
||||
df = pd.read_sql(query.statement, session.bind)
|
||||
|
||||
if df.empty:
|
||||
return {"error": "No data found"}
|
||||
|
||||
# 按接口分组分析
|
||||
interface_stats = df.groupby('interface').agg({
|
||||
'bytes_sent': ['sum', 'mean', 'max'],
|
||||
'bytes_recv': ['sum', 'mean', 'max']
|
||||
}).reset_index()
|
||||
|
||||
# 生成趋势图
|
||||
trend_fig = ReportGenerator._plot_traffic_trend(df, "Local Network Traffic Trend")
|
||||
|
||||
return {
|
||||
"report_type": "local_network",
|
||||
"period": f"Last {days} days",
|
||||
"interface_stats": interface_stats.to_dict('records'),
|
||||
"trend_chart": trend_fig
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _plot_traffic_trend(df: pd.DataFrame, title: str) -> str:
|
||||
"""生成流量趋势图"""
|
||||
plt.figure(figsize=(12, 6))
|
||||
|
||||
if 'rate_in' in df.columns: # 交换机数据
|
||||
df_grouped = df.groupby(pd.Grouper(key='timestamp', freq='1H')).agg({
|
||||
'rate_in': 'mean',
|
||||
'rate_out': 'mean'
|
||||
})
|
||||
plt.plot(df_grouped.index, df_grouped['rate_in'], label='Inbound Traffic')
|
||||
plt.plot(df_grouped.index, df_grouped['rate_out'], label='Outbound Traffic')
|
||||
plt.ylabel("Traffic Rate (bytes/sec)")
|
||||
else: # 本地网络数据
|
||||
df_grouped = df.groupby(pd.Grouper(key='timestamp', freq='1H')).agg({
|
||||
'bytes_sent': 'sum',
|
||||
'bytes_recv': 'sum'
|
||||
})
|
||||
plt.plot(df_grouped.index, df_grouped['bytes_sent'], label='Bytes Sent')
|
||||
plt.plot(df_grouped.index, df_grouped['bytes_recv'], label='Bytes Received')
|
||||
plt.ylabel("Bytes")
|
||||
|
||||
plt.title(title)
|
||||
plt.xlabel("Time")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# 转换为base64
|
||||
buf = BytesIO()
|
||||
plt.savefig(buf, format='png', bbox_inches='tight')
|
||||
plt.close()
|
||||
buf.seek(0)
|
||||
return base64.b64encode(buf.read()).decode('utf-8')
|
||||
|
||||
@staticmethod
|
||||
def generate_config_history_report(days: int = 30) -> Dict:
|
||||
"""生成配置变更历史报告"""
|
||||
# 需要实现配置历史记录功能
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def generate_security_report() -> Dict:
|
||||
"""生成安全评估报告"""
|
||||
# 需要实现安全扫描功能
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def generate_performance_report() -> Dict:
|
||||
"""生成性能评估报告"""
|
||||
# 需要实现性能基准测试功能
|
||||
pass
|
@ -8,7 +8,6 @@ from typing import Dict, Optional, List
|
||||
|
||||
from ..models.traffic_models import TrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
from ..utils.logger import logger
|
||||
|
||||
|
||||
class TrafficMonitor:
|
||||
@ -34,7 +33,7 @@ class TrafficMonitor:
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._monitor_loop())
|
||||
logger.info("流量监控已启动")
|
||||
print("流量监控已启动")
|
||||
|
||||
async def stop_monitoring(self):
|
||||
"""停止流量监控"""
|
||||
@ -45,7 +44,7 @@ class TrafficMonitor:
|
||||
await self.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("流量监控已停止")
|
||||
print("流量监控已停止")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""监控主循环"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user