diff --git a/src/backend/README.md b/src/backend/README.md index 0a34790..d0a536e 100644 --- a/src/backend/README.md +++ b/src/backend/README.md @@ -10,7 +10,6 @@ src/backend/ │ ├── __init__.py # 创建 Flask 应用实例 │ ├── api/ # API 路由模块 │ │ ├—── __init__.py # 注册 API 蓝图 -│ │ ├── command_parser.py # /api/parse_command 接口 │ │ └── network_config.py # /api/apply_config 接口 │ └── services/ # 核心服务逻辑 │ └── ai_services.py # 调用外部 AI 服务生成配置 diff --git a/src/backend/app/api/command_parser.py b/src/backend/app/api/command_parser.py deleted file mode 100644 index f198d10..0000000 --- a/src/backend/app/api/command_parser.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Dict, Any -from src.backend.app.services.ai_services import AIService -from src.backend.config import settings - - -class CommandParser: - def __init__(self): - self.ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL) - - async def parse(self, command: str) -> Dict[str, Any]: - """ - 解析中文命令并返回配置 - """ - return await self.ai_service.parse_command(command) diff --git a/src/backend/app/api/endpoints.py b/src/backend/app/api/endpoints.py index 4bcd616..20ed723 100644 --- a/src/backend/app/api/endpoints.py +++ b/src/backend/app/api/endpoints.py @@ -1,19 +1,30 @@ import socket -from fastapi import (APIRouter, HTTPException, Response) +from datetime import datetime, timedelta +from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect) from typing import List from pydantic import BaseModel +import asyncio +from fastapi.responses import HTMLResponse, JSONResponse +import matplotlib.pyplot as plt +import io +import base64 import psutil import ipaddress + +from ..services.switch_traffic_monitor import get_switch_monitor +from ..utils import logger from ...app.services.ai_services import AIService from ...app.api.network_config import SwitchConfigurator from ...config import settings from ..services.network_scanner import NetworkScanner - - +from ...app.services.traffic_monitor import traffic_monitor +from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord +from src.backend.app.api.database import SessionLocal router = APIRouter(prefix="", tags=["API"]) scanner = NetworkScanner() + @router.get("/", include_in_schema=False) async def root(): return { @@ -25,16 +36,18 @@ async def root(): "/apply_config", "/scan_network", "/list_devices", - "/batch_apply_config" + "/batch_apply_config", "/traffic/switch/current", "/traffic/switch/history" ] } + @router.get("/favicon.ico", include_in_schema=False) async def favicon(): return Response(status_code=204) + class BatchConfigRequest(BaseModel): config: dict switch_ips: List[str] @@ -42,14 +55,31 @@ class BatchConfigRequest(BaseModel): password: str = None timeout: int = None + +@router.post("/batch_apply_config") +async def batch_apply_config(request: BatchConfigRequest): + results = {} + for ip in request.switch_ips: + try: + configurator = SwitchConfigurator( + username=request.username, + password=request.password, + timeout=request.timeout) + results[ip] = await configurator.apply_config(ip, request.config) + except Exception as e: + results[ip] = str(e) + return {"results": results} + + @router.get("/test") async def test_endpoint(): return {"message": "Hello World"} + @router.get("/scan_network", summary="扫描网络中的交换机") async def scan_network(subnet: str = "192.168.1.0/24"): try: - devices = await scanner.scan_subnet(subnet) + devices = await asyncio.to_thread(scanner.scan_subnet, subnet) return { "success": True, "devices": devices, @@ -58,14 +88,18 @@ async def scan_network(subnet: str = "192.168.1.0/24"): except Exception as e: raise HTTPException(500, f"扫描失败: {str(e)}") + @router.get("/list_devices", summary="列出已发现的交换机") async def list_devices(): return { - "devices": await scanner.load_cached_devices() + "devices": await asyncio.to_thread(scanner.load_cached_devices) } + class CommandRequest(BaseModel): command: str + vendor: str = "huawei" + class ConfigRequest(BaseModel): config: dict @@ -73,15 +107,15 @@ class ConfigRequest(BaseModel): username: str = None password: str = None timeout: int = None + vendor: str = "huawei" + @router.post("/parse_command", response_model=dict) async def parse_command(request: CommandRequest): - """ - 解析中文命令并返回JSON配置 - """ + """解析中文命令并返回JSON配置""" try: ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL) - config = await ai_service.parse_command(request.command) + config = await ai_service.parse_command(request.command, request.vendor) return {"success": True, "config": config} except Exception as e: raise HTTPException( @@ -89,16 +123,16 @@ async def parse_command(request: CommandRequest): detail=f"Failed to parse command: {str(e)}" ) + @router.post("/apply_config", response_model=dict) async def apply_config(request: ConfigRequest): - """ - 应用配置到交换机(弃用) - """ + """应用配置到交换机""" try: configurator = SwitchConfigurator( username=request.username, password=request.password, - timeout=request.timeout + timeout=request.timeout, + vendor=request.vendor ) result = await configurator.safe_apply(request.switch_ip, request.config) return {"success": True, "result": result} @@ -132,14 +166,10 @@ class CLICommandRequest(BaseModel): return [cmd for cmd in self.commands if not (cmd.startswith("!username=") or cmd.startswith("!password="))] + @router.post("/execute_cli_commands", response_model=dict) async def execute_cli_commands(request: CLICommandRequest): - """ - 执行前端生成的CLI命令 - 支持在commands中嵌入凭据: - !username=admin - !password=cisco123 - """ + """执行前端生成的CLI命令""" try: username, password = request.extract_credentials() clean_commands = request.get_clean_commands() @@ -163,43 +193,255 @@ async def execute_cli_commands(request: CLICommandRequest): except Exception as e: raise HTTPException(500, detail=str(e)) -@router.get("/", include_in_schema=False) -async def root(): + +@router.get("/traffic/interfaces", summary="获取所有网络接口") +async def get_network_interfaces(): return { - "message": "欢迎使用AI交换机配置系统", - "docs": f"{settings.API_PREFIX}/docs", - "redoc": f"{settings.API_PREFIX}/redoc", - "endpoints": [ - "/parse_command", - "/apply_config", - "/scan_network", - "/list_devices", - "/batch_apply_config", - "/traffic/switch/current", - "/traffic/switch/history" - ] + "interfaces": await asyncio.to_thread(traffic_monitor.get_interfaces) } + + +@router.get("/traffic/current", summary="获取当前流量数据") +async def get_current_traffic(interface: str = None): + return await asyncio.to_thread(traffic_monitor.get_current_traffic, interface) + + +@router.get("/traffic/history", summary="获取流量历史数据") +async def get_traffic_history(interface: str = None, limit: int = 100): + history = await asyncio.to_thread(traffic_monitor.get_traffic_history, interface) + return { + "sent": history["sent"][-limit:], + "recv": history["recv"][-limit:], + "time": [t.isoformat() for t in history["time"]][-limit:] + } + + +@router.get("/traffic/records", summary="获取流量记录") +async def get_traffic_records(interface: str = None, limit: int = 100): + def sync_get_records(): + with SessionLocal() as session: + query = session.query(TrafficRecord) + if interface: + query = query.filter(TrafficRecord.interface == interface) + records = query.order_by(TrafficRecord.timestamp.desc()).limit(limit).all() + return [record.to_dict() for record in records] + + return await asyncio.to_thread(sync_get_records) + + +@router.websocket("/ws/traffic") +async def websocket_traffic(websocket: WebSocket): + await websocket.accept() + try: + while True: + traffic_data = await asyncio.to_thread(traffic_monitor.get_current_traffic) + await websocket.send_json(traffic_data) + await asyncio.sleep(1) + except WebSocketDisconnect: + print("客户端断开连接") + + +@router.get("/traffic/switch/interfaces", summary="获取交换机的网络接口") +async def get_switch_interfaces(switch_ip: str): + try: + monitor = get_switch_monitor(switch_ip) + interfaces = list(monitor.interface_oids.keys()) + return { + "switch_ip": switch_ip, + "interfaces": interfaces + } + except Exception as e: + logger.error(f"获取交换机接口失败: {str(e)}") + raise HTTPException(500, f"获取接口失败: {str(e)}") + + +async def get_interface_current_traffic(switch_ip: str, interface: str) -> dict: + """获取指定交换机接口的当前流量数据""" + try: + def sync_get_record(): + with SessionLocal() as session: + record = session.query(SwitchTrafficRecord).filter( + SwitchTrafficRecord.switch_ip == switch_ip, + SwitchTrafficRecord.interface == interface + ).order_by(SwitchTrafficRecord.timestamp.desc()).first() + + if not record: + return { + "switch_ip": switch_ip, + "interface": interface, + "rate_in": 0.0, + "rate_out": 0.0, + "bytes_in": 0, + "bytes_out": 0 + } + + return { + "switch_ip": switch_ip, + "interface": interface, + "rate_in": record.rate_in, + "rate_out": record.rate_out, + "bytes_in": record.bytes_in, + "bytes_out": record.bytes_out + } + + return await asyncio.to_thread(sync_get_record) + except Exception as e: + logger.error(f"获取接口流量失败: {str(e)}") + raise HTTPException(500, f"获取接口流量失败: {str(e)}") + + +@router.get("/traffic/switch/current", summary="获取交换机的当前流量数据") +async def get_switch_current_traffic(switch_ip: str, interface: str = None): + try: + monitor = get_switch_monitor(switch_ip) + if not interface: + traffic_data = {} + for iface in monitor.interface_oids: + traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface) + return { + "switch_ip": switch_ip, + "traffic": traffic_data + } + return await get_interface_current_traffic(switch_ip, interface) + except Exception as e: + logger.error(f"获取交换机流量失败: {str(e)}") + raise HTTPException(500, f"获取流量失败: {str(e)}") + + +@router.get("/traffic/switch/history", summary="获取交换机的流量历史数据") +async def get_switch_traffic_history(switch_ip: str, interface: str = None, minutes: int = 10): + try: + monitor = get_switch_monitor(switch_ip) + if not interface: + return { + "switch_ip": switch_ip, + "history": await asyncio.to_thread(monitor.get_traffic_history) + } + + def sync_get_history(): + with SessionLocal() as session: + time_threshold = datetime.now() - timedelta(minutes=minutes) + records = session.query(SwitchTrafficRecord).filter( + SwitchTrafficRecord.switch_ip == switch_ip, + SwitchTrafficRecord.interface == interface, + SwitchTrafficRecord.timestamp >= time_threshold + ).order_by(SwitchTrafficRecord.timestamp.asc()).all() + + history_data = { + "in": [record.rate_in for record in records], + "out": [record.rate_out for record in records], + "time": [record.timestamp.isoformat() for record in records] + } + return history_data + + history_data = await asyncio.to_thread(sync_get_history) + return { + "switch_ip": switch_ip, + "interface": interface, + "history": history_data + } + except Exception as e: + logger.error(f"获取历史流量失败: {str(e)}") + raise HTTPException(500, f"获取历史流量失败: {str(e)}") + + +@router.websocket("/ws/traffic/switch") +async def websocket_switch_traffic(websocket: WebSocket, switch_ip: str, interface: str = None): + await websocket.accept() + try: + monitor = get_switch_monitor(switch_ip) + while True: + if interface: + traffic_data = await get_interface_current_traffic(switch_ip, interface) + await websocket.send_json(traffic_data) + else: + traffic_data = {} + for iface in monitor.interface_oids: + traffic_data[iface] = await get_interface_current_traffic(switch_ip, iface) + await websocket.send_json({ + "switch_ip": switch_ip, + "traffic": traffic_data + }) + await asyncio.sleep(1) + except WebSocketDisconnect: + logger.info(f"客户端断开交换机流量连接: {switch_ip}") + except Exception as e: + logger.error(f"交换机流量WebSocket错误: {str(e)}") + await websocket.close(code=1011, reason=str(e)) + + +@router.get("/traffic/switch/plot", response_class=HTMLResponse, summary="交换机流量可视化") +async def plot_switch_traffic(switch_ip: str, interface: str, minutes: int = 10): + try: + history = await get_switch_traffic_history(switch_ip, interface, minutes) + history_data = history["history"] + time_points = [datetime.fromisoformat(t) for t in history_data["time"]] + in_rates = history_data["in"] + out_rates = history_data["out"] + + def generate_plot(): + plt.figure(figsize=(12, 6)) + plt.plot(time_points, in_rates, label="流入流量 (B/s)") + plt.plot(time_points, out_rates, label="流出流量 (B/s)") + plt.title(f"交换机 {switch_ip} 接口 {interface} 流量监控 - 最近 {minutes} 分钟") + plt.xlabel("时间") + plt.ylabel("流量 (字节/秒)") + plt.legend() + plt.grid(True) + plt.xticks(rotation=45) + plt.tight_layout() + buf = io.BytesIO() + plt.savefig(buf, format="png") + buf.seek(0) + image_base64 = base64.b64encode(buf.read()).decode("utf-8") + plt.close() + return image_base64 + + image_base64 = await asyncio.to_thread(generate_plot) + return f""" + + + 交换机流量监控 + + + +
+

交换机 {switch_ip} 接口 {interface} 流量监控

+ 流量图表 +

更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

+
+ + + """ + except Exception as e: + logger.error(f"生成流量图表失败: {str(e)}") + return HTMLResponse(content=f"

错误

{str(e)}

", status_code=500) + + @router.get("/network_adapters", summary="获取网络适配器网段") async def get_network_adapters(): try: - net_if_addrs = psutil.net_if_addrs() - - networks = [] - for interface, addrs in net_if_addrs.items(): - for addr in addrs: - if addr.family == socket.AF_INET: - ip = addr.address - netmask = addr.netmask - - network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False) - networks.append({ - "adapter": interface, - "network": str(network), - "ip": ip, - "subnet_mask": netmask - }) + def sync_get_adapters(): + net_if_addrs = psutil.net_if_addrs() + networks = [] + for interface, addrs in net_if_addrs.items(): + for addr in addrs: + if addr.family == socket.AF_INET: + ip = addr.address + netmask = addr.netmask + network = ipaddress.IPv4Network(f"{ip}/{netmask}", strict=False) + networks.append({ + "adapter": interface, + "network": str(network), + "ip": ip, + "subnet_mask": netmask + }) + return networks + networks = await asyncio.to_thread(sync_get_adapters) return {"networks": networks} - except Exception as e: return {"error": f"获取网络适配器信息失败: {str(e)}"} \ No newline at end of file diff --git a/src/backend/app/services/ai_services.py b/src/backend/app/services/ai_services.py index d315477..2987bd0 100644 --- a/src/backend/app/services/ai_services.py +++ b/src/backend/app/services/ai_services.py @@ -1,7 +1,5 @@ -from typing import Dict, Any, Coroutine - -import httpx -from openai import OpenAI +from typing import Any +from openai import AsyncOpenAI import json from src.backend.app.utils.exceptions import SiliconFlowAPIException from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam @@ -12,28 +10,44 @@ class AIService: def __init__(self, api_key: str, api_url: str): self.api_key = api_key self.api_url = api_url - self.client = OpenAI( + self.client = AsyncOpenAI( api_key=self.api_key, base_url=self.api_url, # timeout=httpx.Timeout(30.0) ) - async def parse_command(self, command: str) -> Any | None: + async def parse_command(self, command: str, vendor: str = "huawei") -> Any | None: """ 调用硅基流动API解析中文命令 """ - prompt = """ - 你是一个网络设备配置专家,精通各种类型的路由器的配置,请将以下用户的中文命令转换为网络设备配置JSON - 但是请注意,由于贪婪的人们追求极高的效率,所以你必须严格按照 JSON 格式返回数据,不要包含任何额外文本或 Markdown 代码块 + vendor_prompts = { + "huawei": "华为交换机配置命令", + "cisco": "思科交换机配置命令", + "h3c": "H3C交换机配置命令", + "ruijie": "锐捷交换机配置命令", + "zte": "中兴交换机配置命令" + } + + prompt = f""" + 你是一个网络设备配置专家,精通各种类型的路由器的配置,请将以下用户的中文命令转换为{vendor_prompts.get(vendor, '网络设备')}配置JSON。 + 但是请注意,由于贪婪的人们追求极高的效率,所以你必须严格按照 JSON 格式返回数据,不要包含任何额外文本或 Markdown 代码块。 返回格式要求: 1. 必须包含'type'字段指明配置类型(vlan/interface/acl/route等) 2. 必须包含'commands'字段,包含可直接执行的命令列表 3. 其他参数根据配置类型动态添加 4. 不要包含解释性文本、步骤说明或注释 - 5.要包含使用ssh连接交换机后的完整命令包括但不完全包括system-view,退出,保存等完整操作,注意保存还需要输入Y + 5. 要包含使用ssh连接交换机后的完整命令包括但不完全包括system-view,退出,保存等完整操作,注意保存还需要输入Y + + 根据厂商{vendor}的不同,命令格式如下: + - 华为: system-view → quit → save Y + - 思科: enable → configure terminal → exit → write memory + - H3C: system-view → quit → save + - 锐捷: enable → configure terminal → exit → write + - 中兴: enable → configure terminal → exit → write memory + 示例命令:'创建VLAN 100,名称为TEST' - 示例返回:{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["system-view","vlan 100", "name TEST","quit","quit","save","Y"]} - 注意:这里生成的commands中需包含登录交换机和保存等所有操作命令,我们使ssh连接交换机,你不需要给出连接ssh的命令,你只需要给出使用ssh连接到交换机后所输入的全部命令,并且注意在system-view状态下是不能save的,需要再quit到用户视图 + 华为示例返回:{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["system-view","vlan 100", "name TEST","quit","quit","save","Y"]}} + 思科示例返回:{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["enable","configure terminal","vlan 100", "name TEST","exit","exit","write memory"]}} """ messages = [ @@ -42,7 +56,7 @@ class AIService: ] try: - response = self.client.chat.completions.create( + response = await self.client.chat.completions.create( model="deepseek-ai/DeepSeek-V3", messages=messages, temperature=0.3, diff --git a/src/backend/app/services/network_scanner.py b/src/backend/app/services/network_scanner.py index 77ac5ff..bdb7f0b 100644 --- a/src/backend/app/services/network_scanner.py +++ b/src/backend/app/services/network_scanner.py @@ -9,13 +9,13 @@ class NetworkScanner: self.cache_path = Path(cache_path) self.nm = nmap.PortScanner() - async def scan_subnet(self, subnet: str = "192.168.1.0/24", ports: List[int] = [22, 23, 80]) -> List[Dict]: + def scan_subnet(self, subnet: str = "192.168.1.0/24", ports: List[int] = [22, 23, 80]) -> List[Dict]: """扫描指定子网的设备,获取设备信息和开放端口""" logger.info(f"Scanning subnet: {subnet}") devices = [] try: - await self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}') + self.nm.scan(hosts=subnet, arguments=f'-p {",".join(map(str, ports))}') for host in self.nm.all_hosts(): ip = host mac = self.nm[host]['addresses'].get('mac', 'N/A') @@ -33,19 +33,19 @@ class NetworkScanner: except Exception as e: logger.error(f"Error while scanning subnet: {e}") - await self._save_to_cache(devices) + self._save_to_cache(devices) return devices - async def _save_to_cache(self, devices: List[Dict]): + def _save_to_cache(self, devices: List[Dict]): """保存扫描结果到本地文件""" with open(self.cache_path, "w") as f: json.dump(devices, f, indent=2) logger.info(f"Saved {len(devices)} devices to cache") - async def load_cached_devices(self) -> List[Dict]: + def load_cached_devices(self) -> List[Dict]: """从缓存加载设备列表""" if not self.cache_path.exists(): return [] with open(self.cache_path) as f: - return json.load(f) + return json.load(f) \ No newline at end of file