114514+114514

This commit is contained in:
3 2025-06-18 18:36:52 +08:00
parent 6e5cd34da7
commit 2231b8cf82
7 changed files with 306 additions and 6 deletions

View File

@ -3,7 +3,18 @@ from src.backend.app.api.endpoints import router
from src.backend.app.utils.logger import setup_logging
from src.backend.config import settings
# 添加正确的导入
from .services.traffic_monitor import traffic_monitor
from src.backend.app.api.database import init_db # 修复:导入 init_db
def create_app() -> FastAPI:
# 初始化数据库
init_db()
# 启动流量监控
traffic_monitor.start_monitoring()
# 设置日志
setup_logging()
@ -29,7 +40,7 @@ def create_app() -> FastAPI:
# 添加API路由
app.include_router(router, prefix=settings.API_PREFIX)
return app
app = create_app()

View File

@ -0,0 +1,17 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///./traffic_monitor.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def init_db():
"""初始化数据库"""
# 删除多余的导入
Base.metadata.create_all(bind=engine)

View File

@ -1,13 +1,24 @@
from fastapi import (APIRouter, HTTPException, Response)
from datetime import datetime
from fastapi import (APIRouter, HTTPException, Response,WebSocket, WebSocketDisconnect)
from typing import List
from pydantic import BaseModel
from ...app.services.ai_services import AIService
from ...app.api.network_config import SwitchConfigurator
from ...config import settings
from ..services.network_scanner import NetworkScanner
from ...app.services.traffic_monitor import traffic_monitor
from ...app.models.traffic_models import TrafficRecord
from src.backend.app.api.database import (SessionLocal)
import asyncio
from fastapi.responses import HTMLResponse
import matplotlib.pyplot as plt
import io
import base64
router = APIRouter(prefix="/api", tags=["API"])
router = APIRouter(prefix="", tags=["API"])
scanner = NetworkScanner()
# 添加根路径路由
@ -110,4 +121,84 @@ async def apply_config(request: ConfigRequest):
status_code=500,
detail=f"Failed to apply config: {str(e)}"
)
@router.get("/traffic/interfaces", summary="获取所有网络接口")
async def get_network_interfaces():
return {
"interfaces": traffic_monitor.get_interfaces()
}
@router.get("/traffic/current", summary="获取当前流量数据")
async def get_current_traffic(interface: str = None):
return traffic_monitor.get_current_traffic(interface)
@router.get("/traffic/history", summary="获取流量历史数据")
async def get_traffic_history(interface: str = None, limit: int = 100):
history = traffic_monitor.get_traffic_history(interface)
return {
"sent": history["sent"][-limit:],
"recv": history["recv"][-limit:],
"time": [t.isoformat() for t in history["time"]][-limit:]
}
@router.get("/traffic/records", summary="获取流量记录")
async def get_traffic_records(interface: str = None, limit: int = 100):
with SessionLocal() as session:
query = session.query(TrafficRecord)
if interface:
query = query.filter(TrafficRecord.interface == interface)
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
return [record.to_dict() for record in records]
@router.websocket("/ws/traffic")
async def websocket_traffic(websocket: WebSocket):
"""实时流量WebSocket"""
await websocket.accept()
try:
while True:
# 获取所有接口的当前流量
traffic_data = traffic_monitor.get_current_traffic()
await websocket.send_json(traffic_data)
await asyncio.sleep(1) # 每秒更新一次
except WebSocketDisconnect:
print("客户端断开连接")
@router.get("/traffic/plot", response_class=HTMLResponse, summary="流量可视化图表")
async def plot_traffic(interface: str = "eth0", minutes: int = 10):
# 获取历史数据
history = traffic_monitor.get_traffic_history(interface)
time_points = history["time"][-minutes * 60:]
sent = history["sent"][-minutes * 60:]
recv = history["recv"][-minutes * 60:]
# 创建图表
plt.figure(figsize=(10, 6))
plt.plot(time_points, sent, label="发送流量 (B/s)")
plt.plot(time_points, recv, label="接收流量 (B/s)")
plt.title(f"{interface} 流量监控 - 最近 {minutes} 分钟")
plt.xlabel("时间")
plt.ylabel("流量 (字节/秒)")
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
# 转换为HTML图像
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
return f"""
<html>
<head>
<title>网络流量监控</title>
</head>
<body>
<h1>{interface} 网络流量监控</h1>
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</body>
</html>
"""

View File

@ -0,0 +1,27 @@
# 添加正确的导入
from src.backend.app.api.database import Base # 修复:导入 Base
from sqlalchemy import Column, Integer, String, DateTime
class TrafficRecord(Base):
"""网络流量记录模型"""
__tablename__ = "traffic_records"
id = Column(Integer, primary_key=True, index=True)
interface = Column(String(50), index=True)
bytes_sent = Column(Integer)
bytes_recv = Column(Integer)
packets_sent = Column(Integer)
packets_recv = Column(Integer)
timestamp = Column(DateTime)
def to_dict(self):
return {
"id": self.id,
"interface": self.interface,
"bytes_sent": self.bytes_sent,
"bytes_recv": self.bytes_recv,
"packets_sent": self.packets_sent,
"packets_recv": self.packets_recv,
"timestamp": self.timestamp.isoformat()
}

View File

@ -0,0 +1,147 @@
import psutil
import time
import asyncio
from datetime import datetime
from collections import deque
from typing import Dict, Optional, List
from ..models.traffic_models import TrafficRecord
from src.backend.app.api.database import SessionLocal # 修复:导入 SessionLocal
class TrafficMonitor:
def __init__(self, history_size: int = 300):
self.history_size = history_size # 保存历史大小
self.history = {
"sent": deque(maxlen=history_size),
"recv": deque(maxlen=history_size),
"time": deque(maxlen=history_size),
"interfaces": {}
}
self.running = False
self.task = None
self.update_interval = 1.0 # 秒
@staticmethod
def get_interfaces() -> List[str]:
"""获取所有网络接口名称"""
return list(psutil.net_io_counters(pernic=True).keys())
def start_monitoring(self):
"""启动流量监控"""
if not self.running:
self.running = True
self.task = asyncio.create_task(self._monitor_loop())
print("流量监控已启动")
async def stop_monitoring(self):
"""停止流量监控"""
if self.running:
self.running = False
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
print("流量监控已停止")
async def _monitor_loop(self):
"""监控主循环"""
last_stats = psutil.net_io_counters(pernic=True)
last_time = time.time()
while self.running:
await asyncio.sleep(self.update_interval)
current_time = time.time()
current_stats = psutil.net_io_counters(pernic=True)
elapsed = current_time - last_time
# 计算每个接口的流量速率
for iface in current_stats:
if iface not in self.history["interfaces"]:
# 修复:使用 self.history_size
self.history["interfaces"][iface] = {
"sent": deque(maxlen=self.history_size),
"recv": deque(maxlen=self.history_size)
}
if iface in last_stats:
sent_rate = (current_stats[iface].bytes_sent - last_stats[iface].bytes_sent) / elapsed
recv_rate = (current_stats[iface].bytes_recv - last_stats[iface].bytes_recv) / elapsed
# 保存到历史数据
self.history["sent"].append(sent_rate)
self.history["recv"].append(recv_rate)
self.history["time"].append(datetime.now())
# 保存到接口特定历史
self.history["interfaces"][iface]["sent"].append(sent_rate)
self.history["interfaces"][iface]["recv"].append(recv_rate)
# 保存到数据库
self._save_to_db(current_stats)
last_stats = current_stats
last_time = current_time
@staticmethod
def _save_to_db(stats):
"""保存流量数据到数据库"""
with SessionLocal() as session:
for iface, counters in stats.items():
record = TrafficRecord(
interface=iface,
bytes_sent=counters.bytes_sent,
bytes_recv=counters.bytes_recv,
packets_sent=counters.packets_sent,
packets_recv=counters.packets_recv,
timestamp=datetime.now()
)
session.add(record)
session.commit()
def get_current_traffic(self, interface: Optional[str] = None) -> Dict:
"""获取当前流量数据"""
stats = psutil.net_io_counters(pernic=True)
if interface:
if interface in stats:
return self._format_interface_stats(stats[interface])
return {}
return {iface: self._format_interface_stats(data) for iface, data in stats.items()}
@staticmethod
def _format_interface_stats(counters) -> Dict:
"""格式化接口统计数据"""
return {
"bytes_sent": counters.bytes_sent,
"bytes_recv": counters.bytes_recv,
"packets_sent": counters.packets_sent,
"packets_recv": counters.packets_recv,
"errin": counters.errin,
"errout": counters.errout,
"dropin": counters.dropin,
"dropout": counters.dropout
}
def get_traffic_history(self, interface: Optional[str] = None) -> Dict:
"""获取流量历史数据"""
if interface and interface in self.history["interfaces"]:
return {
"sent": list(self.history["interfaces"][interface]["sent"]),
"recv": list(self.history["interfaces"][interface]["recv"]),
"time": list(self.history["time"])
}
return {
"sent": list(self.history["sent"]),
"recv": list(self.history["recv"]),
"time": list(self.history["time"])
}
# 全局流量监控实例
traffic_monitor = TrafficMonitor()

View File

@ -1,3 +1,4 @@
from pydantic import BaseModel
from pydantic_settings import BaseSettings
from dotenv import load_dotenv

View File

@ -24,4 +24,10 @@ tenacity==8.2.3
# 其他工具
asyncio==3.4.3
typing_extensions==4.10.0
typing_extensions==4.10.0
#监控依赖
psutil==5.9.8
matplotlib==3.8.3
sqlalchemy==2.0.28
fastapi_utils==0.2.1