mirror of
https://github.com/Jerryplusy/AI-powered-switches.git
synced 2025-10-14 09:49:19 +00:00
Merge remote-tracking branch 'origin/main'
This commit is contained in:
commit
29dd4ec839
@ -1,3 +1,5 @@
|
||||
# File: D:\Python work\AI-powered-switches\src\backend\app\api\endpoints.py
|
||||
|
||||
import socket
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import (APIRouter, HTTPException, Response, WebSocket, WebSocketDisconnect)
|
||||
@ -10,8 +12,8 @@ import io
|
||||
import base64
|
||||
import psutil
|
||||
import ipaddress
|
||||
import json
|
||||
|
||||
from ..models.requests import CLICommandRequest, ConfigRequest
|
||||
from ..services.switch_traffic_monitor import get_switch_monitor
|
||||
from ..utils import logger
|
||||
from ...app.services.ai_services import AIService
|
||||
@ -21,6 +23,9 @@ from ..services.network_scanner import NetworkScanner
|
||||
from ...app.services.traffic_monitor import traffic_monitor
|
||||
from ...app.models.traffic_models import TrafficRecord, SwitchTrafficRecord
|
||||
from src.backend.app.api.database import SessionLocal
|
||||
from ..services.network_visualizer import NetworkVisualizer
|
||||
from ..services.config_validator import ConfigValidator
|
||||
from ..services.report_generator import ReportGenerator
|
||||
|
||||
router = APIRouter(prefix="", tags=["API"])
|
||||
scanner = NetworkScanner()
|
||||
@ -56,6 +61,22 @@ class BatchConfigRequest(BaseModel):
|
||||
password: str = None
|
||||
timeout: int = None
|
||||
|
||||
|
||||
@router.post("/batch_apply_config")
|
||||
async def batch_apply_config(request: BatchConfigRequest):
|
||||
results = {}
|
||||
for ip in request.switch_ips:
|
||||
try:
|
||||
configurator = SwitchConfigurator(
|
||||
username=request.username,
|
||||
password=request.password,
|
||||
timeout=request.timeout)
|
||||
results[ip] = await configurator.apply_config(ip, request.config)
|
||||
except Exception as e:
|
||||
results[ip] = str(e)
|
||||
return {"results": results}
|
||||
|
||||
|
||||
@router.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "Hello World"}
|
||||
@ -81,29 +102,27 @@ async def list_devices():
|
||||
}
|
||||
|
||||
|
||||
class DeviceItem(BaseModel):
|
||||
name: str
|
||||
ip: str
|
||||
vendor: str
|
||||
|
||||
class CommandRequest(BaseModel):
|
||||
command: str
|
||||
devices: List[DeviceItem]
|
||||
vendor: str = "huawei"
|
||||
|
||||
|
||||
class ConfigRequest(BaseModel):
|
||||
config: dict
|
||||
switch_ip: str
|
||||
username: str = None
|
||||
password: str = None
|
||||
timeout: int = None
|
||||
vendor: str = "huawei"
|
||||
|
||||
|
||||
@router.post("/parse_command", response_model=dict)
|
||||
async def parse_command(request: CommandRequest):
|
||||
"""解析中文命令并返回每台设备的配置 JSON"""
|
||||
missing_vendor = [d for d in request.devices if not d.vendor or d.vendor.strip() == ""]
|
||||
if missing_vendor:
|
||||
names = ", ".join([d.name for d in missing_vendor])
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"以下设备未配置厂商: {names}"
|
||||
)
|
||||
"""解析中文命令并返回JSON配置"""
|
||||
try:
|
||||
ai_service = AIService(settings.SILICONFLOW_API_KEY, settings.SILICONFLOW_API_URL)
|
||||
config = await ai_service.parse_command(request.command, [d.dict() for d in request.devices])
|
||||
return {"success": True, "config": config.get("results", [])}
|
||||
config = await ai_service.parse_command(request.command, request.vendor)
|
||||
return {"success": True, "config": config}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -128,16 +147,44 @@ async def apply_config(request: ConfigRequest):
|
||||
status_code=500,
|
||||
detail=f"Failed to apply config: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class CLICommandRequest(BaseModel):
|
||||
switch_ip: str
|
||||
commands: List[str]
|
||||
is_ensp: bool = False
|
||||
|
||||
def extract_credentials(self) -> tuple:
|
||||
"""从commands中提取用户名和密码"""
|
||||
username = None
|
||||
password = None
|
||||
|
||||
for cmd in self.commands:
|
||||
if cmd.startswith("!username="):
|
||||
username = cmd.split("=")[1]
|
||||
elif cmd.startswith("!password="):
|
||||
password = cmd.split("=")[1]
|
||||
|
||||
return username, password
|
||||
|
||||
def get_clean_commands(self) -> List[str]:
|
||||
"""获取去除凭据后的实际命令"""
|
||||
return [cmd for cmd in self.commands
|
||||
if not (cmd.startswith("!username=") or cmd.startswith("!password="))]
|
||||
|
||||
|
||||
@router.post("/execute_cli_commands", response_model=dict)
|
||||
async def execute_cli_commands(request: CLICommandRequest):
|
||||
"""执行前端生成的CLI命令"""
|
||||
try:
|
||||
username, password = request.extract_credentials()
|
||||
clean_commands = request.get_clean_commands()
|
||||
|
||||
configurator = SwitchConfigurator(
|
||||
username=username,
|
||||
password=password,
|
||||
timeout=settings.SWITCH_TIMEOUT,
|
||||
ensp_mode=request.is_ensp
|
||||
)
|
||||
|
||||
result = await configurator.execute_raw_commands(
|
||||
@ -147,6 +194,7 @@ async def execute_cli_commands(request: CLICommandRequest):
|
||||
return {
|
||||
"success": True,
|
||||
"output": result,
|
||||
"mode": "eNSP" if request.is_ensp else "SSH"
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
@ -403,3 +451,69 @@ async def get_network_adapters():
|
||||
return {"networks": networks}
|
||||
except Exception as e:
|
||||
return {"error": f"获取网络适配器信息失败: {str(e)}"}
|
||||
|
||||
|
||||
visualizer = NetworkVisualizer()
|
||||
report_gen = ReportGenerator()
|
||||
|
||||
|
||||
@router.get("/topology/visualize", response_class=HTMLResponse)
|
||||
async def visualize_topology():
|
||||
"""获取网络拓扑可视化图"""
|
||||
try:
|
||||
devices = await list_devices()
|
||||
await asyncio.to_thread(visualizer.update_topology, devices["devices"])
|
||||
image_data = await asyncio.to_thread(visualizer.generate_topology_image)
|
||||
return f"""
|
||||
<html>
|
||||
<head><title>Network Topology</title></head>
|
||||
<body>
|
||||
<h1>Network Topology</h1>
|
||||
<img src="data:image/png;base64,{image_data}" alt="Network Topology">
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/config/validate")
|
||||
async def validate_config(config: dict):
|
||||
"""验证配置有效性"""
|
||||
is_valid, errors = await asyncio.to_thread(ConfigValidator.validate_full_config, config)
|
||||
return {
|
||||
"valid": is_valid,
|
||||
"errors": errors,
|
||||
"has_security_risks": len(
|
||||
await asyncio.to_thread(ConfigValidator.check_security_risks, config.get("commands", []))) > 0
|
||||
}
|
||||
|
||||
|
||||
@router.get("/reports/traffic/{ip}")
|
||||
async def get_traffic_report(ip: str, days: int = 1):
|
||||
"""获取流量分析报告"""
|
||||
try:
|
||||
report = await asyncio.to_thread(report_gen.generate_traffic_report, ip, days)
|
||||
return JSONResponse(content=report)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/reports/traffic")
|
||||
async def get_local_traffic_report(days: int = 1):
|
||||
"""获取本地网络流量报告"""
|
||||
try:
|
||||
report = await asyncio.to_thread(report_gen.generate_traffic_report, days=days)
|
||||
return JSONResponse(content=report)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/topology/traffic_heatmap")
|
||||
async def get_traffic_heatmap(minutes: int = 10):
|
||||
"""获取流量热力图数据"""
|
||||
try:
|
||||
heatmap = await asyncio.to_thread(visualizer.get_traffic_heatmap, minutes)
|
||||
return {"heatmap": heatmap}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
@ -1,46 +1,55 @@
|
||||
from typing import Any, List, Dict
|
||||
from typing import Dict, Any, Coroutine
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
import json
|
||||
from src.backend.app.utils.exceptions import SiliconFlowAPIException
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
from src.backend.app.utils.logger import logger
|
||||
|
||||
|
||||
class AIService:
|
||||
def __init__(self, api_key: str, api_url: str):
|
||||
self.client = AsyncOpenAI(api_key=api_key, base_url=api_url)
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_url,
|
||||
timeout=httpx.Timeout(30.0)
|
||||
)
|
||||
|
||||
async def parse_command(self, command: str, devices: List[Dict]) -> Dict[str, Any]:
|
||||
async def parse_command(self, command: str, vendor: str = "huawei") -> Any | None:
|
||||
"""
|
||||
针对一组设备和一条自然语言命令,生成每台设备的配置 JSON
|
||||
调用硅基流动API解析中文命令
|
||||
"""
|
||||
devices_str = json.dumps(devices, ensure_ascii=False, indent=2)
|
||||
|
||||
example = """[{"device": {"name": "sw1","ip": "192.168.1.10","vendor": "huawei","username": "NONE", "password": "Huawei"},"config": {"type": "vlan","vlan_id": 300,"name": "Sales","commands": ["system-view","vlan 300","name Sales","quit","quit","save","Y"]}}]"""
|
||||
vendor_prompts = {
|
||||
"huawei": "华为交换机配置命令",
|
||||
"cisco": "思科交换机配置命令",
|
||||
"h3c": "H3C交换机配置命令",
|
||||
"ruijie": "锐捷交换机配置命令",
|
||||
"zte": "中兴交换机配置命令"
|
||||
}
|
||||
|
||||
prompt = f"""
|
||||
你是一个网络设备配置专家。现在有以下设备:
|
||||
{devices_str}
|
||||
你是一个网络设备配置专家,精通各种类型的路由器的配置,请将以下用户的中文命令转换为{vendor_prompts.get(vendor, '网络设备')}配置JSON。
|
||||
但是请注意,由于贪婪的人们追求极高的效率,所以你必须严格按照 JSON 格式返回数据,不要包含任何额外文本或 Markdown 代码块。
|
||||
返回格式要求:
|
||||
1. 必须包含'type'字段指明配置类型(vlan/interface/acl/route等)
|
||||
2. 必须包含'commands'字段,包含可直接执行的命令列表
|
||||
3. 其他参数根据配置类型动态添加
|
||||
4. 不要包含解释性文本、步骤说明或注释
|
||||
5. 要包含使用ssh连接交换机后的完整命令包括但不完全包括system-view,退出,保存等完整操作,注意保存还需要输入Y
|
||||
|
||||
用户输入了一条命令:{command}
|
||||
|
||||
你的任务:
|
||||
- 为每台设备分别生成配置
|
||||
- 输出一个 JSON 数组,每个元素对应一台设备
|
||||
- 每个对象必须包含:
|
||||
- device: 原始设备信息 (name, ip, vendor,username,password)
|
||||
- config: 配置详情
|
||||
- type: 配置类型 (如 vlan/interface/acl/route)
|
||||
- commands: 可直接执行的命令数组 (必须包含进入配置、退出、保存命令)
|
||||
- 其他字段: 根据配置类型动态添加
|
||||
- 严格返回 JSON,不要包含解释说明或 markdown
|
||||
|
||||
各厂商保存命令规则:
|
||||
根据厂商{vendor}的不同,命令格式如下:
|
||||
- 华为: system-view → quit → save Y
|
||||
- 思科: enable → configure terminal → exit → write memory
|
||||
- H3C: system-view → quit → save
|
||||
- 锐捷: enable → configure terminal → exit → write
|
||||
- 中兴: enable → configure terminal → exit → write memory
|
||||
|
||||
返回示例(仅作为格式参考,不要照抄 VLAN ID 和命令内容,请根据实际命令生成):{example}
|
||||
示例命令:'创建VLAN 100,名称为TEST'
|
||||
华为示例返回:{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["system-view","vlan 100", "name TEST","quit","quit","save","Y"]}}
|
||||
思科示例返回:{{"type": "vlan", "vlan_id": 100, "name": "TEST", "commands": ["enable","configure terminal","vlan 100", "name TEST","exit","exit","write memory"]}}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
@ -52,18 +61,29 @@ class AIService:
|
||||
response = await self.client.chat.completions.create(
|
||||
model="deepseek-ai/DeepSeek-V3",
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
max_tokens=1500,
|
||||
temperature=0.3,
|
||||
max_tokens=1000,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
|
||||
config_str = response.choices[0].message.content.strip()
|
||||
configs = json.loads(config_str)
|
||||
logger.debug(response)
|
||||
|
||||
return {"success": True, "results": configs}
|
||||
config_str = response.choices[0].message.content.strip()
|
||||
|
||||
try:
|
||||
config = json.loads(config_str)
|
||||
return config
|
||||
except json.JSONDecodeError:
|
||||
if config_str.startswith("```json"):
|
||||
config_str = config_str[7:-3].strip()
|
||||
return json.loads(config_str)
|
||||
raise SiliconFlowAPIException("Invalid JSON format returned from AI")
|
||||
except KeyError:
|
||||
logger.error(KeyError)
|
||||
raise SiliconFlowAPIException("errrrrrrro")
|
||||
|
||||
except Exception as e:
|
||||
raise SiliconFlowAPIException(
|
||||
detail=f"AI 解析配置失败: {str(e)}",
|
||||
detail=f"API请求失败: {str(e)}",
|
||||
status_code=getattr(e, "status_code", 500)
|
||||
)
|
||||
|
@ -1,50 +1,24 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import sys
|
||||
|
||||
ENV_FILE = ".env"
|
||||
|
||||
if not os.path.exists(ENV_FILE):
|
||||
default_env_content = """
|
||||
APP_NAME=AI Network Configurator
|
||||
DEBUG=True
|
||||
API_PREFIX=/api
|
||||
|
||||
SILICONFLOW_API_KEY=your-api-key-here
|
||||
SILICONFLOW_API_URL=https://api.siliconflow.cn/v1
|
||||
|
||||
SWITCH_USERNAME=admin
|
||||
SWITCH_PASSWORD=admin
|
||||
SWITCH_TIMEOUT=10
|
||||
|
||||
ENSP_DEFAULT_IP=172.17.99.201
|
||||
ENSP_DEFAULT_PORT=2000
|
||||
"""
|
||||
with open(ENV_FILE, "w", encoding="utf-8") as f:
|
||||
f.write(default_env_content)
|
||||
|
||||
print(f"已生成默认配置文件 {ENV_FILE} ,请修改后重新运行程序。")
|
||||
sys.exit(1)
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv(ENV_FILE)
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
APP_NAME: str
|
||||
DEBUG: bool
|
||||
API_PREFIX: str
|
||||
APP_NAME: str = "AI Network Configurator"
|
||||
DEBUG: bool = True
|
||||
API_PREFIX: str = "/api"
|
||||
|
||||
SILICONFLOW_API_KEY: str
|
||||
SILICONFLOW_API_URL: str
|
||||
SILICONFLOW_API_KEY: str = os.getenv("SILICONFLOW_API_KEY", "sk-oftmyihyxitocscgjdicafzgezprwqpzzgkzsvoxrakkagmd")
|
||||
SILICONFLOW_API_URL: str = os.getenv("SILICONFLOW_API_URL", "https://api.siliconflow.cn/v1")
|
||||
|
||||
SWITCH_USERNAME: str
|
||||
SWITCH_PASSWORD: str
|
||||
SWITCH_TIMEOUT: int
|
||||
SWITCH_USERNAME: str = os.getenv("SWITCH_USERNAME", "admin")
|
||||
SWITCH_PASSWORD: str = os.getenv("SWITCH_PASSWORD", "admin")
|
||||
SWITCH_TIMEOUT: int = os.getenv("SWITCH_TIMEOUT", 10)
|
||||
|
||||
ENSP_DEFAULT_IP: str
|
||||
ENSP_DEFAULT_PORT: int
|
||||
ENSP_DEFAULT_IP: str = "172.17.99.201"
|
||||
ENSP_DEFAULT_PORT: int = 2000
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
Loading…
x
Reference in New Issue
Block a user