2025-07-30 00:31:08 +08:00

533 lines
18 KiB
Python

import socket
from datetime import datetime, timedelta
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
from typing import List
from pydantic import BaseModel
import asyncio
from fastapi.responses import HTMLResponse
import matplotlib.pyplot as plt
import io
import base64
import psutil
import ipaddress
from ..services.switch_traffic_monitor import get_switch_monitor
from ..utils import logger
from ...app.services.ai_services import AIService
from ...app.api.network_config import SwitchConfigurator
from ...config import settings
from ..services.network_scanner import NetworkScanner
from ...app.services.traffic_monitor import traffic_monitor
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
from ..services.network_visualizer import NetworkVisualizer
from ..services.config_validator import ConfigValidator
from ..services.report_generator import ReportGenerator
from fastapi.responses import JSONResponse
router = APIRouter(prefix="", tags=["API"])
scanner = NetworkScanner()
@router.get("/", include_in_schema=False)
async def root():
return {
"message": "欢迎使用AI交换机配置系统",
"docs": f"{settings.API_PREFIX}/docs",
"redoc": f"{settings.API_PREFIX}/redoc",
"endpoints": [
"/parse_command",
"/apply_config",
"/scan_network",
"/list_devices",
"/batch_apply_config"
"/traffic/switch/current",
"/traffic/switch/history"
]
}
@router.get("/favicon.ico", include_in_schema=False)
async def favicon():
return Response(status_code=204)
class BatchConfigRequest(BaseModel):
config: dict
switch_ips: List[str]
username: str = None
password: str = None
timeout: int = None
@router.post("/batch_apply_config")
async def batch_apply_config(request: BatchConfigRequest):
results = {}
for ip in request.switch_ips:
try:
configurator = SwitchConfigurator(
username=request.username,
password=request.password,
timeout=request.timeout )
results[ip] = await configurator.apply_config(ip, request.config)
except Exception as e:
results[ip] = str(e)
return {"results": results}
@router.get("/test")
async def test_endpoint():
return {"message": "Hello World"}
@router.get("/scan_network", summary="扫描网络中的交换机")
async def scan_network(subnet: str = "192.168.1.0/24"):
try:
devices = scanner.scan_subnet(subnet)
return {
"success": True,
"devices": devices,
"count": len(devices)
}
except Exception as e:
raise HTTPException(500, f"扫描失败: {str(e)}")
@router.get("/list_devices", summary="列出已发现的交换机")
async def list_devices():
return {
"devices": scanner.load_cached_devices()
}
class CommandRequest(BaseModel):
command: str
class ConfigRequest(BaseModel):
config: dict
switch_ip: str
username: str = None
password: str = None
timeout: int = None
@router.post("/parse_command", response_model=dict)
async def parse_command(request: CommandRequest):
"""
解析中文命令并返回JSON配置
"""
try:
ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
config = await ai_service.parse_command(request.command)
return {"success": True, "config": config}
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to parse command: {str(e)}"
)
@router.post("/apply_config", response_model=dict)
async def apply_config(request: ConfigRequest):
"""
应用配置到交换机
"""
try:
configurator = SwitchConfigurator(
username=request.username,
password=request.password,
timeout=request.timeout
)
result = await configurator.safe_apply(request.switch_ip, request.config)
return {"success": True, "result": result}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to apply config: {str(e)}"
)
class CLICommandRequest(BaseModel):
switch_ip: str
commands: List[str]
is_ensp: bool = False
def extract_credentials(self) -> tuple:
"""从commands中提取用户名和密码"""
username = None
password = None
for cmd in self.commands:
if cmd.startswith("!username="):
username = cmd.split("=")[1]
elif cmd.startswith("!password="):
password = cmd.split("=")[1]
return username, password
def get_clean_commands(self) -> List[str]:
"""获取去除凭据后的实际命令"""
return [cmd for cmd in self.commands
if not (cmd.startswith("!username=") or cmd.startswith("!password="))]
@router.post("/execute_cli_commands", response_model=dict)
async def execute_cli_commands(request: CLICommandRequest):
"""
执行前端生成的CLI命令
支持在commands中嵌入凭据:
!username=admin
!password=cisco123
"""
try:
username, password = request.extract_credentials()
clean_commands = request.get_clean_commands()
configurator = SwitchConfigurator(
username=username,
password=password,
timeout=settings.SWITCH_TIMEOUT,
ensp_mode=request.is_ensp
)
result = await configurator.execute_raw_commands(
ip=request.switch_ip,
commands=request.commands
)
return {
"success": True,
"output": result,
"mode": "eNSP" if request.is_ensp else "SSH"
}
except Exception as e:
raise HTTPException(500, detail=str(e))
@router.get("/traffic/interfaces", summary="获取所有网络接口")
async def get_network_interfaces():
return {
"interfaces": traffic_monitor.get_interfaces()
}
@router.get("/traffic/current", summary="获取当前流量数据")
async def get_current_traffic(interface: str = None):
return traffic_monitor.get_current_traffic(interface)
@router.get("/traffic/history", summary="获取流量历史数据")
async def get_traffic_history(interface: str = None, limit: int = 100):
history = traffic_monitor.get_traffic_history(interface)
return {
"sent": history["sent"][-limit:],
"recv": history["recv"][-limit:],
"time": [t.isoformat() for t in history["time"]][-limit:]
}
@router.get("/traffic/records", summary="获取流量记录")
async def get_traffic_records(interface: str = None, limit: int = 100):
with SessionLocal() as session:
query = session.query(TrafficRecord)
if interface:
query = query.filter(TrafficRecord.interface == interface)
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
return [record.to_dict() for record in records]
@router.websocket("/ws/traffic")
async def websocket_traffic(websocket: WebSocket):
"""实时流量WebSocket"""
await websocket.accept()
try:
while True:
traffic_data = traffic_monitor.get_current_traffic()
await websocket.send_json(traffic_data)
await asyncio.sleep(1)
except WebSocketDisconnect:
print("客户端断开连接")
@router.get("/", include_in_schema=False)
async def root():
return {
"message": "欢迎使用AI交换机配置系统",
"docs": f"{settings.API_PREFIX}/docs",
"redoc": f"{settings.API_PREFIX}/redoc",
"endpoints": [
"/parse_command",
"/apply_config",
"/scan_network",
"/list_devices",
"/batch_apply_config",
"/traffic/switch/current",
"/traffic/switch/history"
]
}
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
async def get_switch_interfaces(switch_ip: str):
"""获取指定交换机的所有接口"""
try:
monitor = get_switch_monitor(switch_ip)
interfaces = list(monitor.interface_oids.keys())
return {
"switch_ip": switch_ip,
"interfaces": interfaces
}
except Exception as e:
logger.error(f"获取交换机接口失败: {str(e)}")
raise HTTPException(500, f"获取接口失败: {str(e)}")
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
"""获取交换机的当前流量数据"""
try:
monitor = get_switch_monitor(switch_ip)
if not interface:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
return {
"switch_ip": switch_ip,
"traffic": traffic_data
}
return await get_interface_current_traffic(switch_ip, interface)
except Exception as e:
logger.error(f"获取交换机流量失败: {str(e)}")
raise HTTPException(500, f"获取流量失败: {str(e)}")
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
"""获取指定交换机接口的当前流量数据"""
try:
with SessionLocal() as session:
record = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
if not record:
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": 0.0,
"rate_out": 0.0,
"bytes_in": 0,
"bytes_out": 0
}
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": record.rate_in,
"rate_out": record.rate_out,
"bytes_in": record.bytes_in,
"bytes_out": record.bytes_out
}
except Exception as e:
logger.error(f"获取接口流量失败: {str(e)}")
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
"""获取交换机的流量历史数据"""
try:
monitor = get_switch_monitor(switch_ip)
if not interface:
return {
"switch_ip": switch_ip,
"history": monitor.get_traffic_history()
}
with SessionLocal() as session:
time_threshold = datetime.now() - timedelta(minutes=minutes)
records = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface,
SwitchTrafficRecord.timestamp >= time_threshold
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
history_data = {
"in": [record.rate_in for record in records],
"out": [record.rate_out for record in records],
"time": [record.timestamp.isoformat() for record in records]
}
return {
"switch_ip": switch_ip,
"interface": interface,
"history": history_data
}
except Exception as e:
logger.error(f"获取历史流量失败: {str(e)}")
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
@router.websocket("/ws/traffic/switch")
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
"""交换机实时流量WebSocket"""
await websocket.accept()
try:
monitor = get_switch_monitor(switch_ip)
while True:
if interface:
traffic_data = await get_interface_current_traffic(switch_ip, interface)
await websocket.send_json(traffic_data)
else:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
await websocket.send_json({
"switch_ip": switch_ip,
"traffic": traffic_data
})
await asyncio.sleep(1)
except WebSocketDisconnect:
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
except Exception as e:
logger.error(f"交换机流量WebSocket错误: {str(e)}")
await websocket.close(code=1011, reason=str(e))
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
"""生成交换机流量图表"""
try:
history = await get_switch_traffic_history(switch_ip, interface, minutes)
history_data = history["history"]
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
in_rates = history_data["in"]
out_rates = history_data["out"]
plt.figure(figsize=(12, 6))
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
plt.xlabel("时间")
plt.ylabel("流量 (字节/秒)")
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
return f"""
<html>
<head>
<title>交换机流量监控</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.container {{ max-width: 900px; margin: 0 auto; }}
</style>
</head>
<body>
<div class="container">
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</div>
</body>
</html>
"""
except Exception as e:
logger.error(f"生成流量图表失败: {str(e)}")
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)
@router.get("/network_adapters", summary="获取网络适配器网段")
async def get_network_adapters():
try:
net_if_addrs = psutil.net_if_addrs()
networks = []
for interface, addrs in net_if_addrs.items():
for addr in addrs:
if addr.family == socket.AF_INET:
ip = addr.address
netmask = addr.netmask
network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False)
networks.append({
"adapter": interface,
"network": str(network),
"ip": ip,
"subnet_mask": netmask
})
return {"networks": networks}
except Exception as e:
return {"error": f"获取网络适配器信息失败: {str(e)}"}
visualizer = NetworkVisualizer()
report_gen = ReportGenerator()
@router.get("/topology/visualize", response_class=HTMLResponse)
async def visualize_topology():
"""获取网络拓扑可视化图"""
try:
devices = await list_devices() # 复用现有的设备列表接口
visualizer.update_topology(devices["devices"])
image_data = visualizer.generate_topology_image()
return f"""
<html>
<head><title>Network Topology</title></head>
<body>
<h1>Network Topology</h1>
<img src="data:image/png;base64,{image_data}" alt="Network Topology">
</body>
</html>
"""
except Exception as e:
raise HTTPException(500, detail=str(e))
@router.post("/config/validate")
async def validate_config(config: dict):
"""验证配置有效性"""
is_valid, errors = ConfigValidator.validate_full_config(config)
return {
"valid": is_valid,
"errors": errors,
"has_security_risks": len(ConfigValidator.check_security_risks(config.get("commands", []))) > 0
}
@router.get("/reports/traffic/{ip}")
async def get_traffic_report(ip: str, days: int = 1):
"""获取流量分析报告"""
try:
report = report_gen.generate_traffic_report(ip, days)
return JSONResponse(content=report)
except Exception as e:
raise HTTPException(500, detail=str(e))
@router.get("/reports/traffic")
async def get_local_traffic_report(days: int = 1):
"""获取本地网络流量报告"""
try:
report = report_gen.generate_traffic_report(days=days)
return JSONResponse(content=report)
except Exception as e:
raise HTTPException(500, detail=str(e))
@router.get("/topology/traffic_heatmap")
async def get_traffic_heatmap(minutes: int = 10):
"""获取流量热力图数据"""
try:
heatmap = visualizer.get_traffic_heatmap(minutes)
return {"heatmap": heatmap}
except Exception as e:
raise HTTPException(500, detail=str(e))