mirror of
https://github.com/Jerryplusy/AI-powered-switches.git
synced 2025-07-04 13:19:20 +00:00
114514+114514
This commit is contained in:
parent
6e5cd34da7
commit
2231b8cf82
@ -3,7 +3,18 @@ from src.backend.app.api.endpoints import router
|
||||
from src.backend.app.utils.logger import setup_logging
|
||||
from src.backend.config import settings
|
||||
|
||||
# 添加正确的导入
|
||||
from .services.traffic_monitor import traffic_monitor
|
||||
from src.backend.app.api.database import init_db # 修复:导入 init_db
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
# 初始化数据库
|
||||
init_db()
|
||||
|
||||
# 启动流量监控
|
||||
traffic_monitor.start_monitoring()
|
||||
|
||||
# 设置日志
|
||||
setup_logging()
|
||||
|
||||
@ -29,7 +40,7 @@ def create_app() -> FastAPI:
|
||||
# 添加API路由
|
||||
app.include_router(router, prefix=settings.API_PREFIX)
|
||||
|
||||
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
17
src/backend/app/api/database.py
Normal file
17
src/backend/app/api/database.py
Normal file
@ -0,0 +1,17 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./traffic_monitor.db"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库"""
|
||||
# 删除多余的导入
|
||||
Base.metadata.create_all(bind=engine)
|
@ -1,13 +1,24 @@
|
||||
from fastapi import (APIRouter, HTTPException, Response)
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import (APIRouter, HTTPException, Response,WebSocket, WebSocketDisconnect)
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...app.services.ai_services import AIService
|
||||
from ...app.api.network_config import SwitchConfigurator
|
||||
from ...config import settings
|
||||
from ..services.network_scanner import NetworkScanner
|
||||
from ...app.services.traffic_monitor import traffic_monitor
|
||||
from ...app.models.traffic_models import TrafficRecord
|
||||
from src.backend.app.api.database import (SessionLocal)
|
||||
import asyncio
|
||||
from fastapi.responses import HTMLResponse
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
import base64
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["API"])
|
||||
|
||||
|
||||
router = APIRouter(prefix="", tags=["API"])
|
||||
scanner = NetworkScanner()
|
||||
|
||||
# 添加根路径路由
|
||||
@ -110,4 +121,84 @@ async def apply_config(request: ConfigRequest):
|
||||
status_code=500,
|
||||
detail=f"Failed to apply config: {str(e)}"
|
||||
)
|
||||
@router.get("/traffic/interfaces", summary="获取所有网络接口")
|
||||
async def get_network_interfaces():
|
||||
return {
|
||||
"interfaces": traffic_monitor.get_interfaces()
|
||||
}
|
||||
|
||||
@router.get("/traffic/current", summary="获取当前流量数据")
|
||||
async def get_current_traffic(interface: str = None):
|
||||
return traffic_monitor.get_current_traffic(interface)
|
||||
|
||||
@router.get("/traffic/history", summary="获取流量历史数据")
|
||||
async def get_traffic_history(interface: str = None, limit: int = 100):
|
||||
history = traffic_monitor.get_traffic_history(interface)
|
||||
return {
|
||||
"sent": history["sent"][-limit:],
|
||||
"recv": history["recv"][-limit:],
|
||||
"time": [t.isoformat() for t in history["time"]][-limit:]
|
||||
}
|
||||
|
||||
@router.get("/traffic/records", summary="获取流量记录")
|
||||
async def get_traffic_records(interface: str = None, limit: int = 100):
|
||||
with SessionLocal() as session:
|
||||
query = session.query(TrafficRecord)
|
||||
if interface:
|
||||
query = query.filter(TrafficRecord.interface == interface)
|
||||
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
|
||||
return [record.to_dict() for record in records]
|
||||
|
||||
@router.websocket("/ws/traffic")
|
||||
async def websocket_traffic(websocket: WebSocket):
|
||||
"""实时流量WebSocket"""
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
# 获取所有接口的当前流量
|
||||
traffic_data = traffic_monitor.get_current_traffic()
|
||||
await websocket.send_json(traffic_data)
|
||||
await asyncio.sleep(1) # 每秒更新一次
|
||||
except WebSocketDisconnect:
|
||||
print("客户端断开连接")
|
||||
|
||||
|
||||
@router.get("/traffic/plot", response_class=HTMLResponse, summary="流量可视化图表")
|
||||
async def plot_traffic(interface: str = "eth0", minutes: int = 10):
|
||||
# 获取历史数据
|
||||
history = traffic_monitor.get_traffic_history(interface)
|
||||
time_points = history["time"][-minutes * 60:]
|
||||
sent = history["sent"][-minutes * 60:]
|
||||
recv = history["recv"][-minutes * 60:]
|
||||
|
||||
# 创建图表
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(time_points, sent, label="发送流量 (B/s)")
|
||||
plt.plot(time_points, recv, label="接收流量 (B/s)")
|
||||
plt.title(f"{interface} 流量监控 - 最近 {minutes} 分钟")
|
||||
plt.xlabel("时间")
|
||||
plt.ylabel("流量 (字节/秒)")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
|
||||
# 转换为HTML图像
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
|
||||
plt.close()
|
||||
|
||||
return f"""
|
||||
<html>
|
||||
<head>
|
||||
<title>网络流量监控</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>{interface} 网络流量监控</h1>
|
||||
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
|
||||
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
27
src/backend/app/models/traffic_models.py
Normal file
27
src/backend/app/models/traffic_models.py
Normal file
@ -0,0 +1,27 @@
|
||||
# 添加正确的导入
|
||||
from src.backend.app.api.database import Base # 修复:导入 Base
|
||||
from sqlalchemy import Column, Integer, String, DateTime
|
||||
|
||||
|
||||
class TrafficRecord(Base):
|
||||
"""网络流量记录模型"""
|
||||
__tablename__ = "traffic_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
interface = Column(String(50), index=True)
|
||||
bytes_sent = Column(Integer)
|
||||
bytes_recv = Column(Integer)
|
||||
packets_sent = Column(Integer)
|
||||
packets_recv = Column(Integer)
|
||||
timestamp = Column(DateTime)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"interface": self.interface,
|
||||
"bytes_sent": self.bytes_sent,
|
||||
"bytes_recv": self.bytes_recv,
|
||||
"packets_sent": self.packets_sent,
|
||||
"packets_recv": self.packets_recv,
|
||||
"timestamp": self.timestamp.isoformat()
|
||||
}
|
147
src/backend/app/services/traffic_monitor.py
Normal file
147
src/backend/app/services/traffic_monitor.py
Normal file
@ -0,0 +1,147 @@
|
||||
import psutil
|
||||
import time
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
|
||||
from ..models.traffic_models import TrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal # 修复:导入 SessionLocal
|
||||
|
||||
|
||||
class TrafficMonitor:
|
||||
def __init__(self, history_size: int = 300):
|
||||
self.history_size = history_size # 保存历史大小
|
||||
self.history = {
|
||||
"sent": deque(maxlen=history_size),
|
||||
"recv": deque(maxlen=history_size),
|
||||
"time": deque(maxlen=history_size),
|
||||
"interfaces": {}
|
||||
}
|
||||
self.running = False
|
||||
self.task = None
|
||||
self.update_interval = 1.0 # 秒
|
||||
|
||||
@staticmethod
|
||||
def get_interfaces() -> List[str]:
|
||||
"""获取所有网络接口名称"""
|
||||
return list(psutil.net_io_counters(pernic=True).keys())
|
||||
|
||||
def start_monitoring(self):
|
||||
"""启动流量监控"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._monitor_loop())
|
||||
print("流量监控已启动")
|
||||
|
||||
async def stop_monitoring(self):
|
||||
"""停止流量监控"""
|
||||
if self.running:
|
||||
self.running = False
|
||||
self.task.cancel()
|
||||
try:
|
||||
await self.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
print("流量监控已停止")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""监控主循环"""
|
||||
last_stats = psutil.net_io_counters(pernic=True)
|
||||
last_time = time.time()
|
||||
|
||||
while self.running:
|
||||
await asyncio.sleep(self.update_interval)
|
||||
|
||||
current_time = time.time()
|
||||
current_stats = psutil.net_io_counters(pernic=True)
|
||||
elapsed = current_time - last_time
|
||||
|
||||
# 计算每个接口的流量速率
|
||||
for iface in current_stats:
|
||||
if iface not in self.history["interfaces"]:
|
||||
# 修复:使用 self.history_size
|
||||
self.history["interfaces"][iface] = {
|
||||
"sent": deque(maxlen=self.history_size),
|
||||
"recv": deque(maxlen=self.history_size)
|
||||
}
|
||||
|
||||
if iface in last_stats:
|
||||
sent_rate = (current_stats[iface].bytes_sent - last_stats[iface].bytes_sent) / elapsed
|
||||
recv_rate = (current_stats[iface].bytes_recv - last_stats[iface].bytes_recv) / elapsed
|
||||
|
||||
# 保存到历史数据
|
||||
self.history["sent"].append(sent_rate)
|
||||
self.history["recv"].append(recv_rate)
|
||||
self.history["time"].append(datetime.now())
|
||||
|
||||
# 保存到接口特定历史
|
||||
self.history["interfaces"][iface]["sent"].append(sent_rate)
|
||||
self.history["interfaces"][iface]["recv"].append(recv_rate)
|
||||
|
||||
# 保存到数据库
|
||||
self._save_to_db(current_stats)
|
||||
|
||||
last_stats = current_stats
|
||||
last_time = current_time
|
||||
|
||||
@staticmethod
|
||||
def _save_to_db(stats):
|
||||
"""保存流量数据到数据库"""
|
||||
with SessionLocal() as session:
|
||||
for iface, counters in stats.items():
|
||||
record = TrafficRecord(
|
||||
interface=iface,
|
||||
bytes_sent=counters.bytes_sent,
|
||||
bytes_recv=counters.bytes_recv,
|
||||
packets_sent=counters.packets_sent,
|
||||
packets_recv=counters.packets_recv,
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
session.add(record)
|
||||
session.commit()
|
||||
|
||||
def get_current_traffic(self, interface: Optional[str] = None) -> Dict:
|
||||
"""获取当前流量数据"""
|
||||
stats = psutil.net_io_counters(pernic=True)
|
||||
|
||||
if interface:
|
||||
if interface in stats:
|
||||
return self._format_interface_stats(stats[interface])
|
||||
return {}
|
||||
|
||||
return {iface: self._format_interface_stats(data) for iface, data in stats.items()}
|
||||
|
||||
@staticmethod
|
||||
def _format_interface_stats(counters) -> Dict:
|
||||
"""格式化接口统计数据"""
|
||||
return {
|
||||
"bytes_sent": counters.bytes_sent,
|
||||
"bytes_recv": counters.bytes_recv,
|
||||
"packets_sent": counters.packets_sent,
|
||||
"packets_recv": counters.packets_recv,
|
||||
"errin": counters.errin,
|
||||
"errout": counters.errout,
|
||||
"dropin": counters.dropin,
|
||||
"dropout": counters.dropout
|
||||
}
|
||||
|
||||
def get_traffic_history(self, interface: Optional[str] = None) -> Dict:
|
||||
"""获取流量历史数据"""
|
||||
if interface and interface in self.history["interfaces"]:
|
||||
return {
|
||||
"sent": list(self.history["interfaces"][interface]["sent"]),
|
||||
"recv": list(self.history["interfaces"][interface]["recv"]),
|
||||
"time": list(self.history["time"])
|
||||
}
|
||||
|
||||
return {
|
||||
"sent": list(self.history["sent"]),
|
||||
"recv": list(self.history["recv"]),
|
||||
"time": list(self.history["time"])
|
||||
}
|
||||
|
||||
|
||||
# 全局流量监控实例
|
||||
traffic_monitor = TrafficMonitor()
|
@ -1,3 +1,4 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
@ -25,3 +25,9 @@ tenacity==8.2.3
|
||||
# 其他工具
|
||||
asyncio==3.4.3
|
||||
typing_extensions==4.10.0
|
||||
|
||||
#监控依赖
|
||||
psutil==5.9.8
|
||||
matplotlib==3.8.3
|
||||
sqlalchemy==2.0.28
|
||||
fastapi_utils==0.2.1
|
Loading…
x
Reference in New Issue
Block a user