import asyncio from datetime import datetime from collections import deque from typing import Optional, List, Dict from pysnmp.hlapi import * from ..models.traffic_models import SwitchTrafficRecord from src.backend.app.api.database import SessionLocal from ..utils.logger import logger class SwitchTrafficMonitor: def __init__( self, switch_ip: str, community: str = 'public', update_interval: int = 5, interfaces: Optional[List[str]] = None ): self.switch_ip = switch_ip self.community = community self.update_interval = update_interval self.running = False self.task = None self.interface_history = {} self.history = { "in": deque(maxlen=300), "out": deque(maxlen=300), "time": deque(maxlen=300) } self.interface_oids = { "GigabitEthernet0/0/1": { "in": '1.3.6.1.2.1.2.2.1.10.1', "out": '1.3.6.1.2.1.2.2.1.16.1' }, "GigabitEthernet0/0/24": { "in": '1.3.6.1.2.1.2.2.1.10.24', "out": '1.3.6.1.2.1.2.2.1.16.24' } } if interfaces: self.interface_oids = { iface: oid for iface, oid in self.interface_oids.items() if iface in interfaces } logger.info(f"监控指定接口: {', '.join(interfaces)}") else: logger.info("监控所有接口") def start_monitoring(self): """启动交换机流量监控""" if not self.running: self.running = True self.task = asyncio.create_task(self._monitor_loop()) logger.success(f"交换机流量监控已启动: {self.switch_ip}") async def stop_monitoring(self): """停止监控""" if self.running: self.running = False if self.task: self.task.cancel() try: await self.task except asyncio.CancelledError: pass logger.info(f"交换机流量监控已停止: {self.switch_ip}") async def _monitor_loop(self): """监控主循环""" last_values = {iface: {"in": 0, "out": 0} for iface in self.interface_oids} last_time = datetime.now() while self.running: await asyncio.sleep(self.update_interval) try: current_time = datetime.now() elapsed = (current_time - last_time).total_seconds() for iface, oids in self.interface_oids.items(): in_octets = self._snmp_get(oids["in"]) out_octets = self._snmp_get(oids["out"]) if in_octets is not None and out_octets is not None: iface_values = last_values[iface] in_rate = (in_octets - iface_values["in"]) / elapsed if iface_values["in"] > 0 else 0 out_rate = (out_octets - iface_values["out"]) / elapsed if iface_values["out"] > 0 else 0 self.history["in"].append(in_rate) self.history["out"].append(out_rate) self.history["time"].append(current_time) self._save_to_db(iface, in_octets, out_octets, in_rate, out_rate, current_time) iface_values["in"] = in_octets iface_values["out"] = out_octets last_time = current_time except Exception as e: logger.error(f"监控交换机流量出错: {str(e)}") def _snmp_get(self, oid) -> Optional[int]: """执行SNMP GET请求""" try: cmd = getCmd( SnmpEngine(), CommunityData(self.community), UdpTransportTarget((self.switch_ip, 161)), ContextData(), ObjectType(ObjectIdentity(oid))) errorIndication, errorStatus, errorIndex, varBinds = next(cmd) except Exception as e: logger.error(f"SNMP请求失败: {str(e)}") return None if errorIndication: logger.error(f"SNMP错误: {errorIndication}") return None elif errorStatus: try: if errorIndex: index_val = int(errorIndex) - 1 error_item = varBinds[index_val] if index_val < len(varBinds) else '?' else: error_item = '?' error_msg = f"SNMP错误: {errorStatus.prettyPrint()} at {error_item}" logger.error(error_msg) except Exception as e: logger.error(f"解析SNMP错误失败: {str(e)}") return None else: for varBind in varBinds: try: return int(varBind[1]) except Exception as e: logger.error(f"转换SNMP值失败: {str(e)}") return None return None def _save_to_db(self, interface: str, in_octets: int, out_octets: int, in_rate: float, out_rate: float, timestamp: datetime): """保存流量数据到数据库""" try: with SessionLocal() as session: record = SwitchTrafficRecord( switch_ip=self.switch_ip, interface=interface, bytes_in=in_octets, bytes_out=out_octets, rate_in=in_rate, rate_out=out_rate, timestamp=timestamp ) session.add(record) session.commit() except Exception as e: logger.error(f"保存流量数据到数据库失败: {str(e)}") def get_traffic_history(self) -> Dict[str, List]: """获取流量历史数据""" return { "in": list(self.history["in"]), "out": list(self.history["out"]), "time": list(self.history["time"]) } switch_monitors = {} def get_switch_monitor(switch_ip: str, community: str = 'public', interfaces: Optional[List[str]] = None): """获取或创建交换机监控器""" if switch_ip not in switch_monitors: switch_monitors[switch_ip] = SwitchTrafficMonitor( switch_ip, community, interfaces=interfaces ) return switch_monitors[switch_ip]