mirror of
https://github.com/Jerryplusy/AI-powered-switches.git
synced 2025-10-14 09:49:19 +00:00
Compare commits
No commits in common. "08a6ac56b81803fecaf40e2019d3e7b5003f4055" and "d11decae6ab1def03d755207d2f553a088fbf048" have entirely different histories.
08a6ac56b8
...
d11decae6a
3
.idea/AI-powered-switches.iml
generated
3
.idea/AI-powered-switches.iml
generated
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="FacetManager">
|
||||
<facet type="Python" name="Python facet">
|
||||
<configuration sdkName="Python 3.12 (AI-powered-switches)" />
|
||||
<configuration sdkName="Python 3.13" />
|
||||
</facet>
|
||||
</component>
|
||||
<component name="NewModuleRootManager">
|
||||
@ -11,6 +11,5 @@
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.13" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
<orderEntry type="library" name="Python 3.12 (AI-powered-switches) interpreter library" level="application" />
|
||||
</component>
|
||||
</module>
|
@ -4,8 +4,13 @@ from starlette.middleware import Middleware # 新增导入
|
||||
from src.backend.app.api.endpoints import router
|
||||
from src.backend.app.utils.logger import setup_logging
|
||||
from src.backend.config import settings
|
||||
from .services.switch_traffic_monitor import get_switch_monitor
|
||||
from .services.traffic_monitor import traffic_monitor
|
||||
from src.backend.app.api.database import init_db
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
init_db()
|
||||
traffic_monitor.start_monitoring()
|
||||
setup_logging()
|
||||
|
||||
app = FastAPI(
|
||||
|
16
src/backend/app/api/database.py
Normal file
16
src/backend/app/api/database.py
Normal file
@ -0,0 +1,16 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./traffic_monitor.db"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库"""
|
||||
Base.metadata.create_all(bind=engine)
|
@ -1,20 +1,26 @@
|
||||
import socket
|
||||
from fastapi import (APIRouter, HTTPException, Response)
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
import asyncio
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
import base64
|
||||
import psutil
|
||||
import ipaddress
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from ..models.requests import CLICommandRequest, ConfigRequest, TrafficQueryRequest
|
||||
from ..models.requests import CLICommandRequest, ConfigRequest
|
||||
from ..services.switch_traffic_monitor import get_switch_monitor
|
||||
from ..utils import logger
|
||||
from ...app.services.ai_services import AIService
|
||||
from ...app.api.network_config import SwitchConfigurator
|
||||
from ...config import settings
|
||||
from ..services.network_scanner import NetworkScanner
|
||||
from ..services.traffic_monitor import traffic_monitor
|
||||
from ..utils.logger import logger
|
||||
from ...app.services.traffic_monitor import traffic_monitor
|
||||
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
|
||||
router = APIRouter(prefix="", tags=["API"])
|
||||
scanner = NetworkScanner()
|
||||
@ -32,8 +38,8 @@ async def root():
|
||||
"/scan_network",
|
||||
"/list_devices",
|
||||
"/batch_apply_config",
|
||||
"/traffic/realtime",
|
||||
"/traffic/cache/clear"
|
||||
"/traffic/switch/current",
|
||||
"/traffic/switch/history"
|
||||
]
|
||||
}
|
||||
|
||||
@ -145,6 +151,234 @@ async def execute_cli_commands(request: CLICommandRequest):
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/traffic/interfaces", summary="获取所有网络接口")
|
||||
async def get_network_interfaces():
|
||||
return {
|
||||
"interfaces": await asyncio.to_thread(traffic_monitor.get_interfaces)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/traffic/current", summary="获取当前流量数据")
|
||||
async def get_current_traffic(interface: str = None):
|
||||
return await asyncio.to_thread(traffic_monitor.get_current_traffic, interface)
|
||||
|
||||
|
||||
@router.get("/traffic/history", summary="获取流量历史数据")
|
||||
async def get_traffic_history(interface: str = None, limit: int = 100):
|
||||
history = await asyncio.to_thread(traffic_monitor.get_traffic_history, interface)
|
||||
return {
|
||||
"sent": history["sent"][-limit:],
|
||||
"recv": history["recv"][-limit:],
|
||||
"time": [t.isoformat() for t in history["time"]][-limit:]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/traffic/records", summary="获取流量记录")
|
||||
async def get_traffic_records(interface: str = None, limit: int = 100):
|
||||
def sync_get_records():
|
||||
with SessionLocal() as session:
|
||||
query = session.query(TrafficRecord)
|
||||
if interface:
|
||||
query = query.filter(TrafficRecord.interface == interface)
|
||||
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
|
||||
return [record.to_dict() for record in records]
|
||||
|
||||
return await asyncio.to_thread(sync_get_records)
|
||||
|
||||
|
||||
@router.websocket("/ws/traffic")
|
||||
async def websocket_traffic(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
traffic_data = await asyncio.to_thread(traffic_monitor.get_current_traffic)
|
||||
await websocket.send_json(traffic_data)
|
||||
await asyncio.sleep(1)
|
||||
except WebSocketDisconnect:
|
||||
print("客户端断开连接")
|
||||
|
||||
|
||||
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
|
||||
async def get_switch_interfaces(switch_ip: str):
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
interfaces = list(monitor.interface_oids.keys())
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interfaces": interfaces
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取交换机接口失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取接口失败: {str(e)}")
|
||||
|
||||
|
||||
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
|
||||
"""获取指定交换机接口的当前流量数据"""
|
||||
try:
|
||||
def sync_get_record():
|
||||
with SessionLocal() as session:
|
||||
record = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.switch_ip == switch_ip,
|
||||
SwitchTrafficRecord.interface == interface
|
||||
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
|
||||
|
||||
if not record:
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"rate_in": 0.0,
|
||||
"rate_out": 0.0,
|
||||
"bytes_in": 0,
|
||||
"bytes_out": 0
|
||||
}
|
||||
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"rate_in": record.rate_in,
|
||||
"rate_out": record.rate_out,
|
||||
"bytes_in": record.bytes_in,
|
||||
"bytes_out": record.bytes_out
|
||||
}
|
||||
|
||||
return await asyncio.to_thread(sync_get_record)
|
||||
except Exception as e:
|
||||
logger.error(f"获取接口流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
|
||||
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
if not interface:
|
||||
traffic_data = {}
|
||||
for iface in monitor.interface_oids:
|
||||
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"traffic": traffic_data
|
||||
}
|
||||
return await get_interface_current_traffic(switch_ip, interface)
|
||||
except Exception as e:
|
||||
logger.error(f"获取交换机流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取流量失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
|
||||
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
if not interface:
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"history": await asyncio.to_thread(monitor.get_traffic_history)
|
||||
}
|
||||
|
||||
def sync_get_history():
|
||||
with SessionLocal() as session:
|
||||
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
||||
records = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.switch_ip == switch_ip,
|
||||
SwitchTrafficRecord.interface == interface,
|
||||
SwitchTrafficRecord.timestamp >= time_threshold
|
||||
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
|
||||
|
||||
history_data = {
|
||||
"in": [record.rate_in for record in records],
|
||||
"out": [record.rate_out for record in records],
|
||||
"time": [record.timestamp.isoformat() for record in records]
|
||||
}
|
||||
return history_data
|
||||
|
||||
history_data = await asyncio.to_thread(sync_get_history)
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"history": history_data
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取历史流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
|
||||
|
||||
|
||||
@router.websocket("/ws/traffic/switch")
|
||||
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
|
||||
await websocket.accept()
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
while True:
|
||||
if interface:
|
||||
traffic_data = await get_interface_current_traffic(switch_ip, interface)
|
||||
await websocket.send_json(traffic_data)
|
||||
else:
|
||||
traffic_data = {}
|
||||
for iface in monitor.interface_oids:
|
||||
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
|
||||
await websocket.send_json({
|
||||
"switch_ip": switch_ip,
|
||||
"traffic": traffic_data
|
||||
})
|
||||
await asyncio.sleep(1)
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
|
||||
except Exception as e:
|
||||
logger.error(f"交换机流量WebSocket错误: {str(e)}")
|
||||
await websocket.close(code=1011, reason=str(e))
|
||||
|
||||
|
||||
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
|
||||
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
|
||||
try:
|
||||
history = await get_switch_traffic_history(switch_ip, interface, minutes)
|
||||
history_data = history["history"]
|
||||
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
|
||||
in_rates = history_data["in"]
|
||||
out_rates = history_data["out"]
|
||||
|
||||
def generate_plot():
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
|
||||
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
|
||||
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
|
||||
plt.xlabel("时间")
|
||||
plt.ylabel("流量 (字节/秒)")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
|
||||
plt.close()
|
||||
return image_base64
|
||||
|
||||
image_base64 = await asyncio.to_thread(generate_plot)
|
||||
return f"""
|
||||
<html>
|
||||
<head>
|
||||
<title>交换机流量监控</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
||||
.container {{ max-width: 900px; margin: 0 auto; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
|
||||
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
|
||||
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.error(f"生成流量图表失败: {str(e)}")
|
||||
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)
|
||||
|
||||
|
||||
@router.get("/network_adapters", summary="获取网络适配器网段")
|
||||
async def get_network_adapters():
|
||||
try:
|
||||
@ -169,131 +403,3 @@ async def get_network_adapters():
|
||||
return {"networks": networks}
|
||||
except Exception as e:
|
||||
return {"error": f"获取网络适配器信息失败: {str(e)}"}
|
||||
|
||||
|
||||
@router.post("/traffic/realtime", summary="查询交换机接口实时流量")
|
||||
async def get_realtime_traffic(request: TrafficQueryRequest):
|
||||
"""
|
||||
查询交换机接口实时流量速率(Kbps)
|
||||
|
||||
- 支持多个接口同时查询
|
||||
- 首次查询速率返回 0
|
||||
- 单个接口查询失败不影响其他接口
|
||||
"""
|
||||
# 提取认证信息
|
||||
username = request.username or settings.SWITCH_USERNAME
|
||||
password = request.password or settings.SWITCH_PASSWORD
|
||||
|
||||
# 创建配置器(复用连接池)
|
||||
configurator = SwitchConfigurator(
|
||||
username=username,
|
||||
password=password,
|
||||
timeout=settings.SWITCH_TIMEOUT
|
||||
)
|
||||
|
||||
results = []
|
||||
current_time = time.time()
|
||||
|
||||
# 遍历所有接口
|
||||
for interface in request.interfaces:
|
||||
interface_data = {
|
||||
"interface": interface,
|
||||
"status": "unknown",
|
||||
"in_speed_kbps": 0.0,
|
||||
"out_speed_kbps": 0.0,
|
||||
"in_bytes": 0,
|
||||
"out_bytes": 0,
|
||||
"error": None
|
||||
}
|
||||
|
||||
try:
|
||||
# 获取查询命令
|
||||
command = traffic_monitor.get_query_command(request.vendor, interface)
|
||||
if not command:
|
||||
interface_data["error"] = f"不支持的厂商: {request.vendor}"
|
||||
results.append(interface_data)
|
||||
continue
|
||||
|
||||
# 执行查询命令
|
||||
try:
|
||||
output = await configurator.execute_raw_commands(
|
||||
ip=request.switch_ip,
|
||||
commands=[command]
|
||||
)
|
||||
|
||||
# 解析输出
|
||||
stats = traffic_monitor.parse_interface_stats(request.vendor, str(output))
|
||||
if stats is None:
|
||||
interface_data["error"] = "解析接口统计失败"
|
||||
results.append(interface_data)
|
||||
continue
|
||||
|
||||
in_bytes, out_bytes, status = stats
|
||||
|
||||
# 计算速率
|
||||
in_speed_kbps, out_speed_kbps = traffic_monitor.calculate_speed(
|
||||
request.switch_ip,
|
||||
interface,
|
||||
in_bytes,
|
||||
out_bytes,
|
||||
current_time
|
||||
)
|
||||
|
||||
# 更新结果
|
||||
interface_data.update({
|
||||
"status": status,
|
||||
"in_speed_kbps": round(in_speed_kbps, 2),
|
||||
"out_speed_kbps": round(out_speed_kbps, 2),
|
||||
"in_bytes": in_bytes,
|
||||
"out_bytes": out_bytes
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
interface_data["error"] = f"查询失败: {str(e)}"
|
||||
logger.error(f"查询接口 {interface} 失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
interface_data["error"] = f"未知错误: {str(e)}"
|
||||
logger.error(f"处理接口 {interface} 时发生异常: {e}", exc_info=True)
|
||||
|
||||
results.append(interface_data)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"switch_ip": request.switch_ip,
|
||||
"vendor": request.vendor,
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"data": results
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/traffic/cache/clear", summary="清除流量监控缓存")
|
||||
async def clear_traffic_cache(switch_ip: Optional[str] = None):
|
||||
"""
|
||||
清除流量监控缓存
|
||||
|
||||
- 不指定 switch_ip: 清除所有缓存
|
||||
- 指定 switch_ip: 只清除该交换机的缓存
|
||||
"""
|
||||
try:
|
||||
count = traffic_monitor.clear_cache(switch_ip)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已清除 {count} 条缓存记录",
|
||||
"switch_ip": switch_ip or "all"
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"清除缓存失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/traffic/cache/stats", summary="获取缓存统计信息")
|
||||
async def get_cache_stats():
|
||||
"""获取流量监控缓存统计信息"""
|
||||
try:
|
||||
stats = traffic_monitor.get_cache_stats()
|
||||
return {
|
||||
"success": True,
|
||||
"stats": stats
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"获取缓存统计失败: {str(e)}")
|
@ -24,11 +24,3 @@ class CLICommandRequest(BaseModel):
|
||||
|
||||
def extract_credentials(self):
|
||||
return self.username or "NONE", self.password or "NONE"
|
||||
|
||||
class TrafficQueryRequest(BaseModel):
|
||||
"""实时流量查询请求"""
|
||||
switch_ip: str
|
||||
vendor: str # huawei/cisco/h3c/ruijie/zte
|
||||
interfaces: List[str] # 例如: ["GigabitEthernet0/0/1", "GigabitEthernet0/0/2"]
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
|
51
src/backend/app/models/traffic_models.py
Normal file
51
src/backend/app/models/traffic_models.py
Normal file
@ -0,0 +1,51 @@
|
||||
from src.backend.app.api.database import Base # 修复:导入 Base
|
||||
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, Float
|
||||
|
||||
|
||||
class TrafficRecord(Base):
|
||||
"""网络流量记录模型"""
|
||||
__tablename__ = "traffic_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
interface = Column(String(50), index=True)
|
||||
bytes_sent = Column(Integer)
|
||||
bytes_recv = Column(Integer)
|
||||
packets_sent = Column(Integer)
|
||||
packets_recv = Column(Integer)
|
||||
timestamp = Column(DateTime)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"interface": self.interface,
|
||||
"bytes_sent": self.bytes_sent,
|
||||
"bytes_recv": self.bytes_recv,
|
||||
"packets_sent": self.packets_sent,
|
||||
"packets_recv": self.packets_recv,
|
||||
"timestamp": self.timestamp.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class SwitchTrafficRecord(Base):
|
||||
__tablename__ = "switch_traffic_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
switch_ip = Column(String(50), index=True)
|
||||
interface = Column(String(50))
|
||||
bytes_in = Column(BigInteger)
|
||||
bytes_out = Column(BigInteger)
|
||||
rate_in = Column(Float)
|
||||
rate_out = Column(Float)
|
||||
timestamp = Column(DateTime)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"switch_ip": self.switch_ip,
|
||||
"interface": self.interface,
|
||||
"bytes_in": self.bytes_in,
|
||||
"bytes_out": self.bytes_out,
|
||||
"rate_in": self.rate_in,
|
||||
"rate_out": self.rate_out,
|
||||
"timestamp": self.timestamp.isoformat()
|
||||
}
|
182
src/backend/app/services/switch_traffic_monitor.py
Normal file
182
src/backend/app/services/switch_traffic_monitor.py
Normal file
@ -0,0 +1,182 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
from typing import Optional, List, Dict
|
||||
from pysnmp.hlapi import *
|
||||
from ..models.traffic_models import SwitchTrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
from ..utils.logger import logger
|
||||
|
||||
#V=ΔQ'-ΔQ/Δt (B/s)
|
||||
class SwitchTrafficMonitor:
|
||||
def __init__(
|
||||
self,
|
||||
switch_ip: str,
|
||||
community: str = 'public',
|
||||
update_interval: int = 5,
|
||||
interfaces: Optional[List[str]] = None
|
||||
):
|
||||
self.switch_ip = switch_ip
|
||||
self.community = community
|
||||
self.update_interval = update_interval
|
||||
self.running = False
|
||||
self.task = None
|
||||
self.interface_history = {}
|
||||
self.history = {
|
||||
"in": deque(maxlen=300),
|
||||
"out": deque(maxlen=300),
|
||||
"time": deque(maxlen=300)
|
||||
}
|
||||
|
||||
self.interface_oids = {
|
||||
"GigabitEthernet0/0/1": {
|
||||
"in": '1.3.6.1.2.1.2.2.1.10.1',
|
||||
"out": '1.3.6.1.2.1.2.2.1.16.1'
|
||||
},
|
||||
"GigabitEthernet0/0/24": {
|
||||
"in": '1.3.6.1.2.1.2.2.1.10.24',
|
||||
"out": '1.3.6.1.2.1.2.2.1.16.24'
|
||||
}
|
||||
}
|
||||
|
||||
if interfaces:
|
||||
self.interface_oids = {
|
||||
iface: oid for iface, oid in self.interface_oids.items()
|
||||
if iface in interfaces
|
||||
}
|
||||
logger.info(f"监控指定接口: {', '.join(interfaces)}")
|
||||
else:
|
||||
logger.info("监控所有接口")
|
||||
|
||||
def start_monitoring(self):
|
||||
"""启动交换机流量监控"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._monitor_loop())
|
||||
logger.success(f"交换机流量监控已启动: {self.switch_ip}")
|
||||
|
||||
async def stop_monitoring(self):
|
||||
"""停止监控"""
|
||||
if self.running:
|
||||
self.running = False
|
||||
if self.task:
|
||||
self.task.cancel()
|
||||
try:
|
||||
await self.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info(f"交换机流量监控已停止: {self.switch_ip}")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""监控主循环"""
|
||||
last_values = {iface: {"in": 0, "out": 0} for iface in self.interface_oids}
|
||||
last_time = datetime.now()
|
||||
|
||||
while self.running:
|
||||
await asyncio.sleep(self.update_interval)
|
||||
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
elapsed = (current_time - last_time).total_seconds()
|
||||
|
||||
for iface, oids in self.interface_oids.items():
|
||||
in_octets = self._snmp_get(oids["in"])
|
||||
out_octets = self._snmp_get(oids["out"])
|
||||
|
||||
if in_octets is not None and out_octets is not None:
|
||||
|
||||
iface_values = last_values[iface]
|
||||
in_rate = (in_octets - iface_values["in"]) / elapsed if iface_values["in"] > 0 else 0
|
||||
out_rate = (out_octets - iface_values["out"]) / elapsed if iface_values["out"] > 0 else 0
|
||||
|
||||
self.history["in"].append(in_rate)
|
||||
self.history["out"].append(out_rate)
|
||||
self.history["time"].append(current_time)
|
||||
|
||||
self._save_to_db(iface, in_octets, out_octets, in_rate, out_rate, current_time)
|
||||
|
||||
iface_values["in"] = in_octets
|
||||
iface_values["out"] = out_octets
|
||||
|
||||
last_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"监控交换机流量出错: {str(e)}")
|
||||
|
||||
def _snmp_get(self, oid) -> Optional[int]:
|
||||
"""执行SNMP GET请求"""
|
||||
try:
|
||||
cmd = getCmd(
|
||||
SnmpEngine(),
|
||||
CommunityData(self.community),
|
||||
UdpTransportTarget((self.switch_ip, 161)),
|
||||
ContextData(),
|
||||
ObjectType(ObjectIdentity(oid)))
|
||||
|
||||
errorIndication, errorStatus, errorIndex, varBinds = next(cmd)
|
||||
except Exception as e:
|
||||
logger.error(f"SNMP请求失败: {str(e)}")
|
||||
return None
|
||||
|
||||
if errorIndication:
|
||||
logger.error(f"SNMP错误: {errorIndication}")
|
||||
return None
|
||||
elif errorStatus:
|
||||
try:
|
||||
if errorIndex:
|
||||
index_val = int(errorIndex) - 1
|
||||
error_item = varBinds[index_val] if index_val < len(varBinds) else '?'
|
||||
else:
|
||||
error_item = '?'
|
||||
|
||||
error_msg = f"SNMP错误: {errorStatus.prettyPrint()} at {error_item}"
|
||||
logger.error(error_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"解析SNMP错误失败: {str(e)}")
|
||||
return None
|
||||
else:
|
||||
for varBind in varBinds:
|
||||
try:
|
||||
return int(varBind[1])
|
||||
except Exception as e:
|
||||
logger.error(f"转换SNMP值失败: {str(e)}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _save_to_db(self, interface: str, in_octets: int, out_octets: int,
|
||||
in_rate: float, out_rate: float, timestamp: datetime):
|
||||
"""保存流量数据到数据库"""
|
||||
try:
|
||||
with SessionLocal() as session:
|
||||
record = SwitchTrafficRecord(
|
||||
switch_ip=self.switch_ip,
|
||||
interface=interface,
|
||||
bytes_in=in_octets,
|
||||
bytes_out=out_octets,
|
||||
rate_in=in_rate,
|
||||
rate_out=out_rate,
|
||||
timestamp=timestamp
|
||||
)
|
||||
session.add(record)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"保存流量数据到数据库失败: {str(e)}")
|
||||
|
||||
def get_traffic_history(self) -> Dict[str, List]:
|
||||
"""获取流量历史数据"""
|
||||
return {
|
||||
"in": list(self.history["in"]),
|
||||
"out": list(self.history["out"]),
|
||||
"time": list(self.history["time"])
|
||||
}
|
||||
|
||||
switch_monitors = {}
|
||||
def get_switch_monitor(switch_ip: str, community: str = 'public', interfaces: Optional[List[str]] = None):
|
||||
"""获取或创建交换机监控器"""
|
||||
if switch_ip not in switch_monitors:
|
||||
switch_monitors[switch_ip] = SwitchTrafficMonitor(
|
||||
switch_ip,
|
||||
community,
|
||||
interfaces=interfaces
|
||||
)
|
||||
return switch_monitors[switch_ip]
|
9
src/backend/app/services/test.py
Normal file
9
src/backend/app/services/test.py
Normal file
@ -0,0 +1,9 @@
|
||||
# test_linprog.py
|
||||
import numpy as np
|
||||
from scipy.optimize import linprog
|
||||
|
||||
c = np.array([-1, -2])
|
||||
A_ub = np.array([[1, 1]])
|
||||
b_ub = np.array([3])
|
||||
res = linprog(c, A_ub=A_ub, b_ub=b_ub, method='highs')
|
||||
print(res)
|
@ -1,195 +1,145 @@
|
||||
import re
|
||||
import psutil
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
from src.backend.app.utils.logger import logger
|
||||
|
||||
from ..models.traffic_models import TrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
from ..utils.logger import logger
|
||||
|
||||
|
||||
class TrafficMonitor:
|
||||
"""
|
||||
交换机流量监控服务
|
||||
通过 Telnet CLI 查询接口流量统计,计算实时速率
|
||||
"""
|
||||
|
||||
# 各厂商查询接口流量的CLI命令模板
|
||||
VENDOR_COMMANDS = {
|
||||
"huawei": "display interface {interface}",
|
||||
"cisco": "show interface {interface}",
|
||||
"h3c": "display interface {interface}",
|
||||
"ruijie": "show interface {interface}",
|
||||
"zte": "show interface {interface}",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
# 流量计数器缓存: key = "switch_ip:interface", value = {"timestamp": float, "in_bytes": int, "out_bytes": int}
|
||||
self.traffic_cache: Dict[str, Dict] = {}
|
||||
|
||||
# 缓存TTL(秒), 30分钟未访问则自动清理
|
||||
self.cache_ttl = 1800
|
||||
|
||||
def get_query_command(self, vendor: str, interface: str) -> Optional[str]:
|
||||
"""根据厂商和接口名生成查询命令"""
|
||||
vendor_lower = vendor.lower()
|
||||
if vendor_lower not in self.VENDOR_COMMANDS:
|
||||
logger.warning(f"不支持的厂商: {vendor}")
|
||||
return None
|
||||
|
||||
return self.VENDOR_COMMANDS[vendor_lower].format(interface=interface)
|
||||
|
||||
def parse_interface_stats(self, vendor: str, output: str) -> Optional[Tuple[int, int, str]]:
|
||||
"""
|
||||
解析CLI输出,提取入/出方向字节数和接口状态
|
||||
|
||||
返回: (in_bytes, out_bytes, status) 或 None
|
||||
"""
|
||||
vendor_lower = vendor.lower()
|
||||
|
||||
try:
|
||||
# 提取接口状态
|
||||
status = "unknown"
|
||||
if re.search(r'(current state|line protocol).*?(UP|up)', output, re.IGNORECASE):
|
||||
status = "up"
|
||||
elif re.search(r'(current state|line protocol).*?(DOWN|down)', output, re.IGNORECASE):
|
||||
status = "down"
|
||||
|
||||
# 华为/H3C格式: "Input: 12345 packets, 1048576000 bytes"
|
||||
if vendor_lower in ["huawei", "h3c"]:
|
||||
match_in = re.search(r'Input:.*?(\d+)\s+bytes', output, re.IGNORECASE)
|
||||
match_out = re.search(r'Output:.*?(\d+)\s+bytes', output, re.IGNORECASE)
|
||||
|
||||
if match_in and match_out:
|
||||
in_bytes = int(match_in.group(1))
|
||||
out_bytes = int(match_out.group(1))
|
||||
return (in_bytes, out_bytes, status)
|
||||
|
||||
# Cisco/锐捷/中兴格式: "12345 packets input, 1048576000 bytes"
|
||||
elif vendor_lower in ["cisco", "ruijie", "zte"]:
|
||||
match_in = re.search(r'(\d+)\s+packets input,\s+(\d+)\s+bytes', output, re.IGNORECASE)
|
||||
match_out = re.search(r'(\d+)\s+packets output,\s+(\d+)\s+bytes', output, re.IGNORECASE)
|
||||
|
||||
if match_in and match_out:
|
||||
in_bytes = int(match_in.group(2))
|
||||
out_bytes = int(match_out.group(2))
|
||||
return (in_bytes, out_bytes, status)
|
||||
|
||||
logger.warning(f"无法解析 {vendor} 厂商的输出")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析接口统计失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def calculate_speed(
|
||||
self,
|
||||
switch_ip: str,
|
||||
interface: str,
|
||||
current_in: int,
|
||||
current_out: int,
|
||||
current_time: float
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
计算接口速率(Kbps)
|
||||
|
||||
返回: (in_speed_kbps, out_speed_kbps)
|
||||
"""
|
||||
cache_key = f"{switch_ip}:{interface}"
|
||||
|
||||
# 检查是否有历史数据
|
||||
if cache_key not in self.traffic_cache:
|
||||
# 首次查询,保存数据但返回0速率
|
||||
self.traffic_cache[cache_key] = {
|
||||
"timestamp": current_time,
|
||||
"in_bytes": current_in,
|
||||
"out_bytes": current_out
|
||||
}
|
||||
logger.info(f"首次查询 {cache_key}, 速率返回 0")
|
||||
return (0.0, 0.0)
|
||||
|
||||
# 获取历史数据
|
||||
cached = self.traffic_cache[cache_key]
|
||||
time_diff = current_time - cached["timestamp"]
|
||||
|
||||
# 时间间隔太短,避免除零
|
||||
if time_diff < 0.1:
|
||||
logger.warning(f"{cache_key} 查询间隔过短 ({time_diff}s), 返回上次速率")
|
||||
return (0.0, 0.0)
|
||||
|
||||
# 计算字节差(处理计数器溢出)
|
||||
in_diff = self._calculate_diff(current_in, cached["in_bytes"])
|
||||
out_diff = self._calculate_diff(current_out, cached["out_bytes"])
|
||||
|
||||
# 计算速率: (字节差 * 8 bits/byte) / 时间差(秒) / 1000 = Kbps
|
||||
in_speed_kbps = (in_diff * 8) / time_diff / 1000
|
||||
out_speed_kbps = (out_diff * 8) / time_diff / 1000
|
||||
|
||||
# 更新缓存
|
||||
self.traffic_cache[cache_key] = {
|
||||
"timestamp": current_time,
|
||||
"in_bytes": current_in,
|
||||
"out_bytes": current_out
|
||||
def __init__(self, history_size: int = 300):
|
||||
self.history_size = history_size
|
||||
self.history = {
|
||||
"sent": deque(maxlen=history_size),
|
||||
"recv": deque(maxlen=history_size),
|
||||
"time": deque(maxlen=history_size),
|
||||
"interfaces": {}
|
||||
}
|
||||
self.running = False
|
||||
self.task = None
|
||||
self.update_interval = 1.0 # 秒
|
||||
|
||||
logger.debug(f"{cache_key} 速率: IN={in_speed_kbps:.2f} Kbps, OUT={out_speed_kbps:.2f} Kbps")
|
||||
return (in_speed_kbps, out_speed_kbps)
|
||||
@staticmethod
|
||||
def get_interfaces() -> List[str]:
|
||||
"""获取所有网络接口名称"""
|
||||
return list(psutil.net_io_counters(pernic=True).keys())
|
||||
|
||||
def _calculate_diff(self, current: int, previous: int) -> int:
|
||||
"""
|
||||
计算字节差,处理32位计数器溢出
|
||||
def start_monitoring(self):
|
||||
"""启动流量监控"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._monitor_loop())
|
||||
logger.info("流量监控已启动")
|
||||
|
||||
Reason: 交换机的流量计数器通常是32位,超过4GB会回绕到0
|
||||
"""
|
||||
if current >= previous:
|
||||
return current - previous
|
||||
else:
|
||||
# 计数器溢出,假设32位
|
||||
return (2**32 - previous) + current
|
||||
async def stop_monitoring(self):
|
||||
"""停止流量监控"""
|
||||
if self.running:
|
||||
self.running = False
|
||||
self.task.cancel()
|
||||
try:
|
||||
await self.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("流量监控已停止")
|
||||
|
||||
def cleanup_expired_cache(self):
|
||||
"""清理过期的缓存数据"""
|
||||
current_time = time.time()
|
||||
expired_keys = []
|
||||
async def _monitor_loop(self):
|
||||
"""监控主循环"""
|
||||
last_stats = psutil.net_io_counters(pernic=True)
|
||||
last_time = time.time()
|
||||
|
||||
for key, data in self.traffic_cache.items():
|
||||
if current_time - data["timestamp"] > self.cache_ttl:
|
||||
expired_keys.append(key)
|
||||
while self.running:
|
||||
await asyncio.sleep(self.update_interval)
|
||||
|
||||
for key in expired_keys:
|
||||
del self.traffic_cache[key]
|
||||
logger.info(f"清理过期缓存: {key}")
|
||||
current_time = time.time()
|
||||
current_stats = psutil.net_io_counters(pernic=True)
|
||||
elapsed = current_time - last_time
|
||||
|
||||
return len(expired_keys)
|
||||
for iface in current_stats:
|
||||
if iface not in self.history["interfaces"]:
|
||||
|
||||
def clear_cache(self, switch_ip: Optional[str] = None):
|
||||
"""
|
||||
清除缓存
|
||||
self.history["interfaces"][iface] = {
|
||||
"sent": deque(maxlen=self.history_size),
|
||||
"recv": deque(maxlen=self.history_size)
|
||||
}
|
||||
|
||||
Args:
|
||||
switch_ip: 如果指定,只清除该交换机的缓存;否则清除所有
|
||||
"""
|
||||
if switch_ip:
|
||||
# 清除指定交换机的缓存
|
||||
keys_to_remove = [k for k in self.traffic_cache.keys() if k.startswith(f"{switch_ip}:")]
|
||||
for key in keys_to_remove:
|
||||
del self.traffic_cache[key]
|
||||
logger.info(f"清除交换机 {switch_ip} 的缓存,共 {len(keys_to_remove)} 条")
|
||||
return len(keys_to_remove)
|
||||
else:
|
||||
# 清除所有缓存
|
||||
count = len(self.traffic_cache)
|
||||
self.traffic_cache.clear()
|
||||
logger.info(f"清除所有缓存,共 {count} 条")
|
||||
return count
|
||||
if iface in last_stats:
|
||||
sent_rate = (current_stats[iface].bytes_sent - last_stats[iface].bytes_sent) / elapsed
|
||||
recv_rate = (current_stats[iface].bytes_recv - last_stats[iface].bytes_recv) / elapsed
|
||||
|
||||
def get_cache_stats(self) -> Dict:
|
||||
"""获取缓存统计信息"""
|
||||
|
||||
self.history["sent"].append(sent_rate)
|
||||
self.history["recv"].append(recv_rate)
|
||||
self.history["time"].append(datetime.now())
|
||||
|
||||
|
||||
self.history["interfaces"][iface]["sent"].append(sent_rate)
|
||||
self.history["interfaces"][iface]["recv"].append(recv_rate)
|
||||
|
||||
|
||||
self._save_to_db(current_stats)
|
||||
|
||||
last_stats = current_stats
|
||||
last_time = current_time
|
||||
|
||||
@staticmethod
|
||||
def _save_to_db(stats):
|
||||
"""保存流量数据到数据库"""
|
||||
with SessionLocal() as session:
|
||||
for iface, counters in stats.items():
|
||||
record = TrafficRecord(
|
||||
interface=iface,
|
||||
bytes_sent=counters.bytes_sent,
|
||||
bytes_recv=counters.bytes_recv,
|
||||
packets_sent=counters.packets_sent,
|
||||
packets_recv=counters.packets_recv,
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
session.add(record)
|
||||
session.commit()
|
||||
|
||||
def get_current_traffic(self, interface: Optional[str] = None) -> Dict:
|
||||
"""获取当前流量数据"""
|
||||
stats = psutil.net_io_counters(pernic=True)
|
||||
|
||||
if interface:
|
||||
if interface in stats:
|
||||
return self._format_interface_stats(stats[interface])
|
||||
return {}
|
||||
|
||||
return {iface: self._format_interface_stats(data) for iface, data in stats.items()}
|
||||
|
||||
@staticmethod
|
||||
def _format_interface_stats(counters) -> Dict:
|
||||
"""格式化接口统计数据"""
|
||||
return {
|
||||
"total_entries": len(self.traffic_cache),
|
||||
"cache_ttl_seconds": self.cache_ttl,
|
||||
"entries": list(self.traffic_cache.keys())
|
||||
"bytes_sent": counters.bytes_sent,
|
||||
"bytes_recv": counters.bytes_recv,
|
||||
"packets_sent": counters.packets_sent,
|
||||
"packets_recv": counters.packets_recv,
|
||||
"errin": counters.errin,
|
||||
"errout": counters.errout,
|
||||
"dropin": counters.dropin,
|
||||
"dropout": counters.dropout
|
||||
}
|
||||
|
||||
def get_traffic_history(self, interface: Optional[str] = None) -> Dict:
|
||||
"""获取流量历史数据"""
|
||||
if interface and interface in self.history["interfaces"]:
|
||||
return {
|
||||
"sent": list(self.history["interfaces"][interface]["sent"]),
|
||||
"recv": list(self.history["interfaces"][interface]["recv"]),
|
||||
"time": list(self.history["time"])
|
||||
}
|
||||
|
||||
return {
|
||||
"sent": list(self.history["sent"]),
|
||||
"recv": list(self.history["recv"]),
|
||||
"time": list(self.history["time"])
|
||||
}
|
||||
|
||||
# 全局单例
|
||||
traffic_monitor = TrafficMonitor()
|
||||
|
Loading…
x
Reference in New Issue
Block a user