diff --git a/src/backend/app/__init__.py b/src/backend/app/__init__.py index 55b9b05..f614c63 100644 --- a/src/backend/app/__init__.py +++ b/src/backend/app/__init__.py @@ -3,7 +3,18 @@ from src.backend.app.api.endpoints import router from src.backend.app.utils.logger import setup_logging from src.backend.config import settings +# 添加正确的导入 +from .services.traffic_monitor import traffic_monitor +from src.backend.app.api.database import init_db # 修复:导入 init_db + + def create_app() -> FastAPI: + # 初始化数据库 + init_db() + + # 启动流量监控 + traffic_monitor.start_monitoring() + # 设置日志 setup_logging() @@ -29,7 +40,7 @@ def create_app() -> FastAPI: # 添加API路由 app.include_router(router, prefix=settings.API_PREFIX) - - return app + + app = create_app() \ No newline at end of file diff --git a/src/backend/app/api/database.py b/src/backend/app/api/database.py new file mode 100644 index 0000000..d02e658 --- /dev/null +++ b/src/backend/app/api/database.py @@ -0,0 +1,17 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +SQLALCHEMY_DATABASE_URL = "sqlite:///./traffic_monitor.db" + +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() + +def init_db(): + """初始化数据库""" + # 删除多余的导入 + Base.metadata.create_all(bind=engine) \ No newline at end of file diff --git a/src/backend/app/api/endpoints.py b/src/backend/app/api/endpoints.py index 8e77b5d..20ab3d6 100644 --- a/src/backend/app/api/endpoints.py +++ b/src/backend/app/api/endpoints.py @@ -1,13 +1,24 @@ -from fastapi import (APIRouter, HTTPException, Response) +from datetime import datetime + +from fastapi import (APIRouter, HTTPException, Response,WebSocket, WebSocketDisconnect) from typing import List from pydantic import BaseModel - from ...app.services.ai_services import AIService from ...app.api.network_config import SwitchConfigurator from ...config import settings from ..services.network_scanner import NetworkScanner +from ...app.services.traffic_monitor import traffic_monitor +from ...app.models.traffic_models import TrafficRecord +from src.backend.app.api.database import (SessionLocal) +import asyncio +from fastapi.responses import HTMLResponse +import matplotlib.pyplot as plt +import io +import base64 -router = APIRouter(prefix="/api", tags=["API"]) + + +router = APIRouter(prefix="", tags=["API"]) scanner = NetworkScanner() # 添加根路径路由 @@ -110,4 +121,84 @@ async def apply_config(request: ConfigRequest): status_code=500, detail=f"Failed to apply config: {str(e)}" ) +@router.get("/traffic/interfaces", summary="获取所有网络接口") +async def get_network_interfaces(): + return { + "interfaces": traffic_monitor.get_interfaces() + } +@router.get("/traffic/current", summary="获取当前流量数据") +async def get_current_traffic(interface: str = None): + return traffic_monitor.get_current_traffic(interface) + +@router.get("/traffic/history", summary="获取流量历史数据") +async def get_traffic_history(interface: str = None, limit: int = 100): + history = traffic_monitor.get_traffic_history(interface) + return { + "sent": history["sent"][-limit:], + "recv": history["recv"][-limit:], + "time": [t.isoformat() for t in history["time"]][-limit:] + } + +@router.get("/traffic/records", summary="获取流量记录") +async def get_traffic_records(interface: str = None, limit: int = 100): + with SessionLocal() as session: + query = session.query(TrafficRecord) + if interface: + query = query.filter(TrafficRecord.interface == interface) + records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all() + return [record.to_dict() for record in records] + +@router.websocket("/ws/traffic") +async def websocket_traffic(websocket: WebSocket): + """实时流量WebSocket""" + await websocket.accept() + try: + while True: + # 获取所有接口的当前流量 + traffic_data = traffic_monitor.get_current_traffic() + await websocket.send_json(traffic_data) + await asyncio.sleep(1) # 每秒更新一次 + except WebSocketDisconnect: + print("客户端断开连接") + + +@router.get("/traffic/plot", response_class=HTMLResponse, summary="流量可视化图表") +async def plot_traffic(interface: str = "eth0", minutes: int = 10): + # 获取历史数据 + history = traffic_monitor.get_traffic_history(interface) + time_points = history["time"][-minutes * 60:] + sent = history["sent"][-minutes * 60:] + recv = history["recv"][-minutes * 60:] + + # 创建图表 + plt.figure(figsize=(10, 6)) + plt.plot(time_points, sent, label="发送流量 (B/s)") + plt.plot(time_points, recv, label="接收流量 (B/s)") + plt.title(f"{interface} 流量监控 - 最近 {minutes} 分钟") + plt.xlabel("时间") + plt.ylabel("流量 (字节/秒)") + plt.legend() + plt.grid(True) + plt.xticks(rotation=45) + plt.tight_layout() + + # 转换为HTML图像 + buf = io.BytesIO() + plt.savefig(buf, format="png") + buf.seek(0) + image_base64 = base64.b64encode(buf.read()).decode("utf-8") + plt.close() + + return f""" + + + 网络流量监控 + + +

{interface} 网络流量监控

+ 流量图表 +

更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

+ + + """ \ No newline at end of file diff --git a/src/backend/app/models/traffic_models.py b/src/backend/app/models/traffic_models.py new file mode 100644 index 0000000..f6900f2 --- /dev/null +++ b/src/backend/app/models/traffic_models.py @@ -0,0 +1,27 @@ +# 添加正确的导入 +from src.backend.app.api.database import Base # 修复:导入 Base +from sqlalchemy import Column, Integer, String, DateTime + + +class TrafficRecord(Base): + """网络流量记录模型""" + __tablename__ = "traffic_records" + + id = Column(Integer, primary_key=True, index=True) + interface = Column(String(50), index=True) + bytes_sent = Column(Integer) + bytes_recv = Column(Integer) + packets_sent = Column(Integer) + packets_recv = Column(Integer) + timestamp = Column(DateTime) + + def to_dict(self): + return { + "id": self.id, + "interface": self.interface, + "bytes_sent": self.bytes_sent, + "bytes_recv": self.bytes_recv, + "packets_sent": self.packets_sent, + "packets_recv": self.packets_recv, + "timestamp": self.timestamp.isoformat() + } \ No newline at end of file diff --git a/src/backend/app/services/traffic_monitor.py b/src/backend/app/services/traffic_monitor.py new file mode 100644 index 0000000..b068321 --- /dev/null +++ b/src/backend/app/services/traffic_monitor.py @@ -0,0 +1,147 @@ +import psutil +import time +import asyncio +from datetime import datetime +from collections import deque +from typing import Dict, Optional, List + + +from ..models.traffic_models import TrafficRecord +from src.backend.app.api.database import SessionLocal # 修复:导入 SessionLocal + + +class TrafficMonitor: + def __init__(self, history_size: int = 300): + self.history_size = history_size # 保存历史大小 + self.history = { + "sent": deque(maxlen=history_size), + "recv": deque(maxlen=history_size), + "time": deque(maxlen=history_size), + "interfaces": {} + } + self.running = False + self.task = None + self.update_interval = 1.0 # 秒 + + @staticmethod + def get_interfaces() -> List[str]: + """获取所有网络接口名称""" + return list(psutil.net_io_counters(pernic=True).keys()) + + def start_monitoring(self): + """启动流量监控""" + if not self.running: + self.running = True + self.task = asyncio.create_task(self._monitor_loop()) + print("流量监控已启动") + + async def stop_monitoring(self): + """停止流量监控""" + if self.running: + self.running = False + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + print("流量监控已停止") + + async def _monitor_loop(self): + """监控主循环""" + last_stats = psutil.net_io_counters(pernic=True) + last_time = time.time() + + while self.running: + await asyncio.sleep(self.update_interval) + + current_time = time.time() + current_stats = psutil.net_io_counters(pernic=True) + elapsed = current_time - last_time + + # 计算每个接口的流量速率 + for iface in current_stats: + if iface not in self.history["interfaces"]: + # 修复:使用 self.history_size + self.history["interfaces"][iface] = { + "sent": deque(maxlen=self.history_size), + "recv": deque(maxlen=self.history_size) + } + + if iface in last_stats: + sent_rate = (current_stats[iface].bytes_sent - last_stats[iface].bytes_sent) / elapsed + recv_rate = (current_stats[iface].bytes_recv - last_stats[iface].bytes_recv) / elapsed + + # 保存到历史数据 + self.history["sent"].append(sent_rate) + self.history["recv"].append(recv_rate) + self.history["time"].append(datetime.now()) + + # 保存到接口特定历史 + self.history["interfaces"][iface]["sent"].append(sent_rate) + self.history["interfaces"][iface]["recv"].append(recv_rate) + + # 保存到数据库 + self._save_to_db(current_stats) + + last_stats = current_stats + last_time = current_time + + @staticmethod + def _save_to_db(stats): + """保存流量数据到数据库""" + with SessionLocal() as session: + for iface, counters in stats.items(): + record = TrafficRecord( + interface=iface, + bytes_sent=counters.bytes_sent, + bytes_recv=counters.bytes_recv, + packets_sent=counters.packets_sent, + packets_recv=counters.packets_recv, + timestamp=datetime.now() + ) + session.add(record) + session.commit() + + def get_current_traffic(self, interface: Optional[str] = None) -> Dict: + """获取当前流量数据""" + stats = psutil.net_io_counters(pernic=True) + + if interface: + if interface in stats: + return self._format_interface_stats(stats[interface]) + return {} + + return {iface: self._format_interface_stats(data) for iface, data in stats.items()} + + @staticmethod + def _format_interface_stats(counters) -> Dict: + """格式化接口统计数据""" + return { + "bytes_sent": counters.bytes_sent, + "bytes_recv": counters.bytes_recv, + "packets_sent": counters.packets_sent, + "packets_recv": counters.packets_recv, + "errin": counters.errin, + "errout": counters.errout, + "dropin": counters.dropin, + "dropout": counters.dropout + } + + def get_traffic_history(self, interface: Optional[str] = None) -> Dict: + """获取流量历史数据""" + if interface and interface in self.history["interfaces"]: + return { + "sent": list(self.history["interfaces"][interface]["sent"]), + "recv": list(self.history["interfaces"][interface]["recv"]), + "time": list(self.history["time"]) + } + + return { + "sent": list(self.history["sent"]), + "recv": list(self.history["recv"]), + "time": list(self.history["time"]) + } + + +# 全局流量监控实例 +traffic_monitor = TrafficMonitor() \ No newline at end of file diff --git a/src/backend/config.py b/src/backend/config.py index fd68e88..e79d088 100644 --- a/src/backend/config.py +++ b/src/backend/config.py @@ -1,3 +1,4 @@ +from pydantic import BaseModel from pydantic_settings import BaseSettings from dotenv import load_dotenv diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index 76723fb..24c8fed 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -24,4 +24,10 @@ tenacity==8.2.3 # 其他工具 asyncio==3.4.3 -typing_extensions==4.10.0 \ No newline at end of file +typing_extensions==4.10.0 + +#监控依赖 +psutil==5.9.8 +matplotlib==3.8.3 +sqlalchemy==2.0.28 +fastapi_utils==0.2.1 \ No newline at end of file