mirror of
https://github.com/Jerryplusy/AI-powered-switches.git
synced 2025-07-04 21:29:18 +00:00
114514+114514
This commit is contained in:
parent
6e5cd34da7
commit
2231b8cf82
@ -3,7 +3,18 @@ from src.backend.app.api.endpoints import router
|
|||||||
from src.backend.app.utils.logger import setup_logging
|
from src.backend.app.utils.logger import setup_logging
|
||||||
from src.backend.config import settings
|
from src.backend.config import settings
|
||||||
|
|
||||||
|
# 添加正确的导入
|
||||||
|
from .services.traffic_monitor import traffic_monitor
|
||||||
|
from src.backend.app.api.database import init_db # 修复:导入 init_db
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
|
# 初始化数据库
|
||||||
|
init_db()
|
||||||
|
|
||||||
|
# 启动流量监控
|
||||||
|
traffic_monitor.start_monitoring()
|
||||||
|
|
||||||
# 设置日志
|
# 设置日志
|
||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
@ -29,7 +40,7 @@ def create_app() -> FastAPI:
|
|||||||
# 添加API路由
|
# 添加API路由
|
||||||
app.include_router(router, prefix=settings.API_PREFIX)
|
app.include_router(router, prefix=settings.API_PREFIX)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
17
src/backend/app/api/database.py
Normal file
17
src/backend/app/api/database.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
SQLALCHEMY_DATABASE_URL = "sqlite:///./traffic_monitor.db"
|
||||||
|
|
||||||
|
engine = create_engine(
|
||||||
|
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||||
|
)
|
||||||
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
"""初始化数据库"""
|
||||||
|
# 删除多余的导入
|
||||||
|
Base.metadata.create_all(bind=engine)
|
@ -1,13 +1,24 @@
|
|||||||
from fastapi import (APIRouter, HTTPException, Response)
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import (APIRouter, HTTPException, Response,WebSocket, WebSocketDisconnect)
|
||||||
from typing import List
|
from typing import List
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ...app.services.ai_services import AIService
|
from ...app.services.ai_services import AIService
|
||||||
from ...app.api.network_config import SwitchConfigurator
|
from ...app.api.network_config import SwitchConfigurator
|
||||||
from ...config import settings
|
from ...config import settings
|
||||||
from ..services.network_scanner import NetworkScanner
|
from ..services.network_scanner import NetworkScanner
|
||||||
|
from ...app.services.traffic_monitor import traffic_monitor
|
||||||
|
from ...app.models.traffic_models import TrafficRecord
|
||||||
|
from src.backend.app.api.database import (SessionLocal)
|
||||||
|
import asyncio
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["API"])
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="", tags=["API"])
|
||||||
scanner = NetworkScanner()
|
scanner = NetworkScanner()
|
||||||
|
|
||||||
# 添加根路径路由
|
# 添加根路径路由
|
||||||
@ -110,4 +121,84 @@ async def apply_config(request: ConfigRequest):
|
|||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"Failed to apply config: {str(e)}"
|
detail=f"Failed to apply config: {str(e)}"
|
||||||
)
|
)
|
||||||
|
@router.get("/traffic/interfaces", summary="获取所有网络接口")
|
||||||
|
async def get_network_interfaces():
|
||||||
|
return {
|
||||||
|
"interfaces": traffic_monitor.get_interfaces()
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.get("/traffic/current", summary="获取当前流量数据")
|
||||||
|
async def get_current_traffic(interface: str = None):
|
||||||
|
return traffic_monitor.get_current_traffic(interface)
|
||||||
|
|
||||||
|
@router.get("/traffic/history", summary="获取流量历史数据")
|
||||||
|
async def get_traffic_history(interface: str = None, limit: int = 100):
|
||||||
|
history = traffic_monitor.get_traffic_history(interface)
|
||||||
|
return {
|
||||||
|
"sent": history["sent"][-limit:],
|
||||||
|
"recv": history["recv"][-limit:],
|
||||||
|
"time": [t.isoformat() for t in history["time"]][-limit:]
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.get("/traffic/records", summary="获取流量记录")
|
||||||
|
async def get_traffic_records(interface: str = None, limit: int = 100):
|
||||||
|
with SessionLocal() as session:
|
||||||
|
query = session.query(TrafficRecord)
|
||||||
|
if interface:
|
||||||
|
query = query.filter(TrafficRecord.interface == interface)
|
||||||
|
records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all()
|
||||||
|
return [record.to_dict() for record in records]
|
||||||
|
|
||||||
|
@router.websocket("/ws/traffic")
|
||||||
|
async def websocket_traffic(websocket: WebSocket):
|
||||||
|
"""实时流量WebSocket"""
|
||||||
|
await websocket.accept()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# 获取所有接口的当前流量
|
||||||
|
traffic_data = traffic_monitor.get_current_traffic()
|
||||||
|
await websocket.send_json(traffic_data)
|
||||||
|
await asyncio.sleep(1) # 每秒更新一次
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
print("客户端断开连接")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/traffic/plot", response_class=HTMLResponse, summary="流量可视化图表")
|
||||||
|
async def plot_traffic(interface: str = "eth0", minutes: int = 10):
|
||||||
|
# 获取历史数据
|
||||||
|
history = traffic_monitor.get_traffic_history(interface)
|
||||||
|
time_points = history["time"][-minutes * 60:]
|
||||||
|
sent = history["sent"][-minutes * 60:]
|
||||||
|
recv = history["recv"][-minutes * 60:]
|
||||||
|
|
||||||
|
# 创建图表
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(time_points, sent, label="发送流量 (B/s)")
|
||||||
|
plt.plot(time_points, recv, label="接收流量 (B/s)")
|
||||||
|
plt.title(f"{interface} 流量监控 - 最近 {minutes} 分钟")
|
||||||
|
plt.xlabel("时间")
|
||||||
|
plt.ylabel("流量 (字节/秒)")
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
plt.xticks(rotation=45)
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# 转换为HTML图像
|
||||||
|
buf = io.BytesIO()
|
||||||
|
plt.savefig(buf, format="png")
|
||||||
|
buf.seek(0)
|
||||||
|
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>网络流量监控</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>{interface} 网络流量监控</h1>
|
||||||
|
<img src="data:image/png;base64,{image_base64}" alt="流量图表">
|
||||||
|
<p>更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
27
src/backend/app/models/traffic_models.py
Normal file
27
src/backend/app/models/traffic_models.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# 添加正确的导入
|
||||||
|
from src.backend.app.api.database import Base # 修复:导入 Base
|
||||||
|
from sqlalchemy import Column, Integer, String, DateTime
|
||||||
|
|
||||||
|
|
||||||
|
class TrafficRecord(Base):
|
||||||
|
"""网络流量记录模型"""
|
||||||
|
__tablename__ = "traffic_records"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
interface = Column(String(50), index=True)
|
||||||
|
bytes_sent = Column(Integer)
|
||||||
|
bytes_recv = Column(Integer)
|
||||||
|
packets_sent = Column(Integer)
|
||||||
|
packets_recv = Column(Integer)
|
||||||
|
timestamp = Column(DateTime)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"interface": self.interface,
|
||||||
|
"bytes_sent": self.bytes_sent,
|
||||||
|
"bytes_recv": self.bytes_recv,
|
||||||
|
"packets_sent": self.packets_sent,
|
||||||
|
"packets_recv": self.packets_recv,
|
||||||
|
"timestamp": self.timestamp.isoformat()
|
||||||
|
}
|
147
src/backend/app/services/traffic_monitor.py
Normal file
147
src/backend/app/services/traffic_monitor.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import psutil
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
from collections import deque
|
||||||
|
from typing import Dict, Optional, List
|
||||||
|
|
||||||
|
|
||||||
|
from ..models.traffic_models import TrafficRecord
|
||||||
|
from src.backend.app.api.database import SessionLocal # 修复:导入 SessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
class TrafficMonitor:
|
||||||
|
def __init__(self, history_size: int = 300):
|
||||||
|
self.history_size = history_size # 保存历史大小
|
||||||
|
self.history = {
|
||||||
|
"sent": deque(maxlen=history_size),
|
||||||
|
"recv": deque(maxlen=history_size),
|
||||||
|
"time": deque(maxlen=history_size),
|
||||||
|
"interfaces": {}
|
||||||
|
}
|
||||||
|
self.running = False
|
||||||
|
self.task = None
|
||||||
|
self.update_interval = 1.0 # 秒
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_interfaces() -> List[str]:
|
||||||
|
"""获取所有网络接口名称"""
|
||||||
|
return list(psutil.net_io_counters(pernic=True).keys())
|
||||||
|
|
||||||
|
def start_monitoring(self):
|
||||||
|
"""启动流量监控"""
|
||||||
|
if not self.running:
|
||||||
|
self.running = True
|
||||||
|
self.task = asyncio.create_task(self._monitor_loop())
|
||||||
|
print("流量监控已启动")
|
||||||
|
|
||||||
|
async def stop_monitoring(self):
|
||||||
|
"""停止流量监控"""
|
||||||
|
if self.running:
|
||||||
|
self.running = False
|
||||||
|
self.task.cancel()
|
||||||
|
try:
|
||||||
|
await self.task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
print("流量监控已停止")
|
||||||
|
|
||||||
|
async def _monitor_loop(self):
|
||||||
|
"""监控主循环"""
|
||||||
|
last_stats = psutil.net_io_counters(pernic=True)
|
||||||
|
last_time = time.time()
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
await asyncio.sleep(self.update_interval)
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
current_stats = psutil.net_io_counters(pernic=True)
|
||||||
|
elapsed = current_time - last_time
|
||||||
|
|
||||||
|
# 计算每个接口的流量速率
|
||||||
|
for iface in current_stats:
|
||||||
|
if iface not in self.history["interfaces"]:
|
||||||
|
# 修复:使用 self.history_size
|
||||||
|
self.history["interfaces"][iface] = {
|
||||||
|
"sent": deque(maxlen=self.history_size),
|
||||||
|
"recv": deque(maxlen=self.history_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface in last_stats:
|
||||||
|
sent_rate = (current_stats[iface].bytes_sent - last_stats[iface].bytes_sent) / elapsed
|
||||||
|
recv_rate = (current_stats[iface].bytes_recv - last_stats[iface].bytes_recv) / elapsed
|
||||||
|
|
||||||
|
# 保存到历史数据
|
||||||
|
self.history["sent"].append(sent_rate)
|
||||||
|
self.history["recv"].append(recv_rate)
|
||||||
|
self.history["time"].append(datetime.now())
|
||||||
|
|
||||||
|
# 保存到接口特定历史
|
||||||
|
self.history["interfaces"][iface]["sent"].append(sent_rate)
|
||||||
|
self.history["interfaces"][iface]["recv"].append(recv_rate)
|
||||||
|
|
||||||
|
# 保存到数据库
|
||||||
|
self._save_to_db(current_stats)
|
||||||
|
|
||||||
|
last_stats = current_stats
|
||||||
|
last_time = current_time
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save_to_db(stats):
|
||||||
|
"""保存流量数据到数据库"""
|
||||||
|
with SessionLocal() as session:
|
||||||
|
for iface, counters in stats.items():
|
||||||
|
record = TrafficRecord(
|
||||||
|
interface=iface,
|
||||||
|
bytes_sent=counters.bytes_sent,
|
||||||
|
bytes_recv=counters.bytes_recv,
|
||||||
|
packets_sent=counters.packets_sent,
|
||||||
|
packets_recv=counters.packets_recv,
|
||||||
|
timestamp=datetime.now()
|
||||||
|
)
|
||||||
|
session.add(record)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def get_current_traffic(self, interface: Optional[str] = None) -> Dict:
|
||||||
|
"""获取当前流量数据"""
|
||||||
|
stats = psutil.net_io_counters(pernic=True)
|
||||||
|
|
||||||
|
if interface:
|
||||||
|
if interface in stats:
|
||||||
|
return self._format_interface_stats(stats[interface])
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return {iface: self._format_interface_stats(data) for iface, data in stats.items()}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_interface_stats(counters) -> Dict:
|
||||||
|
"""格式化接口统计数据"""
|
||||||
|
return {
|
||||||
|
"bytes_sent": counters.bytes_sent,
|
||||||
|
"bytes_recv": counters.bytes_recv,
|
||||||
|
"packets_sent": counters.packets_sent,
|
||||||
|
"packets_recv": counters.packets_recv,
|
||||||
|
"errin": counters.errin,
|
||||||
|
"errout": counters.errout,
|
||||||
|
"dropin": counters.dropin,
|
||||||
|
"dropout": counters.dropout
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_traffic_history(self, interface: Optional[str] = None) -> Dict:
|
||||||
|
"""获取流量历史数据"""
|
||||||
|
if interface and interface in self.history["interfaces"]:
|
||||||
|
return {
|
||||||
|
"sent": list(self.history["interfaces"][interface]["sent"]),
|
||||||
|
"recv": list(self.history["interfaces"][interface]["recv"]),
|
||||||
|
"time": list(self.history["time"])
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sent": list(self.history["sent"]),
|
||||||
|
"recv": list(self.history["recv"]),
|
||||||
|
"time": list(self.history["time"])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局流量监控实例
|
||||||
|
traffic_monitor = TrafficMonitor()
|
@ -1,3 +1,4 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
@ -25,3 +25,9 @@ tenacity==8.2.3
|
|||||||
# 其他工具
|
# 其他工具
|
||||||
asyncio==3.4.3
|
asyncio==3.4.3
|
||||||
typing_extensions==4.10.0
|
typing_extensions==4.10.0
|
||||||
|
|
||||||
|
#监控依赖
|
||||||
|
psutil==5.9.8
|
||||||
|
matplotlib==3.8.3
|
||||||
|
sqlalchemy==2.0.28
|
||||||
|
fastapi_utils==0.2.1
|
Loading…
x
Reference in New Issue
Block a user