mirror of
https://github.com/Jerryplusy/AI-powered-switches.git
synced 2025-10-14 01:39:18 +00:00
464 lines
16 KiB
Python
464 lines
16 KiB
Python
import socket
|
|
from datetime import datetime, timedelta
|
|
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
|
|
from typing import List
|
|
from pydantic import BaseModel
|
|
import asyncio
|
|
from fastapi.responses import HTMLResponse
|
|
import matplotlib.pyplot as plt
|
|
import io
|
|
import base64
|
|
import psutil
|
|
import ipaddress
|
|
|
|
from ..services.switch_traffic_monitor import get_switch_monitor
|
|
from ..utils import logger
|
|
from ...app.services.ai_services import AIService
|
|
from ...app.api.network_config import SwitchConfigurator
|
|
from ...config import settings
|
|
from ..services.network_scanner import NetworkScanner
|
|
from ...app.services.traffic_monitor import traffic_monitor
|
|
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
|
|
from src.backend.app.api.database import SessionLocal
|
|
|
|
|
|
|
|
|
|
router = APIRouter(prefix="", tags=["API"])
|
|
scanner = NetworkScanner()
|
|
|
|
@router.get("/", include_in_schema=False)
|
|
async def root():
|
|
return {
|
|
"message": "欢迎使用AI交换机配置系统",
|
|
"docs": f"{settings.API_PREFIX}/docs",
|
|
"redoc": f"{settings.API_PREFIX}/redoc",
|
|
"endpoints": [
|
|
"/parse_command",
|
|
"/apply_config",
|
|
"/scan_network",
|
|
"/list_devices",
|
|
"/batch_apply_config"
|
|
"/traffic/switch/current",
|
|
"/traffic/switch/history"
|
|
]
|
|
}
|
|
|
|
@router.get("/favicon.ico", include_in_schema=False)
|
|
async def favicon():
|
|
return Response(status_code=204)
|
|
|
|
class BatchConfigRequest(BaseModel):
|
|
config: dict
|
|
switch_ips: List[str]
|
|
username: str = None
|
|
password: str = None
|
|
timeout: int = None
|
|
|
|
@router.post("/batch_apply_config")
|
|
async def batch_apply_config(request: BatchConfigRequest):
|
|
results = {}
|
|
for ip in request.switch_ips:
|
|
try:
|
|
configurator = SwitchConfigurator(
|
|
username=request.username,
|
|
password=request.password,
|
|
timeout=request.timeout )
|
|
results[ip] = await configurator.apply_config(ip, request.config)
|
|
except Exception as e:
|
|
results[ip] = str(e)
|
|
return {"results": results}
|
|
|
|
|
|
|
|
@router.get("/test")
|
|
async def test_endpoint():
|
|
return {"message": "Hello World"}
|
|
|
|
@router.get("/scan_network", summary="扫描网络中的交换机")
|
|
async def scan_network(subnet: str = "192.168.1.0/24"):
|
|
try:
|
|
devices = scanner.scan_subnet(subnet)
|
|
return {
|
|
"success": True,
|
|
"devices": devices,
|
|
"count": len(devices)
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(500, f"扫描失败: {str(e)}")
|
|
|
|
@router.get("/list_devices", summary="列出已发现的交换机")
|
|
async def list_devices():
|
|
return {
|
|
"devices": scanner.load_cached_devices()
|
|
}
|
|
|
|
class CommandRequest(BaseModel):
|
|
command: str
|
|
|
|
class ConfigRequest(BaseModel):
|
|
config: dict
|
|
switch_ip: str
|
|
username: str = None
|
|
password: str = None
|
|
timeout: int = None
|
|
|
|
@router.post("/parse_command", response_model=dict)
|
|
async def parse_command(request: CommandRequest):
|
|
"""
|
|
解析中文命令并返回JSON配置
|
|
"""
|
|
try:
|
|
ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
|
|
config = await ai_service.parse_command(request.command)
|
|
return {"success": True, "config": config}
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Failed to parse command: {str(e)}"
|
|
)
|
|
|
|
@router.post("/apply_config", response_model=dict)
|
|
async def apply_config(request: ConfigRequest):
|
|
"""
|
|
应用配置到交换机
|
|
"""
|
|
try:
|
|
configurator = SwitchConfigurator(
|
|
username=request.username,
|
|
password=request.password,
|
|
timeout=request.timeout
|
|
)
|
|
result = await configurator.safe_apply(request.switch_ip, request.config)
|
|
return {"success": True, "result": result}
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to apply config: {str(e)}"
|
|
)
|
|
|
|
|
|
class CLICommandRequest(BaseModel):
|
|
switch_ip: str
|
|
commands: List[str]
|
|
is_ensp: bool = False
|
|
|
|
def extract_credentials(self) -> tuple:
|
|
"""从commands中提取用户名和密码"""
|
|
username = None
|
|
password = None
|
|
|
|
for cmd in self.commands:
|
|
if cmd.startswith("!username="):
|
|
username = cmd.split("=")[1]
|
|
elif cmd.startswith("!password="):
|
|
password = cmd.split("=")[1]
|
|
|
|
return username, password
|
|
|
|
def get_clean_commands(self) -> List[str]:
|
|
"""获取去除凭据后的实际命令"""
|
|
return [cmd for cmd in self.commands
|
|
if not (cmd.startswith("!username=") or cmd.startswith("!password="))]
|
|
|
|
@router.post("/execute_cli_commands", response_model=dict)
|
|
async def execute_cli_commands(request: CLICommandRequest):
|
|
"""
|
|
执行前端生成的CLI命令
|
|
支持在commands中嵌入凭据:
|
|
!username=admin
|
|
!password=cisco123
|
|
"""
|
|
try:
|
|
username, password = request.extract_credentials()
|
|
clean_commands = request.get_clean_commands()
|
|
|
|
configurator = SwitchConfigurator(
|
|
username=username,
|
|
password=password,
|
|
timeout=settings.SWITCH_TIMEOUT,
|
|
ensp_mode=request.is_ensp
|
|
)
|
|
|
|
result = await configurator.execute_raw_commands(
|
|
ip=request.switch_ip,
|
|
commands=request.commands
|
|
)
|
|
return {
|
|
"success": True,
|
|
"output": result,
|
|
"mode": "eNSP" if request.is_ensp else "SSH"
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(500, detail=str(e))
|
|
|
|
@router.get("/traffic/interfaces", summary="获取所有网络接口")
|
|
async def get_network_interfaces():
|
|
return {
|
|
"interfaces": traffic_monitor.get_interfaces()
|
|
}
|
|
|
|
@router.get("/traffic/current", summary="获取当前流量数据")
|
|
async def get_current_traffic(interface: str = None):
|
|
return traffic_monitor.get_current_traffic(interface)
|
|
|
|
@router.get("/traffic/history", summary="获取流量历史数据")
|
|
async def get_traffic_history(interface: str = None, limit: int = 100):
|
|
history = traffic_monitor.get_traffic_history(interface)
|
|
return {
|
|
"sent": history["sent"][-limit:],
|
|
"recv": history["recv"][-limit:],
|
|
"time": [t.isoformat() for t in history["time"]][-limit:]
|
|
}
|
|
|
|
@router.get("/traffic/records", summary="获取流量记录")
|
|
async def get_traffic_records(interface: str = None, limit: int = 100):
|
|
with SessionLocal() as session:
|
|
query = session.query(TrafficRecord)
|
|
if interface:
|
|
query = query.filter(TrafficRecord.interface == interface)
|
|
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
|
|
return [record.to_dict() for record in records]
|
|
|
|
@router.websocket("/ws/traffic")
|
|
async def websocket_traffic(websocket: WebSocket):
|
|
"""实时流量WebSocket"""
|
|
await websocket.accept()
|
|
try:
|
|
while True:
|
|
traffic_data = traffic_monitor.get_current_traffic()
|
|
await websocket.send_json(traffic_data)
|
|
await asyncio.sleep(1)
|
|
except WebSocketDisconnect:
|
|
print("客户端断开连接")
|
|
|
|
|
|
@router.get("/", include_in_schema=False)
|
|
async def root():
|
|
return {
|
|
"message": "欢迎使用AI交换机配置系统",
|
|
"docs": f"{settings.API_PREFIX}/docs",
|
|
"redoc": f"{settings.API_PREFIX}/redoc",
|
|
"endpoints": [
|
|
"/parse_command",
|
|
"/apply_config",
|
|
"/scan_network",
|
|
"/list_devices",
|
|
"/batch_apply_config",
|
|
"/traffic/switch/current",
|
|
"/traffic/switch/history"
|
|
]
|
|
}
|
|
|
|
|
|
|
|
|
|
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
|
|
async def get_switch_interfaces(switch_ip: str):
|
|
"""获取指定交换机的所有接口"""
|
|
try:
|
|
monitor = get_switch_monitor(switch_ip)
|
|
interfaces = list(monitor.interface_oids.keys())
|
|
return {
|
|
"switch_ip": switch_ip,
|
|
"interfaces": interfaces
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"获取交换机接口失败: {str(e)}")
|
|
raise HTTPException(500, f"获取接口失败: {str(e)}")
|
|
|
|
|
|
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
|
|
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
|
|
"""获取交换机的当前流量数据"""
|
|
try:
|
|
monitor = get_switch_monitor(switch_ip)
|
|
|
|
|
|
if not interface:
|
|
traffic_data = {}
|
|
for iface in monitor.interface_oids:
|
|
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
|
|
return {
|
|
"switch_ip": switch_ip,
|
|
"traffic": traffic_data
|
|
}
|
|
|
|
return await get_interface_current_traffic(switch_ip, interface)
|
|
except Exception as e:
|
|
logger.error(f"获取交换机流量失败: {str(e)}")
|
|
raise HTTPException(500, f"获取流量失败: {str(e)}")
|
|
|
|
|
|
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
|
|
"""获取指定交换机接口的当前流量数据"""
|
|
try:
|
|
with SessionLocal() as session:
|
|
|
|
record = session.query(SwitchTrafficRecord).filter(
|
|
SwitchTrafficRecord.switch_ip == switch_ip,
|
|
SwitchTrafficRecord.interface == interface
|
|
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
|
|
|
|
if not record:
|
|
return {
|
|
"switch_ip": switch_ip,
|
|
"interface": interface,
|
|
"rate_in": 0.0,
|
|
"rate_out": 0.0,
|
|
"bytes_in": 0,
|
|
"bytes_out": 0
|
|
}
|
|
|
|
return {
|
|
"switch_ip": switch_ip,
|
|
"interface": interface,
|
|
"rate_in": record.rate_in,
|
|
"rate_out": record.rate_out,
|
|
"bytes_in": record.bytes_in,
|
|
"bytes_out": record.bytes_out
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"获取接口流量失败: {str(e)}")
|
|
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
|
|
|
|
|
|
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
|
|
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
|
|
"""获取交换机的流量历史数据"""
|
|
try:
|
|
monitor = get_switch_monitor(switch_ip)
|
|
|
|
if not interface:
|
|
return {
|
|
"switch_ip": switch_ip,
|
|
"history": monitor.get_traffic_history()
|
|
}
|
|
|
|
with SessionLocal() as session:
|
|
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
|
|
|
records = session.query(SwitchTrafficRecord).filter(
|
|
SwitchTrafficRecord.switch_ip == switch_ip,
|
|
SwitchTrafficRecord.interface == interface,
|
|
SwitchTrafficRecord.timestamp >= time_threshold
|
|
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
|
|
|
|
history_data = {
|
|
"in": [record.rate_in for record in records],
|
|
"out": [record.rate_out for record in records],
|
|
"time": [record.timestamp.isoformat() for record in records]
|
|
}
|
|
|
|
return {
|
|
"switch_ip": switch_ip,
|
|
"interface": interface,
|
|
"history": history_data
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"获取历史流量失败: {str(e)}")
|
|
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
|
|
|
|
|
|
@router.websocket("/ws/traffic/switch")
|
|
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
|
|
"""交换机实时流量WebSocket"""
|
|
await websocket.accept()
|
|
try:
|
|
monitor = get_switch_monitor(switch_ip)
|
|
|
|
while True:
|
|
if interface:
|
|
traffic_data = await get_interface_current_traffic(switch_ip, interface)
|
|
await websocket.send_json(traffic_data)
|
|
else:
|
|
traffic_data = {}
|
|
for iface in monitor.interface_oids:
|
|
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
|
|
|
|
await websocket.send_json({
|
|
"switch_ip": switch_ip,
|
|
"traffic": traffic_data
|
|
})
|
|
|
|
await asyncio.sleep(1)
|
|
except WebSocketDisconnect:
|
|
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
|
|
except Exception as e:
|
|
logger.error(f"交换机流量WebSocket错误: {str(e)}")
|
|
await websocket.close(code=1011, reason=str(e))
|
|
|
|
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
|
|
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
|
|
"""生成交换机流量图表"""
|
|
try:
|
|
history = await get_switch_traffic_history(switch_ip, interface, minutes)
|
|
history_data = history["history"]
|
|
|
|
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
|
|
in_rates = history_data["in"]
|
|
out_rates = history_data["out"]
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
|
|
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
|
|
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
|
|
plt.xlabel("时间")
|
|
plt.ylabel("流量 (字节/秒)")
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
buf = io.BytesIO()
|
|
plt.savefig(buf, format="png")
|
|
buf.seek(0)
|
|
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
|
|
plt.close()
|
|
|
|
return f"""
|
|
<html>
|
|
<head>
|
|
<title>交换机流量监控</title>
|
|
<style>
|
|
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
|
.container {{ max-width: 900px; margin: 0 auto; }}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<div class="container">
|
|
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
|
|
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
|
|
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
|
</div>
|
|
</body>
|
|
</html>
|
|
"""
|
|
except Exception as e:
|
|
logger.error(f"生成流量图表失败: {str(e)}")
|
|
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)
|
|
|
|
|
|
@router.get("/network_adapters", summary="获取网络适配器网段")
|
|
async def get_network_adapters():
|
|
try:
|
|
net_if_addrs = psutil.net_if_addrs()
|
|
|
|
networks = []
|
|
for interface, addrs in net_if_addrs.items():
|
|
for addr in addrs:
|
|
if addr.family == socket.AF_INET:
|
|
ip = addr.address
|
|
netmask = addr.netmask
|
|
|
|
network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False)
|
|
networks.append({
|
|
"adapter": interface,
|
|
"network": str(network),
|
|
"ip": ip,
|
|
"subnet_mask": netmask
|
|
})
|
|
|
|
return {"networks": networks}
|
|
|
|
except Exception as e:
|
|
return {"error": f"获取网络适配器信息失败: {str(e)}"} |