mirror of
https://github.com/Jerryplusy/AI-powered-switches.git
synced 2025-10-14 01:39:18 +00:00
删除待重构文件
This commit is contained in:
parent
d11decae6a
commit
9d540c77b7
3
.idea/AI-powered-switches.iml
generated
3
.idea/AI-powered-switches.iml
generated
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="FacetManager">
|
||||
<facet type="Python" name="Python facet">
|
||||
<configuration sdkName="Python 3.13" />
|
||||
<configuration sdkName="Python 3.12 (AI-powered-switches)" />
|
||||
</facet>
|
||||
</component>
|
||||
<component name="NewModuleRootManager">
|
||||
@ -11,5 +11,6 @@
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.13" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
<orderEntry type="library" name="Python 3.12 (AI-powered-switches) interpreter library" level="application" />
|
||||
</component>
|
||||
</module>
|
@ -4,13 +4,8 @@ from starlette.middleware import Middleware # 新增导入
|
||||
from src.backend.app.api.endpoints import router
|
||||
from src.backend.app.utils.logger import setup_logging
|
||||
from src.backend.config import settings
|
||||
from .services.switch_traffic_monitor import get_switch_monitor
|
||||
from .services.traffic_monitor import traffic_monitor
|
||||
from src.backend.app.api.database import init_db
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
init_db()
|
||||
traffic_monitor.start_monitoring()
|
||||
setup_logging()
|
||||
|
||||
app = FastAPI(
|
||||
|
@ -1,16 +0,0 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./traffic_monitor.db"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库"""
|
||||
Base.metadata.create_all(bind=engine)
|
@ -1,26 +1,16 @@
|
||||
import socket
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
|
||||
from fastapi import (APIRouter, HTTPException, Response)
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
import asyncio
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
import base64
|
||||
import psutil
|
||||
import ipaddress
|
||||
|
||||
from ..models.requests import CLICommandRequest, ConfigRequest
|
||||
from ..services.switch_traffic_monitor import get_switch_monitor
|
||||
from ..utils import logger
|
||||
from ...app.services.ai_services import AIService
|
||||
from ...app.api.network_config import SwitchConfigurator
|
||||
from ...config import settings
|
||||
from ..services.network_scanner import NetworkScanner
|
||||
from ...app.services.traffic_monitor import traffic_monitor
|
||||
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
|
||||
router = APIRouter(prefix="", tags=["API"])
|
||||
scanner = NetworkScanner()
|
||||
@ -151,234 +141,6 @@ async def execute_cli_commands(request: CLICommandRequest):
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/traffic/interfaces", summary="获取所有网络接口")
|
||||
async def get_network_interfaces():
|
||||
return {
|
||||
"interfaces": await asyncio.to_thread(traffic_monitor.get_interfaces)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/traffic/current", summary="获取当前流量数据")
|
||||
async def get_current_traffic(interface: str = None):
|
||||
return await asyncio.to_thread(traffic_monitor.get_current_traffic, interface)
|
||||
|
||||
|
||||
@router.get("/traffic/history", summary="获取流量历史数据")
|
||||
async def get_traffic_history(interface: str = None, limit: int = 100):
|
||||
history = await asyncio.to_thread(traffic_monitor.get_traffic_history, interface)
|
||||
return {
|
||||
"sent": history["sent"][-limit:],
|
||||
"recv": history["recv"][-limit:],
|
||||
"time": [t.isoformat() for t in history["time"]][-limit:]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/traffic/records", summary="获取流量记录")
|
||||
async def get_traffic_records(interface: str = None, limit: int = 100):
|
||||
def sync_get_records():
|
||||
with SessionLocal() as session:
|
||||
query = session.query(TrafficRecord)
|
||||
if interface:
|
||||
query = query.filter(TrafficRecord.interface == interface)
|
||||
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
|
||||
return [record.to_dict() for record in records]
|
||||
|
||||
return await asyncio.to_thread(sync_get_records)
|
||||
|
||||
|
||||
@router.websocket("/ws/traffic")
|
||||
async def websocket_traffic(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
traffic_data = await asyncio.to_thread(traffic_monitor.get_current_traffic)
|
||||
await websocket.send_json(traffic_data)
|
||||
await asyncio.sleep(1)
|
||||
except WebSocketDisconnect:
|
||||
print("客户端断开连接")
|
||||
|
||||
|
||||
@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口")
|
||||
async def get_switch_interfaces(switch_ip: str):
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
interfaces = list(monitor.interface_oids.keys())
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interfaces": interfaces
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取交换机接口失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取接口失败: {str(e)}")
|
||||
|
||||
|
||||
async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict:
|
||||
"""获取指定交换机接口的当前流量数据"""
|
||||
try:
|
||||
def sync_get_record():
|
||||
with SessionLocal() as session:
|
||||
record = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.switch_ip == switch_ip,
|
||||
SwitchTrafficRecord.interface == interface
|
||||
).order_by(SwitchTrafficRecord.timestamp.desc()).first()
|
||||
|
||||
if not record:
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"rate_in": 0.0,
|
||||
"rate_out": 0.0,
|
||||
"bytes_in": 0,
|
||||
"bytes_out": 0
|
||||
}
|
||||
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"rate_in": record.rate_in,
|
||||
"rate_out": record.rate_out,
|
||||
"bytes_in": record.bytes_in,
|
||||
"bytes_out": record.bytes_out
|
||||
}
|
||||
|
||||
return await asyncio.to_thread(sync_get_record)
|
||||
except Exception as e:
|
||||
logger.error(f"获取接口流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取接口流量失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据")
|
||||
async def get_switch_current_traffic(switch_ip: str, interface: str = None):
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
if not interface:
|
||||
traffic_data = {}
|
||||
for iface in monitor.interface_oids:
|
||||
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"traffic": traffic_data
|
||||
}
|
||||
return await get_interface_current_traffic(switch_ip, interface)
|
||||
except Exception as e:
|
||||
logger.error(f"获取交换机流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取流量失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据")
|
||||
async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10):
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
if not interface:
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"history": await asyncio.to_thread(monitor.get_traffic_history)
|
||||
}
|
||||
|
||||
def sync_get_history():
|
||||
with SessionLocal() as session:
|
||||
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
||||
records = session.query(SwitchTrafficRecord).filter(
|
||||
SwitchTrafficRecord.switch_ip == switch_ip,
|
||||
SwitchTrafficRecord.interface == interface,
|
||||
SwitchTrafficRecord.timestamp >= time_threshold
|
||||
).order_by(SwitchTrafficRecord.timestamp.asc()).all()
|
||||
|
||||
history_data = {
|
||||
"in": [record.rate_in for record in records],
|
||||
"out": [record.rate_out for record in records],
|
||||
"time": [record.timestamp.isoformat() for record in records]
|
||||
}
|
||||
return history_data
|
||||
|
||||
history_data = await asyncio.to_thread(sync_get_history)
|
||||
return {
|
||||
"switch_ip": switch_ip,
|
||||
"interface": interface,
|
||||
"history": history_data
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取历史流量失败: {str(e)}")
|
||||
raise HTTPException(500, f"获取历史流量失败: {str(e)}")
|
||||
|
||||
|
||||
@router.websocket("/ws/traffic/switch")
|
||||
async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None):
|
||||
await websocket.accept()
|
||||
try:
|
||||
monitor = get_switch_monitor(switch_ip)
|
||||
while True:
|
||||
if interface:
|
||||
traffic_data = await get_interface_current_traffic(switch_ip, interface)
|
||||
await websocket.send_json(traffic_data)
|
||||
else:
|
||||
traffic_data = {}
|
||||
for iface in monitor.interface_oids:
|
||||
traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface)
|
||||
await websocket.send_json({
|
||||
"switch_ip": switch_ip,
|
||||
"traffic": traffic_data
|
||||
})
|
||||
await asyncio.sleep(1)
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"客户端断开交换机流量连接: {switch_ip}")
|
||||
except Exception as e:
|
||||
logger.error(f"交换机流量WebSocket错误: {str(e)}")
|
||||
await websocket.close(code=1011, reason=str(e))
|
||||
|
||||
|
||||
@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化")
|
||||
async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10):
|
||||
try:
|
||||
history = await get_switch_traffic_history(switch_ip, interface, minutes)
|
||||
history_data = history["history"]
|
||||
time_points = [datetime.fromisoformat(t) for t in history_data["time"]]
|
||||
in_rates = history_data["in"]
|
||||
out_rates = history_data["out"]
|
||||
|
||||
def generate_plot():
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(time_points, in_rates, label="流入流量 (B/s)")
|
||||
plt.plot(time_points, out_rates, label="流出流量 (B/s)")
|
||||
plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟")
|
||||
plt.xlabel("时间")
|
||||
plt.ylabel("流量 (字节/秒)")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
|
||||
plt.close()
|
||||
return image_base64
|
||||
|
||||
image_base64 = await asyncio.to_thread(generate_plot)
|
||||
return f"""
|
||||
<html>
|
||||
<head>
|
||||
<title>交换机流量监控</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
||||
.container {{ max-width: 900px; margin: 0 auto; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>交换机 {switch_ip} 接口 {interface} 流量监控</h1>
|
||||
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
|
||||
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.error(f"生成流量图表失败: {str(e)}")
|
||||
return HTMLResponse(content=f"<h1>错误</h1><p>{str(e)}</p>", status_code=500)
|
||||
|
||||
|
||||
@router.get("/network_adapters", summary="获取网络适配器网段")
|
||||
async def get_network_adapters():
|
||||
try:
|
||||
|
@ -1,51 +0,0 @@
|
||||
from src.backend.app.api.database import Base # 修复:导入 Base
|
||||
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, Float
|
||||
|
||||
|
||||
class TrafficRecord(Base):
|
||||
"""网络流量记录模型"""
|
||||
__tablename__ = "traffic_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
interface = Column(String(50), index=True)
|
||||
bytes_sent = Column(Integer)
|
||||
bytes_recv = Column(Integer)
|
||||
packets_sent = Column(Integer)
|
||||
packets_recv = Column(Integer)
|
||||
timestamp = Column(DateTime)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"interface": self.interface,
|
||||
"bytes_sent": self.bytes_sent,
|
||||
"bytes_recv": self.bytes_recv,
|
||||
"packets_sent": self.packets_sent,
|
||||
"packets_recv": self.packets_recv,
|
||||
"timestamp": self.timestamp.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class SwitchTrafficRecord(Base):
|
||||
__tablename__ = "switch_traffic_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
switch_ip = Column(String(50), index=True)
|
||||
interface = Column(String(50))
|
||||
bytes_in = Column(BigInteger)
|
||||
bytes_out = Column(BigInteger)
|
||||
rate_in = Column(Float)
|
||||
rate_out = Column(Float)
|
||||
timestamp = Column(DateTime)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"switch_ip": self.switch_ip,
|
||||
"interface": self.interface,
|
||||
"bytes_in": self.bytes_in,
|
||||
"bytes_out": self.bytes_out,
|
||||
"rate_in": self.rate_in,
|
||||
"rate_out": self.rate_out,
|
||||
"timestamp": self.timestamp.isoformat()
|
||||
}
|
@ -1,182 +0,0 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
from typing import Optional, List, Dict
|
||||
from pysnmp.hlapi import *
|
||||
from ..models.traffic_models import SwitchTrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
from ..utils.logger import logger
|
||||
|
||||
#V=ΔQ'-ΔQ/Δt (B/s)
|
||||
class SwitchTrafficMonitor:
|
||||
def __init__(
|
||||
self,
|
||||
switch_ip: str,
|
||||
community: str = 'public',
|
||||
update_interval: int = 5,
|
||||
interfaces: Optional[List[str]] = None
|
||||
):
|
||||
self.switch_ip = switch_ip
|
||||
self.community = community
|
||||
self.update_interval = update_interval
|
||||
self.running = False
|
||||
self.task = None
|
||||
self.interface_history = {}
|
||||
self.history = {
|
||||
"in": deque(maxlen=300),
|
||||
"out": deque(maxlen=300),
|
||||
"time": deque(maxlen=300)
|
||||
}
|
||||
|
||||
self.interface_oids = {
|
||||
"GigabitEthernet0/0/1": {
|
||||
"in": '1.3.6.1.2.1.2.2.1.10.1',
|
||||
"out": '1.3.6.1.2.1.2.2.1.16.1'
|
||||
},
|
||||
"GigabitEthernet0/0/24": {
|
||||
"in": '1.3.6.1.2.1.2.2.1.10.24',
|
||||
"out": '1.3.6.1.2.1.2.2.1.16.24'
|
||||
}
|
||||
}
|
||||
|
||||
if interfaces:
|
||||
self.interface_oids = {
|
||||
iface: oid for iface, oid in self.interface_oids.items()
|
||||
if iface in interfaces
|
||||
}
|
||||
logger.info(f"监控指定接口: {', '.join(interfaces)}")
|
||||
else:
|
||||
logger.info("监控所有接口")
|
||||
|
||||
def start_monitoring(self):
|
||||
"""启动交换机流量监控"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._monitor_loop())
|
||||
logger.success(f"交换机流量监控已启动: {self.switch_ip}")
|
||||
|
||||
async def stop_monitoring(self):
|
||||
"""停止监控"""
|
||||
if self.running:
|
||||
self.running = False
|
||||
if self.task:
|
||||
self.task.cancel()
|
||||
try:
|
||||
await self.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info(f"交换机流量监控已停止: {self.switch_ip}")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""监控主循环"""
|
||||
last_values = {iface: {"in": 0, "out": 0} for iface in self.interface_oids}
|
||||
last_time = datetime.now()
|
||||
|
||||
while self.running:
|
||||
await asyncio.sleep(self.update_interval)
|
||||
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
elapsed = (current_time - last_time).total_seconds()
|
||||
|
||||
for iface, oids in self.interface_oids.items():
|
||||
in_octets = self._snmp_get(oids["in"])
|
||||
out_octets = self._snmp_get(oids["out"])
|
||||
|
||||
if in_octets is not None and out_octets is not None:
|
||||
|
||||
iface_values = last_values[iface]
|
||||
in_rate = (in_octets - iface_values["in"]) / elapsed if iface_values["in"] > 0 else 0
|
||||
out_rate = (out_octets - iface_values["out"]) / elapsed if iface_values["out"] > 0 else 0
|
||||
|
||||
self.history["in"].append(in_rate)
|
||||
self.history["out"].append(out_rate)
|
||||
self.history["time"].append(current_time)
|
||||
|
||||
self._save_to_db(iface, in_octets, out_octets, in_rate, out_rate, current_time)
|
||||
|
||||
iface_values["in"] = in_octets
|
||||
iface_values["out"] = out_octets
|
||||
|
||||
last_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"监控交换机流量出错: {str(e)}")
|
||||
|
||||
def _snmp_get(self, oid) -> Optional[int]:
|
||||
"""执行SNMP GET请求"""
|
||||
try:
|
||||
cmd = getCmd(
|
||||
SnmpEngine(),
|
||||
CommunityData(self.community),
|
||||
UdpTransportTarget((self.switch_ip, 161)),
|
||||
ContextData(),
|
||||
ObjectType(ObjectIdentity(oid)))
|
||||
|
||||
errorIndication, errorStatus, errorIndex, varBinds = next(cmd)
|
||||
except Exception as e:
|
||||
logger.error(f"SNMP请求失败: {str(e)}")
|
||||
return None
|
||||
|
||||
if errorIndication:
|
||||
logger.error(f"SNMP错误: {errorIndication}")
|
||||
return None
|
||||
elif errorStatus:
|
||||
try:
|
||||
if errorIndex:
|
||||
index_val = int(errorIndex) - 1
|
||||
error_item = varBinds[index_val] if index_val < len(varBinds) else '?'
|
||||
else:
|
||||
error_item = '?'
|
||||
|
||||
error_msg = f"SNMP错误: {errorStatus.prettyPrint()} at {error_item}"
|
||||
logger.error(error_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"解析SNMP错误失败: {str(e)}")
|
||||
return None
|
||||
else:
|
||||
for varBind in varBinds:
|
||||
try:
|
||||
return int(varBind[1])
|
||||
except Exception as e:
|
||||
logger.error(f"转换SNMP值失败: {str(e)}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _save_to_db(self, interface: str, in_octets: int, out_octets: int,
|
||||
in_rate: float, out_rate: float, timestamp: datetime):
|
||||
"""保存流量数据到数据库"""
|
||||
try:
|
||||
with SessionLocal() as session:
|
||||
record = SwitchTrafficRecord(
|
||||
switch_ip=self.switch_ip,
|
||||
interface=interface,
|
||||
bytes_in=in_octets,
|
||||
bytes_out=out_octets,
|
||||
rate_in=in_rate,
|
||||
rate_out=out_rate,
|
||||
timestamp=timestamp
|
||||
)
|
||||
session.add(record)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"保存流量数据到数据库失败: {str(e)}")
|
||||
|
||||
def get_traffic_history(self) -> Dict[str, List]:
|
||||
"""获取流量历史数据"""
|
||||
return {
|
||||
"in": list(self.history["in"]),
|
||||
"out": list(self.history["out"]),
|
||||
"time": list(self.history["time"])
|
||||
}
|
||||
|
||||
switch_monitors = {}
|
||||
def get_switch_monitor(switch_ip: str, community: str = 'public', interfaces: Optional[List[str]] = None):
|
||||
"""获取或创建交换机监控器"""
|
||||
if switch_ip not in switch_monitors:
|
||||
switch_monitors[switch_ip] = SwitchTrafficMonitor(
|
||||
switch_ip,
|
||||
community,
|
||||
interfaces=interfaces
|
||||
)
|
||||
return switch_monitors[switch_ip]
|
@ -1,9 +0,0 @@
|
||||
# test_linprog.py
|
||||
import numpy as np
|
||||
from scipy.optimize import linprog
|
||||
|
||||
c = np.array([-1, -2])
|
||||
A_ub = np.array([[1, 1]])
|
||||
b_ub = np.array([3])
|
||||
res = linprog(c, A_ub=A_ub, b_ub=b_ub, method='highs')
|
||||
print(res)
|
@ -1,145 +0,0 @@
|
||||
import psutil
|
||||
import time
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
|
||||
from ..models.traffic_models import TrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
from ..utils.logger import logger
|
||||
|
||||
|
||||
class TrafficMonitor:
|
||||
def __init__(self, history_size: int = 300):
|
||||
self.history_size = history_size
|
||||
self.history = {
|
||||
"sent": deque(maxlen=history_size),
|
||||
"recv": deque(maxlen=history_size),
|
||||
"time": deque(maxlen=history_size),
|
||||
"interfaces": {}
|
||||
}
|
||||
self.running = False
|
||||
self.task = None
|
||||
self.update_interval = 1.0 # 秒
|
||||
|
||||
@staticmethod
|
||||
def get_interfaces() -> List[str]:
|
||||
"""获取所有网络接口名称"""
|
||||
return list(psutil.net_io_counters(pernic=True).keys())
|
||||
|
||||
def start_monitoring(self):
|
||||
"""启动流量监控"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._monitor_loop())
|
||||
logger.info("流量监控已启动")
|
||||
|
||||
async def stop_monitoring(self):
|
||||
"""停止流量监控"""
|
||||
if self.running:
|
||||
self.running = False
|
||||
self.task.cancel()
|
||||
try:
|
||||
await self.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("流量监控已停止")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""监控主循环"""
|
||||
last_stats = psutil.net_io_counters(pernic=True)
|
||||
last_time = time.time()
|
||||
|
||||
while self.running:
|
||||
await asyncio.sleep(self.update_interval)
|
||||
|
||||
current_time = time.time()
|
||||
current_stats = psutil.net_io_counters(pernic=True)
|
||||
elapsed = current_time - last_time
|
||||
|
||||
for iface in current_stats:
|
||||
if iface not in self.history["interfaces"]:
|
||||
|
||||
self.history["interfaces"][iface] = {
|
||||
"sent": deque(maxlen=self.history_size),
|
||||
"recv": deque(maxlen=self.history_size)
|
||||
}
|
||||
|
||||
if iface in last_stats:
|
||||
sent_rate = (current_stats[iface].bytes_sent - last_stats[iface].bytes_sent) / elapsed
|
||||
recv_rate = (current_stats[iface].bytes_recv - last_stats[iface].bytes_recv) / elapsed
|
||||
|
||||
|
||||
self.history["sent"].append(sent_rate)
|
||||
self.history["recv"].append(recv_rate)
|
||||
self.history["time"].append(datetime.now())
|
||||
|
||||
|
||||
self.history["interfaces"][iface]["sent"].append(sent_rate)
|
||||
self.history["interfaces"][iface]["recv"].append(recv_rate)
|
||||
|
||||
|
||||
self._save_to_db(current_stats)
|
||||
|
||||
last_stats = current_stats
|
||||
last_time = current_time
|
||||
|
||||
@staticmethod
|
||||
def _save_to_db(stats):
|
||||
"""保存流量数据到数据库"""
|
||||
with SessionLocal() as session:
|
||||
for iface, counters in stats.items():
|
||||
record = TrafficRecord(
|
||||
interface=iface,
|
||||
bytes_sent=counters.bytes_sent,
|
||||
bytes_recv=counters.bytes_recv,
|
||||
packets_sent=counters.packets_sent,
|
||||
packets_recv=counters.packets_recv,
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
session.add(record)
|
||||
session.commit()
|
||||
|
||||
def get_current_traffic(self, interface: Optional[str] = None) -> Dict:
|
||||
"""获取当前流量数据"""
|
||||
stats = psutil.net_io_counters(pernic=True)
|
||||
|
||||
if interface:
|
||||
if interface in stats:
|
||||
return self._format_interface_stats(stats[interface])
|
||||
return {}
|
||||
|
||||
return {iface: self._format_interface_stats(data) for iface, data in stats.items()}
|
||||
|
||||
@staticmethod
|
||||
def _format_interface_stats(counters) -> Dict:
|
||||
"""格式化接口统计数据"""
|
||||
return {
|
||||
"bytes_sent": counters.bytes_sent,
|
||||
"bytes_recv": counters.bytes_recv,
|
||||
"packets_sent": counters.packets_sent,
|
||||
"packets_recv": counters.packets_recv,
|
||||
"errin": counters.errin,
|
||||
"errout": counters.errout,
|
||||
"dropin": counters.dropin,
|
||||
"dropout": counters.dropout
|
||||
}
|
||||
|
||||
def get_traffic_history(self, interface: Optional[str] = None) -> Dict:
|
||||
"""获取流量历史数据"""
|
||||
if interface and interface in self.history["interfaces"]:
|
||||
return {
|
||||
"sent": list(self.history["interfaces"][interface]["sent"]),
|
||||
"recv": list(self.history["interfaces"][interface]["recv"]),
|
||||
"time": list(self.history["time"])
|
||||
}
|
||||
|
||||
return {
|
||||
"sent": list(self.history["sent"]),
|
||||
"recv": list(self.history["recv"]),
|
||||
"time": list(self.history["time"])
|
||||
}
|
||||
|
||||
traffic_monitor = TrafficMonitor()
|
Loading…
x
Reference in New Issue
Block a user