Compare commits

...

2 Commits

Author SHA1 Message Date
08a6ac56b8 流量监控 2025-10-08 02:19:19 +08:00
9d540c77b7 删除待重构文件 2025-10-08 01:15:33 +08:00
9 changed files with 314 additions and 624 deletions

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="FacetManager">
<facet type="Python" name="Python facet">
<configuration sdkName="Python 3.13" />
<configuration sdkName="Python 3.12 (AI-powered-switches)" />
</facet>
</component>
<component name="NewModuleRootManager">
@ -11,5 +11,6 @@
</content>
<orderEntry type="jdk" jdkName="Python 3.13" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="library" name="Python 3.12 (AI-powered-switches) interpreter library" level="application" />
</component>
</module>

View File

@ -4,13 +4,8 @@ from starlette.middleware import Middleware # 新增导入
from src.backend.app.api.endpoints import router
from src.backend.app.utils.logger import setup_logging
from src.backend.config import settings
from .services.switch_traffic_monitor import get_switch_monitor
from .services.traffic_monitor import traffic_monitor
from src.backend.app.api.database import init_db
def create_app() -> FastAPI:
init_db()
traffic_monitor.start_monitoring()
setup_logging()
app = FastAPI(

View File

@ -1,16 +0,0 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///./traffic_monitor.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def init_db():
"""初始化数据库"""
Base.metadata.create_all(bind=engine)

View File

@ -1,26 +1,20 @@
import socket
from datetime import datetime, timedelta
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
from typing import List
from fastapi import (APIRouter, HTTPException, Response)
from typing import List, Optional
from pydantic import BaseModel
import asyncio
from fastapi.responses import HTMLResponse, JSONResponse
import matplotlib.pyplot as plt
import io
import base64
import psutil
import ipaddress
import time
from datetime import datetime
from ..models.requests import CLICommandRequest, ConfigRequest
from ..services.switch_traffic_monitor import get_switch_monitor
from ..utils import logger
from ..models.requests import CLICommandRequest, ConfigRequest, TrafficQueryRequest
from ...app.services.ai_services import AIService
from ...app.api.network_config import SwitchConfigurator
from ...config import settings
from ..services.network_scanner import NetworkScanner
from ...app.services.traffic_monitor import traffic_monitor
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
from ..services.traffic_monitor import traffic_monitor
from ..utils.logger import logger
router = APIRouter(prefix="", tags=["API"])
scanner = NetworkScanner()
@ -38,8 +32,8 @@ async def root():
"/scan_network",
"/list_devices",
"/batch_apply_config",
"/traffic/switch/current",
"/traffic/switch/history"
"/traffic/realtime",
"/traffic/cache/clear"
]
}
@ -151,234 +145,6 @@ async def execute_cli_commands(request: CLICommandRequest):
except Exception as e:
raise HTTPException(500, detail=str(e))
@router.get("/traffic/interfaces", summary="获取所有网络接口")
async def get_network_interfaces():
return {
"interfaces": await asyncio.to_thread(traffic_monitor.get_interfaces)
}
@router.get("/traffic/current", summary="获取当前流量数据")
async def get_current_traffic(interface: str = None):
return await asyncio.to_thread(traffic_monitor.get_current_traffic, interface)
@router.get("/traffic/history", summary="获取流量历史数据")
async def get_traffic_history(interface: str = None, limit: int = 100):
history = await asyncio.to_thread(traffic_monitor.get_traffic_history, interface)
return {
"sent": history["sent"][-limit:],
"recv": history["recv"][-limit:],
"time": [t.isoformat() for t in history["time"]][-limit:]
}
@router.get("/traffic/records", summary="获取流量记录")
async def get_traffic_records(interface: str = None, limit: int = 100):
def sync_get_records():
with SessionLocal() as session:
query = session.query(TrafficRecord)
if interface:
query = query.filter(TrafficRecord.interface == interface)
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
return [record.to_dict() for record in records]
return await asyncio.to_thread(sync_get_records)
@router.websocket("/ws/traffic")
async def websocket_traffic(websocket: WebSocket):
await websocket.accept()
try:
while True:
traffic_data = await asyncio.to_thread(traffic_monitor.get_current_traffic)
await websocket.send_json(traffic_data)
await asyncio.sleep(1)
except WebSocketDisconnect:
print("客户端断开连接")
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
async def get_switch_interfaces(switch_ip: str):
try:
monitor = get_switch_monitor(switch_ip)
interfaces = list(monitor.interface_oids.keys())
return {
"switch_ip": switch_ip,
"interfaces": interfaces
}
except Exception as e:
logger.error(f"获取交换机接口失败: {str(e)}")
raise HTTPException(500, f"获取接口失败: {str(e)}")
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
"""获取指定交换机接口的当前流量数据"""
try:
def sync_get_record():
with SessionLocal() as session:
record = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
if not record:
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": 0.0,
"rate_out": 0.0,
"bytes_in": 0,
"bytes_out": 0
}
return {
"switch_ip": switch_ip,
"interface": interface,
"rate_in": record.rate_in,
"rate_out": record.rate_out,
"bytes_in": record.bytes_in,
"bytes_out": record.bytes_out
}
return await asyncio.to_thread(sync_get_record)
except Exception as e:
logger.error(f"获取接口流量失败: {str(e)}")
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
try:
monitor = get_switch_monitor(switch_ip)
if not interface:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
return {
"switch_ip": switch_ip,
"traffic": traffic_data
}
return await get_interface_current_traffic(switch_ip, interface)
except Exception as e:
logger.error(f"获取交换机流量失败: {str(e)}")
raise HTTPException(500, f"获取流量失败: {str(e)}")
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
try:
monitor = get_switch_monitor(switch_ip)
if not interface:
return {
"switch_ip": switch_ip,
"history": await asyncio.to_thread(monitor.get_traffic_history)
}
def sync_get_history():
with SessionLocal() as session:
time_threshold = datetime.now() - timedelta(minutes=minutes)
records = session.query(SwitchTrafficRecord).filter(
SwitchTrafficRecord.switch_ip == switch_ip,
SwitchTrafficRecord.interface == interface,
SwitchTrafficRecord.timestamp >= time_threshold
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
history_data = {
"in": [record.rate_in for record in records],
"out": [record.rate_out for record in records],
"time": [record.timestamp.isoformat() for record in records]
}
return history_data
history_data = await asyncio.to_thread(sync_get_history)
return {
"switch_ip": switch_ip,
"interface": interface,
"history": history_data
}
except Exception as e:
logger.error(f"获取历史流量失败: {str(e)}")
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
@router.websocket("/ws/traffic/switch")
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
await websocket.accept()
try:
monitor = get_switch_monitor(switch_ip)
while True:
if interface:
traffic_data = await get_interface_current_traffic(switch_ip, interface)
await websocket.send_json(traffic_data)
else:
traffic_data = {}
for iface in monitor.interface_oids:
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
await websocket.send_json({
"switch_ip": switch_ip,
"traffic": traffic_data
})
await asyncio.sleep(1)
except WebSocketDisconnect:
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
except Exception as e:
logger.error(f"交换机流量WebSocket错误: {str(e)}")
await websocket.close(code=1011, reason=str(e))
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
try:
history = await get_switch_traffic_history(switch_ip, interface, minutes)
history_data = history["history"]
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
in_rates = history_data["in"]
out_rates = history_data["out"]
def generate_plot():
plt.figure(figsize=(12, 6))
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
plt.xlabel("时间")
plt.ylabel("流量 (字节/秒)")
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
return image_base64
image_base64 = await asyncio.to_thread(generate_plot)
return f"""
<html>
<head>
<title>交换机流量监控</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.container {{ max-width: 900px; margin: 0 auto; }}
</style>
</head>
<body>
<div class="container">
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</div>
</body>
</html>
"""
except Exception as e:
logger.error(f"生成流量图表失败: {str(e)}")
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)
@router.get("/network_adapters", summary="获取网络适配器网段")
async def get_network_adapters():
try:
@ -403,3 +169,131 @@ async def get_network_adapters():
return {"networks": networks}
except Exception as e:
return {"error": f"获取网络适配器信息失败: {str(e)}"}
@router.post("/traffic/realtime", summary="查询交换机接口实时流量")
async def get_realtime_traffic(request: TrafficQueryRequest):
"""
查询交换机接口实时流量速率(Kbps)
- 支持多个接口同时查询
- 首次查询速率返回 0
- 单个接口查询失败不影响其他接口
"""
# 提取认证信息
username = request.username or settings.SWITCH_USERNAME
password = request.password or settings.SWITCH_PASSWORD
# 创建配置器(复用连接池)
configurator = SwitchConfigurator(
username=username,
password=password,
timeout=settings.SWITCH_TIMEOUT
)
results = []
current_time = time.time()
# 遍历所有接口
for interface in request.interfaces:
interface_data = {
"interface": interface,
"status": "unknown",
"in_speed_kbps": 0.0,
"out_speed_kbps": 0.0,
"in_bytes": 0,
"out_bytes": 0,
"error": None
}
try:
# 获取查询命令
command = traffic_monitor.get_query_command(request.vendor, interface)
if not command:
interface_data["error"] = f"不支持的厂商: {request.vendor}"
results.append(interface_data)
continue
# 执行查询命令
try:
output = await configurator.execute_raw_commands(
ip=request.switch_ip,
commands=[command]
)
# 解析输出
stats = traffic_monitor.parse_interface_stats(request.vendor, str(output))
if stats is None:
interface_data["error"] = "解析接口统计失败"
results.append(interface_data)
continue
in_bytes, out_bytes, status = stats
# 计算速率
in_speed_kbps, out_speed_kbps = traffic_monitor.calculate_speed(
request.switch_ip,
interface,
in_bytes,
out_bytes,
current_time
)
# 更新结果
interface_data.update({
"status": status,
"in_speed_kbps": round(in_speed_kbps, 2),
"out_speed_kbps": round(out_speed_kbps, 2),
"in_bytes": in_bytes,
"out_bytes": out_bytes
})
except Exception as e:
interface_data["error"] = f"查询失败: {str(e)}"
logger.error(f"查询接口 {interface} 失败: {e}")
except Exception as e:
interface_data["error"] = f"未知错误: {str(e)}"
logger.error(f"处理接口 {interface} 时发生异常: {e}", exc_info=True)
results.append(interface_data)
return {
"success": True,
"switch_ip": request.switch_ip,
"vendor": request.vendor,
"timestamp": datetime.utcnow().isoformat() + "Z",
"data": results
}
@router.delete("/traffic/cache/clear", summary="清除流量监控缓存")
async def clear_traffic_cache(switch_ip: Optional[str] = None):
"""
清除流量监控缓存
- 不指定 switch_ip: 清除所有缓存
- 指定 switch_ip: 只清除该交换机的缓存
"""
try:
count = traffic_monitor.clear_cache(switch_ip)
return {
"success": True,
"message": f"已清除 {count} 条缓存记录",
"switch_ip": switch_ip or "all"
}
except Exception as e:
raise HTTPException(500, f"清除缓存失败: {str(e)}")
@router.get("/traffic/cache/stats", summary="获取缓存统计信息")
async def get_cache_stats():
"""获取流量监控缓存统计信息"""
try:
stats = traffic_monitor.get_cache_stats()
return {
"success": True,
"stats": stats
}
except Exception as e:
raise HTTPException(500, f"获取缓存统计失败: {str(e)}")

View File

@ -24,3 +24,11 @@ class CLICommandRequest(BaseModel):
def extract_credentials(self):
return self.username or "NONE", self.password or "NONE"
class TrafficQueryRequest(BaseModel):
"""实时流量查询请求"""
switch_ip: str
vendor: str # huawei/cisco/h3c/ruijie/zte
interfaces: List[str] # 例如: ["GigabitEthernet0/0/1", "GigabitEthernet0/0/2"]
username: Optional[str] = None
password: Optional[str] = None

View File

@ -1,51 +0,0 @@
from src.backend.app.api.database import Base # 修复:导入 Base
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, Float
class TrafficRecord(Base):
"""网络流量记录模型"""
__tablename__ = "traffic_records"
id = Column(Integer, primary_key=True, index=True)
interface = Column(String(50), index=True)
bytes_sent = Column(Integer)
bytes_recv = Column(Integer)
packets_sent = Column(Integer)
packets_recv = Column(Integer)
timestamp = Column(DateTime)
def to_dict(self):
return {
"id": self.id,
"interface": self.interface,
"bytes_sent": self.bytes_sent,
"bytes_recv": self.bytes_recv,
"packets_sent": self.packets_sent,
"packets_recv": self.packets_recv,
"timestamp": self.timestamp.isoformat()
}
class SwitchTrafficRecord(Base):
__tablename__ = "switch_traffic_records"
id = Column(Integer, primary_key=True, index=True)
switch_ip = Column(String(50), index=True)
interface = Column(String(50))
bytes_in = Column(BigInteger)
bytes_out = Column(BigInteger)
rate_in = Column(Float)
rate_out = Column(Float)
timestamp = Column(DateTime)
def to_dict(self):
return {
"id": self.id,
"switch_ip": self.switch_ip,
"interface": self.interface,
"bytes_in": self.bytes_in,
"bytes_out": self.bytes_out,
"rate_in": self.rate_in,
"rate_out": self.rate_out,
"timestamp": self.timestamp.isoformat()
}

View File

@ -1,182 +0,0 @@
import asyncio
from datetime import datetime
from collections import deque
from typing import Optional, List, Dict
from pysnmp.hlapi import *
from ..models.traffic_models import SwitchTrafficRecord
from src.backend.app.api.database import SessionLocal
from ..utils.logger import logger
#V=ΔQ'-ΔQ/Δt (B/s)
class SwitchTrafficMonitor:
def __init__(
self,
switch_ip: str,
community: str = 'public',
update_interval: int = 5,
interfaces: Optional[List[str]] = None
):
self.switch_ip = switch_ip
self.community = community
self.update_interval = update_interval
self.running = False
self.task = None
self.interface_history = {}
self.history = {
"in": deque(maxlen=300),
"out": deque(maxlen=300),
"time": deque(maxlen=300)
}
self.interface_oids = {
"GigabitEthernet0/0/1": {
"in": '1.3.6.1.2.1.2.2.1.10.1',
"out": '1.3.6.1.2.1.2.2.1.16.1'
},
"GigabitEthernet0/0/24": {
"in": '1.3.6.1.2.1.2.2.1.10.24',
"out": '1.3.6.1.2.1.2.2.1.16.24'
}
}
if interfaces:
self.interface_oids = {
iface: oid for iface, oid in self.interface_oids.items()
if iface in interfaces
}
logger.info(f"监控指定接口: {', '.join(interfaces)}")
else:
logger.info("监控所有接口")
def start_monitoring(self):
"""启动交换机流量监控"""
if not self.running:
self.running = True
self.task = asyncio.create_task(self._monitor_loop())
logger.success(f"交换机流量监控已启动: {self.switch_ip}")
async def stop_monitoring(self):
"""停止监控"""
if self.running:
self.running = False
if self.task:
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
logger.info(f"交换机流量监控已停止: {self.switch_ip}")
async def _monitor_loop(self):
"""监控主循环"""
last_values = {iface: {"in": 0, "out": 0} for iface in self.interface_oids}
last_time = datetime.now()
while self.running:
await asyncio.sleep(self.update_interval)
try:
current_time = datetime.now()
elapsed = (current_time - last_time).total_seconds()
for iface, oids in self.interface_oids.items():
in_octets = self._snmp_get(oids["in"])
out_octets = self._snmp_get(oids["out"])
if in_octets is not None and out_octets is not None:
iface_values = last_values[iface]
in_rate = (in_octets - iface_values["in"]) / elapsed if iface_values["in"] > 0 else 0
out_rate = (out_octets - iface_values["out"]) / elapsed if iface_values["out"] > 0 else 0
self.history["in"].append(in_rate)
self.history["out"].append(out_rate)
self.history["time"].append(current_time)
self._save_to_db(iface, in_octets, out_octets, in_rate, out_rate, current_time)
iface_values["in"] = in_octets
iface_values["out"] = out_octets
last_time = current_time
except Exception as e:
logger.error(f"监控交换机流量出错: {str(e)}")
def _snmp_get(self, oid) -> Optional[int]:
"""执行SNMP GET请求"""
try:
cmd = getCmd(
SnmpEngine(),
CommunityData(self.community),
UdpTransportTarget((self.switch_ip, 161)),
ContextData(),
ObjectType(ObjectIdentity(oid)))
errorIndication, errorStatus, errorIndex, varBinds = next(cmd)
except Exception as e:
logger.error(f"SNMP请求失败: {str(e)}")
return None
if errorIndication:
logger.error(f"SNMP错误: {errorIndication}")
return None
elif errorStatus:
try:
if errorIndex:
index_val = int(errorIndex) - 1
error_item = varBinds[index_val] if index_val < len(varBinds) else '?'
else:
error_item = '?'
error_msg = f"SNMP错误: {errorStatus.prettyPrint()} at {error_item}"
logger.error(error_msg)
except Exception as e:
logger.error(f"解析SNMP错误失败: {str(e)}")
return None
else:
for varBind in varBinds:
try:
return int(varBind[1])
except Exception as e:
logger.error(f"转换SNMP值失败: {str(e)}")
return None
return None
def _save_to_db(self, interface: str, in_octets: int, out_octets: int,
in_rate: float, out_rate: float, timestamp: datetime):
"""保存流量数据到数据库"""
try:
with SessionLocal() as session:
record = SwitchTrafficRecord(
switch_ip=self.switch_ip,
interface=interface,
bytes_in=in_octets,
bytes_out=out_octets,
rate_in=in_rate,
rate_out=out_rate,
timestamp=timestamp
)
session.add(record)
session.commit()
except Exception as e:
logger.error(f"保存流量数据到数据库失败: {str(e)}")
def get_traffic_history(self) -> Dict[str, List]:
"""获取流量历史数据"""
return {
"in": list(self.history["in"]),
"out": list(self.history["out"]),
"time": list(self.history["time"])
}
switch_monitors = {}
def get_switch_monitor(switch_ip: str, community: str = 'public', interfaces: Optional[List[str]] = None):
"""获取或创建交换机监控器"""
if switch_ip not in switch_monitors:
switch_monitors[switch_ip] = SwitchTrafficMonitor(
switch_ip,
community,
interfaces=interfaces
)
return switch_monitors[switch_ip]

View File

@ -1,9 +0,0 @@
# test_linprog.py
import numpy as np
from scipy.optimize import linprog
c = np.array([-1, -2])
A_ub = np.array([[1, 1]])
b_ub = np.array([3])
res = linprog(c, A_ub=A_ub, b_ub=b_ub, method='highs')
print(res)

View File

@ -1,145 +1,195 @@
import psutil
import re
import time
import asyncio
from typing import Dict, List, Optional, Tuple
from datetime import datetime
from collections import deque
from typing import Dict, Optional, List
from ..models.traffic_models import TrafficRecord
from src.backend.app.api.database import SessionLocal
from ..utils.logger import logger
from src.backend.app.utils.logger import logger
class TrafficMonitor:
def __init__(self, history_size: int = 300):
self.history_size = history_size
self.history = {
"sent": deque(maxlen=history_size),
"recv": deque(maxlen=history_size),
"time": deque(maxlen=history_size),
"interfaces": {}
"""
交换机流量监控服务
通过 Telnet CLI 查询接口流量统计,计算实时速率
"""
# 各厂商查询接口流量的CLI命令模板
VENDOR_COMMANDS = {
"huawei": "display interface {interface}",
"cisco": "show interface {interface}",
"h3c": "display interface {interface}",
"ruijie": "show interface {interface}",
"zte": "show interface {interface}",
}
self.running = False
self.task = None
self.update_interval = 1.0 # 秒
@staticmethod
def get_interfaces() -> List[str]:
"""获取所有网络接口名称"""
return list(psutil.net_io_counters(pernic=True).keys())
def __init__(self):
# 流量计数器缓存: key = "switch_ip:interface", value = {"timestamp": float, "in_bytes": int, "out_bytes": int}
self.traffic_cache: Dict[str, Dict] = {}
def start_monitoring(self):
"""启动流量监控"""
if not self.running:
self.running = True
self.task = asyncio.create_task(self._monitor_loop())
logger.info("流量监控已启动")
# 缓存TTL(秒), 30分钟未访问则自动清理
self.cache_ttl = 1800
def get_query_command(self, vendor: str, interface: str) -> Optional[str]:
"""根据厂商和接口名生成查询命令"""
vendor_lower = vendor.lower()
if vendor_lower not in self.VENDOR_COMMANDS:
logger.warning(f"不支持的厂商: {vendor}")
return None
return self.VENDOR_COMMANDS[vendor_lower].format(interface=interface)
def parse_interface_stats(self, vendor: str, output: str) -> Optional[Tuple[int, int, str]]:
"""
解析CLI输出,提取入/出方向字节数和接口状态
返回: (in_bytes, out_bytes, status) None
"""
vendor_lower = vendor.lower()
async def stop_monitoring(self):
"""停止流量监控"""
if self.running:
self.running = False
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
logger.info("流量监控已停止")
# 提取接口状态
status = "unknown"
if re.search(r'(current state|line protocol).*?(UP|up)', output, re.IGNORECASE):
status = "up"
elif re.search(r'(current state|line protocol).*?(DOWN|down)', output, re.IGNORECASE):
status = "down"
async def _monitor_loop(self):
"""监控主循环"""
last_stats = psutil.net_io_counters(pernic=True)
last_time = time.time()
# 华为/H3C格式: "Input: 12345 packets, 1048576000 bytes"
if vendor_lower in ["huawei", "h3c"]:
match_in = re.search(r'Input:.*?(\d+)\s+bytes', output, re.IGNORECASE)
match_out = re.search(r'Output:.*?(\d+)\s+bytes', output, re.IGNORECASE)
while self.running:
await asyncio.sleep(self.update_interval)
if match_in and match_out:
in_bytes = int(match_in.group(1))
out_bytes = int(match_out.group(1))
return (in_bytes, out_bytes, status)
# Cisco/锐捷/中兴格式: "12345 packets input, 1048576000 bytes"
elif vendor_lower in ["cisco", "ruijie", "zte"]:
match_in = re.search(r'(\d+)\s+packets input,\s+(\d+)\s+bytes', output, re.IGNORECASE)
match_out = re.search(r'(\d+)\s+packets output,\s+(\d+)\s+bytes', output, re.IGNORECASE)
if match_in and match_out:
in_bytes = int(match_in.group(2))
out_bytes = int(match_out.group(2))
return (in_bytes, out_bytes, status)
logger.warning(f"无法解析 {vendor} 厂商的输出")
return None
except Exception as e:
logger.error(f"解析接口统计失败: {e}", exc_info=True)
return None
def calculate_speed(
self,
switch_ip: str,
interface: str,
current_in: int,
current_out: int,
current_time: float
) -> Tuple[float, float]:
"""
计算接口速率(Kbps)
返回: (in_speed_kbps, out_speed_kbps)
"""
cache_key = f"{switch_ip}:{interface}"
# 检查是否有历史数据
if cache_key not in self.traffic_cache:
# 首次查询,保存数据但返回0速率
self.traffic_cache[cache_key] = {
"timestamp": current_time,
"in_bytes": current_in,
"out_bytes": current_out
}
logger.info(f"首次查询 {cache_key}, 速率返回 0")
return (0.0, 0.0)
# 获取历史数据
cached = self.traffic_cache[cache_key]
time_diff = current_time - cached["timestamp"]
# 时间间隔太短,避免除零
if time_diff < 0.1:
logger.warning(f"{cache_key} 查询间隔过短 ({time_diff}s), 返回上次速率")
return (0.0, 0.0)
# 计算字节差(处理计数器溢出)
in_diff = self._calculate_diff(current_in, cached["in_bytes"])
out_diff = self._calculate_diff(current_out, cached["out_bytes"])
# 计算速率: (字节差 * 8 bits/byte) / 时间差(秒) / 1000 = Kbps
in_speed_kbps = (in_diff * 8) / time_diff / 1000
out_speed_kbps = (out_diff * 8) / time_diff / 1000
# 更新缓存
self.traffic_cache[cache_key] = {
"timestamp": current_time,
"in_bytes": current_in,
"out_bytes": current_out
}
logger.debug(f"{cache_key} 速率: IN={in_speed_kbps:.2f} Kbps, OUT={out_speed_kbps:.2f} Kbps")
return (in_speed_kbps, out_speed_kbps)
def _calculate_diff(self, current: int, previous: int) -> int:
"""
计算字节差,处理32位计数器溢出
Reason: 交换机的流量计数器通常是32位,超过4GB会回绕到0
"""
if current >= previous:
return current - previous
else:
# 计数器溢出,假设32位
return (2**32 - previous) + current
def cleanup_expired_cache(self):
"""清理过期的缓存数据"""
current_time = time.time()
current_stats = psutil.net_io_counters(pernic=True)
elapsed = current_time - last_time
expired_keys = []
for iface in current_stats:
if iface not in self.history["interfaces"]:
for key, data in self.traffic_cache.items():
if current_time - data["timestamp"] > self.cache_ttl:
expired_keys.append(key)
self.history["interfaces"][iface] = {
"sent": deque(maxlen=self.history_size),
"recv": deque(maxlen=self.history_size)
}
for key in expired_keys:
del self.traffic_cache[key]
logger.info(f"清理过期缓存: {key}")
if iface in last_stats:
sent_rate = (current_stats[iface].bytes_sent - last_stats[iface].bytes_sent) / elapsed
recv_rate = (current_stats[iface].bytes_recv - last_stats[iface].bytes_recv) / elapsed
return len(expired_keys)
def clear_cache(self, switch_ip: Optional[str] = None):
"""
清除缓存
self.history["sent"].append(sent_rate)
self.history["recv"].append(recv_rate)
self.history["time"].append(datetime.now())
Args:
switch_ip: 如果指定,只清除该交换机的缓存;否则清除所有
"""
if switch_ip:
# 清除指定交换机的缓存
keys_to_remove = [k for k in self.traffic_cache.keys() if k.startswith(f"{switch_ip}:")]
for key in keys_to_remove:
del self.traffic_cache[key]
logger.info(f"清除交换机 {switch_ip} 的缓存,共 {len(keys_to_remove)}")
return len(keys_to_remove)
else:
# 清除所有缓存
count = len(self.traffic_cache)
self.traffic_cache.clear()
logger.info(f"清除所有缓存,共 {count}")
return count
self.history["interfaces"][iface]["sent"].append(sent_rate)
self.history["interfaces"][iface]["recv"].append(recv_rate)
self._save_to_db(current_stats)
last_stats = current_stats
last_time = current_time
@staticmethod
def _save_to_db(stats):
"""保存流量数据到数据库"""
with SessionLocal() as session:
for iface, counters in stats.items():
record = TrafficRecord(
interface=iface,
bytes_sent=counters.bytes_sent,
bytes_recv=counters.bytes_recv,
packets_sent=counters.packets_sent,
packets_recv=counters.packets_recv,
timestamp=datetime.now()
)
session.add(record)
session.commit()
def get_current_traffic(self, interface: Optional[str] = None) -> Dict:
"""获取当前流量数据"""
stats = psutil.net_io_counters(pernic=True)
if interface:
if interface in stats:
return self._format_interface_stats(stats[interface])
return {}
return {iface: self._format_interface_stats(data) for iface, data in stats.items()}
@staticmethod
def _format_interface_stats(counters) -> Dict:
"""格式化接口统计数据"""
def get_cache_stats(self) -> Dict:
"""获取缓存统计信息"""
return {
"bytes_sent": counters.bytes_sent,
"bytes_recv": counters.bytes_recv,
"packets_sent": counters.packets_sent,
"packets_recv": counters.packets_recv,
"errin": counters.errin,
"errout": counters.errout,
"dropin": counters.dropin,
"dropout": counters.dropout
"total_entries": len(self.traffic_cache),
"cache_ttl_seconds": self.cache_ttl,
"entries": list(self.traffic_cache.keys())
}
def get_traffic_history(self, interface: Optional[str] = None) -> Dict:
"""获取流量历史数据"""
if interface and interface in self.history["interfaces"]:
return {
"sent": list(self.history["interfaces"][interface]["sent"]),
"recv": list(self.history["interfaces"][interface]["recv"]),
"time": list(self.history["time"])
}
return {
"sent": list(self.history["sent"]),
"recv": list(self.history["recv"]),
"time": list(self.history["time"])
}
# 全局单例
traffic_monitor = TrafficMonitor()