mirror of
https://github.com/Jerryplusy/AI-powered-switches.git
synced 2025-10-14 01:39:18 +00:00
删掉一些东西
This commit is contained in:
parent
0b6b9624a6
commit
d77a0dacad
@ -1,29 +1,17 @@
|
||||
import socket
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
|
||||
from fastapi import (APIRouter, HTTPException, Response)
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
import asyncio
|
||||
from fastapi.responses import HTMLResponse
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
import base64
|
||||
import psutil
|
||||
import ipaddress
|
||||
|
||||
from ..services.switch_traffic_monitor import get_switch_monitor
|
||||
from ..utils import logger
|
||||
from ...app.services.ai_services import AIService
|
||||
from ...app.api.network_config import SwitchConfigurator
|
||||
from ...config import settings
|
||||
from ..services.network_scanner import NetworkScanner
|
||||
from ...app.services.traffic_monitor import traffic_monitor
|
||||
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
from ..services.network_visualizer import NetworkVisualizer
|
||||
from ..services.config_validator import ConfigValidator
|
||||
from ..services.report_generator import ReportGenerator
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
|
||||
@ -59,22 +47,6 @@ class BatchConfigRequest(BaseModel):
|
||||
password: str = None
|
||||
timeout: int = None
|
||||
|
||||
@router.post("/batch_apply_config")
|
||||
async def batch_apply_config(request: BatchConfigRequest):
|
||||
results = {}
|
||||
for ip in request.switch_ips:
|
||||
try:
|
||||
configurator = SwitchConfigurator(
|
||||
username=request.username,
|
||||
password=request.password,
|
||||
timeout=request.timeout )
|
||||
results[ip] = await configurator.apply_config(ip, request.config)
|
||||
except Exception as e:
|
||||
results[ip] = str(e)
|
||||
return {"results": results}
|
||||
|
||||
|
||||
|
||||
@router.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "Hello World"}
|
||||
@ -196,47 +168,6 @@ async def execute_cli_commands(request: CLICommandRequest):
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
@router.get("/traffic/interfaces", summary="获取所有网络接口")
|
||||
async def get_network_interfaces():
|
||||
return {
|
||||
"interfaces": traffic_monitor.get_interfaces()
|
||||
}
|
||||
|
||||
@router.get("/traffic/current", summary="获取当前流量数据")
|
||||
async def get_current_traffic(interface: str = None):
|
||||
return traffic_monitor.get_current_traffic(interface)
|
||||
|
||||
@router.get("/traffic/history", summary="获取流量历史数据")
|
||||
async def get_traffic_history(interface: str = None, limit: int = 100):
|
||||
history = traffic_monitor.get_traffic_history(interface)
|
||||
return {
|
||||
"sent": history["sent"][-limit:],
|
||||
"recv": history["recv"][-limit:],
|
||||
"time": [t.isoformat() for t in history["time"]][-limit:]
|
||||
}
|
||||
|
||||
@router.get("/traffic/records", summary="获取流量记录")
|
||||
async def get_traffic_records(interface: str = None, limit: int = 100):
|
||||
with SessionLocal() as session:
|
||||
query = session.query(TrafficRecord)
|
||||
if interface:
|
||||
query = query.filter(TrafficRecord.interface == interface)
|
||||
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
|
||||
return [record.to_dict() for record in records]
|
||||
|
||||
@router.websocket("/ws/traffic")
|
||||
async def websocket_traffic(websocket: WebSocket):
|
||||
"""实时流量WebSocket"""
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
traffic_data = traffic_monitor.get_current_traffic()
|
||||
await websocket.send_json(traffic_data)
|
||||
await asyncio.sleep(1)
|
||||
except WebSocketDisconnect:
|
||||
print("客户端断开连接")
|
||||
|
||||
|
||||
@router.get("/", include_in_schema=False)
|
||||
async def root():
|
||||
return {
|
||||
@ -253,195 +184,6 @@ async def root():
|
||||
"/traffic/switch/history"
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
|
||||
async def get_switch_interfaces(switch_ip: str):
|
||||
"""获取指定交换机的所有接口"""
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
interfaces = list(monitor.interface_oids.keys())
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interfaces": interfaces
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取交换机接口失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取接口失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
|
||||
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
|
||||
"""获取交换机的当前流量数据"""
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
|
||||
|
||||
if not interface:
|
||||
traffic_data = {}
|
||||
for iface in monitor.interface_oids:
|
||||
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"traffic": traffic_data
|
||||
}
|
||||
|
||||
return await get_interface_current_traffic(switch_ip, interface)
|
||||
except Exception as e:
|
||||
logger.error(f"获取交换机流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取流量失败: {str(e)}")
|
||||
|
||||
|
||||
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
|
||||
"""获取指定交换机接口的当前流量数据"""
|
||||
try:
|
||||
with SessionLocal() as session:
|
||||
|
||||
record = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.switch_ip == switch_ip,
|
||||
SwitchTrafficRecord.interface == interface
|
||||
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
|
||||
|
||||
if not record:
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"rate_in": 0.0,
|
||||
"rate_out": 0.0,
|
||||
"bytes_in": 0,
|
||||
"bytes_out": 0
|
||||
}
|
||||
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"rate_in": record.rate_in,
|
||||
"rate_out": record.rate_out,
|
||||
"bytes_in": record.bytes_in,
|
||||
"bytes_out": record.bytes_out
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取接口流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
|
||||
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
|
||||
"""获取交换机的流量历史数据"""
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
|
||||
if not interface:
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"history": monitor.get_traffic_history()
|
||||
}
|
||||
|
||||
with SessionLocal() as session:
|
||||
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
||||
|
||||
records = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.switch_ip == switch_ip,
|
||||
SwitchTrafficRecord.interface == interface,
|
||||
SwitchTrafficRecord.timestamp >= time_threshold
|
||||
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
|
||||
|
||||
history_data = {
|
||||
"in": [record.rate_in for record in records],
|
||||
"out": [record.rate_out for record in records],
|
||||
"time": [record.timestamp.isoformat() for record in records]
|
||||
}
|
||||
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"history": history_data
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取历史流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
|
||||
|
||||
|
||||
@router.websocket("/ws/traffic/switch")
|
||||
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
|
||||
"""交换机实时流量WebSocket"""
|
||||
await websocket.accept()
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
|
||||
while True:
|
||||
if interface:
|
||||
traffic_data = await get_interface_current_traffic(switch_ip, interface)
|
||||
await websocket.send_json(traffic_data)
|
||||
else:
|
||||
traffic_data = {}
|
||||
for iface in monitor.interface_oids:
|
||||
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
|
||||
|
||||
await websocket.send_json({
|
||||
"switch_ip": switch_ip,
|
||||
"traffic": traffic_data
|
||||
})
|
||||
|
||||
await asyncio.sleep(1)
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
|
||||
except Exception as e:
|
||||
logger.error(f"交换机流量WebSocket错误: {str(e)}")
|
||||
await websocket.close(code=1011, reason=str(e))
|
||||
|
||||
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
|
||||
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
|
||||
"""生成交换机流量图表"""
|
||||
try:
|
||||
history = await get_switch_traffic_history(switch_ip, interface, minutes)
|
||||
history_data = history["history"]
|
||||
|
||||
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
|
||||
in_rates = history_data["in"]
|
||||
out_rates = history_data["out"]
|
||||
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
|
||||
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
|
||||
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
|
||||
plt.xlabel("时间")
|
||||
plt.ylabel("流量 (字节/秒)")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
|
||||
plt.close()
|
||||
|
||||
return f"""
|
||||
<html>
|
||||
<head>
|
||||
<title>交换机流量监控</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
||||
.container {{ max-width: 900px; margin: 0 auto; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
|
||||
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
|
||||
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.error(f"生成流量图表失败: {str(e)}")
|
||||
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)
|
||||
|
||||
|
||||
@router.get("/network_adapters", summary="获取网络适配器网段")
|
||||
async def get_network_adapters():
|
||||
try:
|
||||
@ -475,7 +217,7 @@ report_gen = ReportGenerator()
|
||||
async def visualize_topology():
|
||||
"""获取网络拓扑可视化图"""
|
||||
try:
|
||||
devices = await list_devices() # 复用现有的设备列表接口
|
||||
devices = await list_devices()
|
||||
visualizer.update_topology(devices["devices"])
|
||||
image_data = visualizer.generate_topology_image()
|
||||
|
||||
@ -500,34 +242,4 @@ async def validate_config(config: dict):
|
||||
"valid": is_valid,
|
||||
"errors": errors,
|
||||
"has_security_risks": len(ConfigValidator.check_security_risks(config.get("commands", []))) > 0
|
||||
}
|
||||
|
||||
|
||||
@router.get("/reports/traffic/{ip}")
|
||||
async def get_traffic_report(ip: str, days: int = 1):
|
||||
"""获取流量分析报告"""
|
||||
try:
|
||||
report = report_gen.generate_traffic_report(ip, days)
|
||||
return JSONResponse(content=report)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/reports/traffic")
|
||||
async def get_local_traffic_report(days: int = 1):
|
||||
"""获取本地网络流量报告"""
|
||||
try:
|
||||
report = report_gen.generate_traffic_report(days=days)
|
||||
return JSONResponse(content=report)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/topology/traffic_heatmap")
|
||||
async def get_traffic_heatmap(minutes: int = 10):
|
||||
"""获取流量热力图数据"""
|
||||
try:
|
||||
heatmap = visualizer.get_traffic_heatmap(minutes)
|
||||
return {"heatmap": heatmap}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
}
|
@ -42,7 +42,7 @@ class AIService:
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
response = self.client.chat.completions.create(
|
||||
model="deepseek-ai/DeepSeek-V3",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
|
@ -15,7 +15,7 @@ class NetworkScanner:
|
||||
|
||||
devices = []
|
||||
try:
|
||||
self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}')
|
||||
await self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}')
|
||||
for host in self.nm.all_hosts():
|
||||
ip = host
|
||||
mac = self.nm[host]['addresses'].get('mac', 'N/A')
|
||||
|
@ -8,6 +8,7 @@ from typing import Dict, Optional, List
|
||||
|
||||
from ..models.traffic_models import TrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
from ..utils.logger import logger
|
||||
|
||||
|
||||
class TrafficMonitor:
|
||||
@ -33,7 +34,7 @@ class TrafficMonitor:
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._monitor_loop())
|
||||
print("流量监控已启动")
|
||||
logger.info("流量监控已启动")
|
||||
|
||||
async def stop_monitoring(self):
|
||||
"""停止流量监控"""
|
||||
@ -44,7 +45,7 @@ class TrafficMonitor:
|
||||
await self.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
print("流量监控已停止")
|
||||
logger.info("流量监控已停止")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""监控主循环"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user