114514+114514

This commit is contained in:
3 2025-06-18 18:36:52 +08:00
parent 6e5cd34da7
commit 2231b8cf82
7 changed files with 306 additions and 6 deletions

View File

@ -3,7 +3,18 @@ from src.backend.app.api.endpoints import router
from src.backend.app.utils.logger import setup_logging from src.backend.app.utils.logger import setup_logging
from src.backend.config import settings from src.backend.config import settings
# 添加正确的导入
from .services.traffic_monitor import traffic_monitor
from src.backend.app.api.database import init_db # 修复:导入 init_db
def create_app() -> FastAPI: def create_app() -> FastAPI:
# 初始化数据库
init_db()
# 启动流量监控
traffic_monitor.start_monitoring()
# 设置日志 # 设置日志
setup_logging() setup_logging()
@ -29,7 +40,7 @@ def create_app() -> FastAPI:
# 添加API路由 # 添加API路由
app.include_router(router, prefix=settings.API_PREFIX) app.include_router(router, prefix=settings.API_PREFIX)
return app return app
app = create_app() app = create_app()

View File

@ -0,0 +1,17 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///./traffic_monitor.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def init_db():
"""初始化数据库"""
# 删除多余的导入
Base.metadata.create_all(bind=engine)

View File

@ -1,13 +1,24 @@
from fastapi import (APIRouter, HTTPException, Response) from datetime import datetime
from fastapi import (APIRouter, HTTPException, Response,WebSocket, WebSocketDisconnect)
from typing import List from typing import List
from pydantic import BaseModel from pydantic import BaseModel
from ...app.services.ai_services import AIService from ...app.services.ai_services import AIService
from ...app.api.network_config import SwitchConfigurator from ...app.api.network_config import SwitchConfigurator
from ...config import settings from ...config import settings
from ..services.network_scanner import NetworkScanner from ..services.network_scanner import NetworkScanner
from ...app.services.traffic_monitor import traffic_monitor
from ...app.models.traffic_models import TrafficRecord
from src.backend.app.api.database import (SessionLocal)
import asyncio
from fastapi.responses import HTMLResponse
import matplotlib.pyplot as plt
import io
import base64
router = APIRouter(prefix="/api", tags=["API"])
router = APIRouter(prefix="", tags=["API"])
scanner = NetworkScanner() scanner = NetworkScanner()
# 添加根路径路由 # 添加根路径路由
@ -110,4 +121,84 @@ async def apply_config(request: ConfigRequest):
status_code=500, status_code=500,
detail=f"Failed to apply config: {str(e)}" detail=f"Failed to apply config: {str(e)}"
) )
@router.get("/traffic/interfaces", summary="获取所有网络接口")
async def get_network_interfaces():
return {
"interfaces": traffic_monitor.get_interfaces()
}
@router.get("/traffic/current", summary="获取当前流量数据")
async def get_current_traffic(interface: str = None):
return traffic_monitor.get_current_traffic(interface)
@router.get("/traffic/history", summary="获取流量历史数据")
async def get_traffic_history(interface: str = None, limit: int = 100):
history = traffic_monitor.get_traffic_history(interface)
return {
"sent": history["sent"][-limit:],
"recv": history["recv"][-limit:],
"time": [t.isoformat() for t in history["time"]][-limit:]
}
@router.get("/traffic/records", summary="获取流量记录")
async def get_traffic_records(interface: str = None, limit: int = 100):
with SessionLocal() as session:
query = session.query(TrafficRecord)
if interface:
query = query.filter(TrafficRecord.interface == interface)
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
return [record.to_dict() for record in records]
@router.websocket("/ws/traffic")
async def websocket_traffic(websocket: WebSocket):
"""实时流量WebSocket"""
await websocket.accept()
try:
while True:
# 获取所有接口的当前流量
traffic_data = traffic_monitor.get_current_traffic()
await websocket.send_json(traffic_data)
await asyncio.sleep(1) # 每秒更新一次
except WebSocketDisconnect:
print("客户端断开连接")
@router.get("/traffic/plot", response_class=HTMLResponse, summary="流量可视化图表")
async def plot_traffic(interface: str = "eth0", minutes: int = 10):
# 获取历史数据
history = traffic_monitor.get_traffic_history(interface)
time_points = history["time"][-minutes * 60:]
sent = history["sent"][-minutes * 60:]
recv = history["recv"][-minutes * 60:]
# 创建图表
plt.figure(figsize=(10, 6))
plt.plot(time_points, sent, label="发送流量 (B/s)")
plt.plot(time_points, recv, label="接收流量 (B/s)")
plt.title(f"{interface} 流量监控 - 最近 {minutes} 分钟")
plt.xlabel("时间")
plt.ylabel("流量 (字节/秒)")
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
# 转换为HTML图像
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
return f"""
<html>
<head>
<title>网络流量监控</title>
</head>
<body>
<h1>{interface} 网络流量监控</h1>
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</body>
</html>
"""

View File

@ -0,0 +1,27 @@
# 添加正确的导入
from src.backend.app.api.database import Base # 修复:导入 Base
from sqlalchemy import Column, Integer, String, DateTime
class TrafficRecord(Base):
"""网络流量记录模型"""
__tablename__ = "traffic_records"
id = Column(Integer, primary_key=True, index=True)
interface = Column(String(50), index=True)
bytes_sent = Column(Integer)
bytes_recv = Column(Integer)
packets_sent = Column(Integer)
packets_recv = Column(Integer)
timestamp = Column(DateTime)
def to_dict(self):
return {
"id": self.id,
"interface": self.interface,
"bytes_sent": self.bytes_sent,
"bytes_recv": self.bytes_recv,
"packets_sent": self.packets_sent,
"packets_recv": self.packets_recv,
"timestamp": self.timestamp.isoformat()
}

View File

@ -0,0 +1,147 @@
import psutil
import time
import asyncio
from datetime import datetime
from collections import deque
from typing import Dict, Optional, List
from ..models.traffic_models import TrafficRecord
from src.backend.app.api.database import SessionLocal # 修复:导入 SessionLocal
class TrafficMonitor:
def __init__(self, history_size: int = 300):
self.history_size = history_size # 保存历史大小
self.history = {
"sent": deque(maxlen=history_size),
"recv": deque(maxlen=history_size),
"time": deque(maxlen=history_size),
"interfaces": {}
}
self.running = False
self.task = None
self.update_interval = 1.0 # 秒
@staticmethod
def get_interfaces() -> List[str]:
"""获取所有网络接口名称"""
return list(psutil.net_io_counters(pernic=True).keys())
def start_monitoring(self):
"""启动流量监控"""
if not self.running:
self.running = True
self.task = asyncio.create_task(self._monitor_loop())
print("流量监控已启动")
async def stop_monitoring(self):
"""停止流量监控"""
if self.running:
self.running = False
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
print("流量监控已停止")
async def _monitor_loop(self):
"""监控主循环"""
last_stats = psutil.net_io_counters(pernic=True)
last_time = time.time()
while self.running:
await asyncio.sleep(self.update_interval)
current_time = time.time()
current_stats = psutil.net_io_counters(pernic=True)
elapsed = current_time - last_time
# 计算每个接口的流量速率
for iface in current_stats:
if iface not in self.history["interfaces"]:
# 修复:使用 self.history_size
self.history["interfaces"][iface] = {
"sent": deque(maxlen=self.history_size),
"recv": deque(maxlen=self.history_size)
}
if iface in last_stats:
sent_rate = (current_stats[iface].bytes_sent - last_stats[iface].bytes_sent) / elapsed
recv_rate = (current_stats[iface].bytes_recv - last_stats[iface].bytes_recv) / elapsed
# 保存到历史数据
self.history["sent"].append(sent_rate)
self.history["recv"].append(recv_rate)
self.history["time"].append(datetime.now())
# 保存到接口特定历史
self.history["interfaces"][iface]["sent"].append(sent_rate)
self.history["interfaces"][iface]["recv"].append(recv_rate)
# 保存到数据库
self._save_to_db(current_stats)
last_stats = current_stats
last_time = current_time
@staticmethod
def _save_to_db(stats):
"""保存流量数据到数据库"""
with SessionLocal() as session:
for iface, counters in stats.items():
record = TrafficRecord(
interface=iface,
bytes_sent=counters.bytes_sent,
bytes_recv=counters.bytes_recv,
packets_sent=counters.packets_sent,
packets_recv=counters.packets_recv,
timestamp=datetime.now()
)
session.add(record)
session.commit()
def get_current_traffic(self, interface: Optional[str] = None) -> Dict:
"""获取当前流量数据"""
stats = psutil.net_io_counters(pernic=True)
if interface:
if interface in stats:
return self._format_interface_stats(stats[interface])
return {}
return {iface: self._format_interface_stats(data) for iface, data in stats.items()}
@staticmethod
def _format_interface_stats(counters) -> Dict:
"""格式化接口统计数据"""
return {
"bytes_sent": counters.bytes_sent,
"bytes_recv": counters.bytes_recv,
"packets_sent": counters.packets_sent,
"packets_recv": counters.packets_recv,
"errin": counters.errin,
"errout": counters.errout,
"dropin": counters.dropin,
"dropout": counters.dropout
}
def get_traffic_history(self, interface: Optional[str] = None) -> Dict:
"""获取流量历史数据"""
if interface and interface in self.history["interfaces"]:
return {
"sent": list(self.history["interfaces"][interface]["sent"]),
"recv": list(self.history["interfaces"][interface]["recv"]),
"time": list(self.history["time"])
}
return {
"sent": list(self.history["sent"]),
"recv": list(self.history["recv"]),
"time": list(self.history["time"])
}
# 全局流量监控实例
traffic_monitor = TrafficMonitor()

View File

@ -1,3 +1,4 @@
from pydantic import BaseModel
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from dotenv import load_dotenv from dotenv import load_dotenv

View File

@ -25,3 +25,9 @@ tenacity==8.2.3
# 其他工具 # 其他工具
asyncio==3.4.3 asyncio==3.4.3
typing_extensions==4.10.0 typing_extensions==4.10.0
#监控依赖
psutil==5.9.8
matplotlib==3.8.3
sqlalchemy==2.0.28
fastapi_utils==0.2.1