mirror of
https://github.com/Jerryplusy/AI-powered-switches.git
synced 2025-07-04 13:19:20 +00:00
修改(报错的是api问题)
This commit is contained in:
parent
3f745301fb
commit
8ad26d6fff
22
src/backend/Dockerfile
Normal file
22
src/backend/Dockerfile
Normal file
@ -0,0 +1,22 @@
|
||||
version: '3.13.2'
|
||||
|
||||
services:
|
||||
app:
|
||||
build: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
depends_on:
|
||||
- redis
|
||||
environment:
|
||||
- REDIS_URL=redis://redis:6379
|
||||
|
||||
redis:
|
||||
image: redis:alpine
|
||||
ports:
|
||||
- "6379:6379"
|
||||
|
||||
worker:
|
||||
build: .
|
||||
command: celery -A app.services.task_service worker --loglevel=info
|
||||
depends_on:
|
||||
- ONBUILD
|
@ -0,0 +1,29 @@
|
||||
from .base import BaseAdapter
|
||||
from .cisco import CiscoAdapter
|
||||
from .huawei import HuaweiAdapter
|
||||
from .factory import AdapterFactory
|
||||
|
||||
# 自动注册所有适配器类
|
||||
__all_adapters__ = {
|
||||
'cisco': CiscoAdapter,
|
||||
'huawei': HuaweiAdapter
|
||||
}
|
||||
|
||||
def get_supported_vendors() -> list:
|
||||
"""获取当前支持的设备厂商列表"""
|
||||
return list(__all_adapters__.keys())
|
||||
|
||||
def init_adapters():
|
||||
"""初始化适配器工厂"""
|
||||
AdapterFactory.register_adapters(__all_adapters__)
|
||||
|
||||
# 应用启动时自动初始化
|
||||
init_adapters()
|
||||
|
||||
__all__ = [
|
||||
'BaseAdapter',
|
||||
'CiscoAdapter',
|
||||
'HuaweiAdapter',
|
||||
'AdapterFactory',
|
||||
'get_supported_vendors'
|
||||
]
|
16
src/backend/app/adapters/base.py
Normal file
16
src/backend/app/adapters/base.py
Normal file
@ -0,0 +1,16 @@
|
||||
# /backend/app/adapters/base.py
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
|
||||
class BaseAdapter(ABC):
|
||||
@abstractmethod
|
||||
async def connect(self, ip: str, credentials: Dict[str, str]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def deploy_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
pass
|
@ -1 +0,0 @@
|
||||
#抽象基类
|
32
src/backend/app/adapters/cisco.py
Normal file
32
src/backend/app/adapters/cisco.py
Normal file
@ -0,0 +1,32 @@
|
||||
# /backend/app/adapters/cisco.py
|
||||
from netmiko import ConnectHandler
|
||||
from .base import BaseAdapter
|
||||
|
||||
class CiscoAdapter(BaseAdapter):
|
||||
def __init__(self):
|
||||
self.connection = None
|
||||
|
||||
async def connect(self, ip: str, credentials: Dict[str, str]):
|
||||
self.connection = ConnectHandler(
|
||||
device_type='cisco_ios',
|
||||
host=ip,
|
||||
username=credentials['username'],
|
||||
password=credentials['password'],
|
||||
timeout=10
|
||||
)
|
||||
|
||||
async def deploy_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
commands = self._generate_commands(config)
|
||||
output = self.connection.send_config_set(commands)
|
||||
return {'success': True, 'output': output}
|
||||
|
||||
def _generate_commands(self, config: Dict[str, Any]) -> list:
|
||||
# 实际生产中应使用Jinja2模板
|
||||
commands = []
|
||||
if 'vlans' in config:
|
||||
for vlan in config['vlans']:
|
||||
commands.extend([
|
||||
f"vlan {vlan['id']}",
|
||||
f"name {vlan['name']}"
|
||||
])
|
||||
return commands
|
21
src/backend/app/adapters/factory.py
Normal file
21
src/backend/app/adapters/factory.py
Normal file
@ -0,0 +1,21 @@
|
||||
from . import BaseAdapter
|
||||
from .cisco import CiscoAdapter
|
||||
from .huawei import HuaweiAdapter
|
||||
|
||||
class AdapterFactory:
|
||||
_adapters = {}
|
||||
|
||||
@classmethod
|
||||
def register_adapters(cls, adapters: dict):
|
||||
"""注册适配器字典"""
|
||||
cls._adapters.update(adapters)
|
||||
|
||||
@classmethod
|
||||
def get_adapter(vendor: str)->BaseAdapter:
|
||||
adapters = {
|
||||
'cisco': CiscoAdapter,
|
||||
'huawei': HuaweiAdapter
|
||||
}
|
||||
if vendor not in cls._adapters:
|
||||
raise ValueError(f"Unsupported vendor: {vendor}")
|
||||
return cls._adapters[vendor]()
|
26
src/backend/app/adapters/huawei.py
Normal file
26
src/backend/app/adapters/huawei.py
Normal file
@ -0,0 +1,26 @@
|
||||
import httpx
|
||||
from .base import BaseAdapter
|
||||
|
||||
class HuaweiAdapter(BaseAdapter):
|
||||
def __init__(self):
|
||||
self.client = None
|
||||
self.base_url = None
|
||||
|
||||
async def connect(self, ip: str, credentials: dict):
|
||||
self.base_url = f"https://{ip}/restconf"
|
||||
self.client = httpx.AsyncClient(
|
||||
auth=(credentials['username'], credentials['password']),
|
||||
verify=False,
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
async def deploy_config(self, config: dict):
|
||||
headers = {"Content-Type": "application/yang-data+json"}
|
||||
url = f"{self.base_url}/data/ietf-restconf:operations/network-topology:deploy"
|
||||
response = await self.client.post(url, json=config, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def disconnect(self):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
40
src/backend/app/api/bulk.py
Normal file
40
src/backend/app/api/bulk.py
Normal file
@ -0,0 +1,40 @@
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
from app.services.batch import BatchService
|
||||
from app.utils.decorators import async_retry
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class BulkDeviceConfig(BaseModel):
|
||||
device_ips: List[str]
|
||||
config: dict
|
||||
credentials: dict
|
||||
vendor: str = "cisco"
|
||||
timeout: int = 30
|
||||
|
||||
@router.post("/config")
|
||||
@async_retry(max_attempts=3, delay=1)
|
||||
async def bulk_apply_config(request: BulkDeviceConfig, bg_tasks: BackgroundTasks):
|
||||
"""
|
||||
批量配置设备接口
|
||||
示例请求体:
|
||||
{
|
||||
"device_ips": ["192.168.1.1", "192.168.1.2"],
|
||||
"config": {"vlans": [{"id": 100, "name": "test"}]},
|
||||
"credentials": {"username": "admin", "password": "secret"},
|
||||
"vendor": "cisco"
|
||||
}
|
||||
"""
|
||||
devices = [{
|
||||
"ip": ip,
|
||||
"credentials": request.credentials,
|
||||
"vendor": request.vendor
|
||||
} for ip in request.device_ips]
|
||||
|
||||
try:
|
||||
batch = BatchService()
|
||||
bg_tasks.add_task(batch.deploy_batch, devices, request.config)
|
||||
return {"message": "Batch job started", "device_count": len(devices)}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
17
src/backend/app/api/health.py
Normal file
17
src/backend/app/api/health.py
Normal file
@ -0,0 +1,17 @@
|
||||
from fastapi import APIRouter
|
||||
from ...monitoring.healthcheck import check_redis, check_ai_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/live")
|
||||
async def liveness_check():
|
||||
return {"status": "alive"}
|
||||
|
||||
@router.get("/ready")
|
||||
async def readiness_check():
|
||||
redis_ok = await check_redis()
|
||||
ai_ok = await check_ai_service()
|
||||
return {
|
||||
"redis": redis_ok,
|
||||
"ai_service": ai_ok
|
||||
}
|
20
src/backend/app/api/topology.py
Normal file
20
src/backend/app/api/topology.py
Normal file
@ -0,0 +1,20 @@
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from ...services.task_service import deploy_to_device
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class TopologyRequest(BaseModel):
|
||||
devices: list
|
||||
config: dict
|
||||
|
||||
@router.post("/deploy")
|
||||
async def deploy_topology(
|
||||
request: TopologyRequest,
|
||||
bg_tasks: BackgroundTasks
|
||||
):
|
||||
task_ids = []
|
||||
for device in request.devices:
|
||||
task = deploy_to_device.delay(device, request.config)
|
||||
task_ids.append(task.id)
|
||||
return {"task_ids": task_ids}
|
10
src/backend/app/models/__init__.py
Normal file
10
src/backend/app/models/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
from .device import DeviceCredentials, DeviceInfo
|
||||
from .topology import TopologyType, DeviceRole, NetworkTopology
|
||||
|
||||
__all__ = [
|
||||
'DeviceCredentials',
|
||||
'DeviceInfo',
|
||||
'TopologyType',
|
||||
'DeviceRole',
|
||||
'NetworkTopology'
|
||||
]
|
14
src/backend/app/models/devices.py
Normal file
14
src/backend/app/models/devices.py
Normal file
@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class DeviceCredentials(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
enable_password: Optional[str] = None
|
||||
|
||||
class DeviceInfo(BaseModel):
|
||||
ip: str
|
||||
vendor: str
|
||||
model: Optional[str] = None
|
||||
os_version: Optional[str] = None
|
||||
credentials: DeviceCredentials
|
20
src/backend/app/models/topology.py
Normal file
20
src/backend/app/models/topology.py
Normal file
@ -0,0 +1,20 @@
|
||||
#拓补数据结构
|
||||
from enum import Enum
|
||||
from typing import Dict, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TopologyType(str, Enum):
|
||||
SPINE_LEAF = "spine-leaf"
|
||||
CORE_ACCESS = "core-access"
|
||||
RING = "ring"
|
||||
|
||||
class DeviceRole(str, Enum):
|
||||
CORE = "core"
|
||||
SPINE = "spine"
|
||||
LEAF = "leaf"
|
||||
ACCESS = "access"
|
||||
|
||||
class NetworkTopology(BaseModel):
|
||||
type: TopologyType
|
||||
devices: Dict[DeviceRole, List[str]]
|
||||
links: Dict[str, List[str]]
|
@ -1 +0,0 @@
|
||||
#拓补数据结构
|
42
src/backend/app/monitoring/metrics.py
Normal file
42
src/backend/app/monitoring/metrics.py
Normal file
@ -0,0 +1,42 @@
|
||||
from prometheus_client import (
|
||||
Counter,
|
||||
Gauge,
|
||||
Histogram,
|
||||
Summary
|
||||
)
|
||||
|
||||
# API Metrics
|
||||
API_REQUESTS = Counter(
|
||||
'api_requests_total',
|
||||
'Total API requests',
|
||||
['method', 'endpoint', 'status']
|
||||
)
|
||||
|
||||
API_LATENCY = Histogram(
|
||||
'api_request_latency_seconds',
|
||||
'API request latency',
|
||||
['endpoint']
|
||||
)
|
||||
|
||||
# Device Metrics
|
||||
DEVICE_CONNECTIONS = Gauge(
|
||||
'network_device_connections',
|
||||
'Active device connections',
|
||||
['vendor']
|
||||
)
|
||||
|
||||
CONFIG_APPLY_TIME = Summary(
|
||||
'config_apply_seconds',
|
||||
'Time spent applying configurations'
|
||||
)
|
||||
|
||||
# Error Metrics
|
||||
CONFIG_ERRORS = Counter(
|
||||
'config_errors_total',
|
||||
'Configuration errors',
|
||||
['error_type']
|
||||
)
|
||||
|
||||
def observe_api_request(method: str, endpoint: str, status: int, duration: float):
|
||||
API_REQUESTS.labels(method, endpoint, status).inc()
|
||||
API_LATENCY.labels(endpoint).observe(duration)
|
31
src/backend/app/monitoring/middleware.py
Normal file
31
src/backend/app/monitoring/middleware.py
Normal file
@ -0,0 +1,31 @@
|
||||
from prometheus_client import Counter, Histogram
|
||||
from fastapi import Request
|
||||
|
||||
REQUESTS = Counter(
|
||||
'api_requests_total',
|
||||
'Total API Requests',
|
||||
['method', 'endpoint']
|
||||
)
|
||||
|
||||
LATENCY = Histogram(
|
||||
'api_request_latency_seconds',
|
||||
'API Request Latency',
|
||||
['endpoint']
|
||||
)
|
||||
|
||||
|
||||
async def monitor_requests(request: Request, call_next):
|
||||
start_time = time.time()
|
||||
response = await call_next(request)
|
||||
latency = time.time() - start_time
|
||||
|
||||
REQUESTS.labels(
|
||||
method=request.method,
|
||||
endpoint=request.url.path
|
||||
).inc()
|
||||
|
||||
LATENCY.labels(
|
||||
endpoint=request.url.path
|
||||
).observe(latency)
|
||||
|
||||
return response
|
16
src/backend/app/services/__init__.py
Normal file
16
src/backend/app/services/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
from .task_service import celery_app
|
||||
from .ai_service import AIService
|
||||
from .topology import TopologyService
|
||||
from .batch import BatchService
|
||||
|
||||
# 单例服务实例
|
||||
ai_service = AIService()
|
||||
topology_service = TopologyService()
|
||||
batch_service = BatchService()
|
||||
|
||||
__all__ = [
|
||||
'celery_app',
|
||||
'ai_service',
|
||||
'topology_service',
|
||||
'batch_service'
|
||||
]
|
@ -1 +0,0 @@
|
||||
#异步SSH连接器
|
48
src/backend/app/services/batch.py
Normal file
48
src/backend/app/services/batch.py
Normal file
@ -0,0 +1,48 @@
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
from app.adapters.factory import AdapterFactory
|
||||
from app.utils.connection_pool import ConnectionPool
|
||||
from app.monitoring.metrics import (
|
||||
DEVICE_CONNECTIONS,
|
||||
CONFIG_APPLY_TIME,
|
||||
CONFIG_ERRORS
|
||||
)
|
||||
|
||||
|
||||
class BatchService:
|
||||
def __init__(self, max_workers: int = 10):
|
||||
self.semaphore = asyncio.Semaphore(max_workers)
|
||||
self.pool = ConnectionPool()
|
||||
|
||||
@CONFIG_APPLY_TIME.time()
|
||||
async def deploy_batch(self, devices: List[Dict], config: Dict[str, Any]):
|
||||
async def _deploy(device):
|
||||
vendor = device.get('vendor', 'cisco')
|
||||
async with self.semaphore:
|
||||
try:
|
||||
adapter = AdapterFactory.get_adapter(vendor)
|
||||
await adapter.connect(device['ip'], device['credentials'])
|
||||
DEVICE_CONNECTIONS.labels(vendor).inc()
|
||||
|
||||
result = await adapter.deploy_config(config)
|
||||
return {
|
||||
"device": device['ip'],
|
||||
"status": "success",
|
||||
"result": result
|
||||
}
|
||||
except ConnectionError as e:
|
||||
CONFIG_ERRORS.labels("connection").inc()
|
||||
return {
|
||||
"device": device['ip'],
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
finally:
|
||||
if adapter:
|
||||
await adapter.disconnect()
|
||||
DEVICE_CONNECTIONS.labels(vendor).dec()
|
||||
|
||||
return await asyncio.gather(
|
||||
*[_deploy(device) for device in devices],
|
||||
return_exceptions=True
|
||||
)
|
@ -1 +1,18 @@
|
||||
#Celery任务定义
|
||||
#Celery任务定义
|
||||
from celery import Celery
|
||||
from src.backend.app.utils.connection_pool import ConnectionPool
|
||||
from src.backend.config import settings
|
||||
|
||||
celery = Celery(__name__, broker=settings.REDIS_URL)
|
||||
pool = ConnectionPool(max_size=settings.MAX_CONNECTIONS)
|
||||
|
||||
@celery.task
|
||||
async def deploy_to_device(device_info: dict, config: dict):
|
||||
adapter = await pool.get(device_info['vendor'])
|
||||
try:
|
||||
await adapter.connect(device_info['ip'], device_info['credentials'])
|
||||
result = await adapter.deploy_config(config)
|
||||
await pool.release(adapter)
|
||||
return {'device': device_info['ip'], 'result': result}
|
||||
except Exception as e:
|
||||
return {'device': device_info['ip'], 'error': str(e)}
|
@ -1 +1,23 @@
|
||||
#拓补处理逻辑
|
||||
#拓补处理逻辑
|
||||
def generate_multi_device_config(topology):
|
||||
"""
|
||||
topology示例:
|
||||
{
|
||||
"core_switches": [sw1, sw2],
|
||||
"access_switches": {
|
||||
"sw1": [sw3, sw4],
|
||||
"sw2": [sw5, sw6]
|
||||
}
|
||||
}
|
||||
"""
|
||||
configs = {}
|
||||
# 生成核心层配置(如MSTP根桥选举)
|
||||
for sw in topology['core_switches']:
|
||||
configs[sw] = generate_core_config(sw)
|
||||
|
||||
# 生成接入层配置(如端口绑定)
|
||||
for core_sw, access_sws in topology['access_switches'].items():
|
||||
for sw in access_sws:
|
||||
configs[sw] = generate_access_config(sw, uplink=core_sw)
|
||||
|
||||
return configs
|
@ -1 +1,22 @@
|
||||
#连接池
|
||||
# /backend/app/utils/connection_pool.py
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from ..adapters import cisco, huawei
|
||||
|
||||
class ConnectionPool:
|
||||
def __init__(self, max_size=10):
|
||||
self.max_size = max_size
|
||||
self.pool = deque(maxlen=max_size)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def get(self, vendor: str):
|
||||
async with self.lock:
|
||||
if self.pool:
|
||||
return self.pool.pop()
|
||||
return CiscoAdapter() if vendor == 'cisco' else HuaweiAdapter()
|
||||
|
||||
async def release(self, adapter):
|
||||
async with self.lock:
|
||||
if len(self.pool) < self.max_size:
|
||||
self.pool.append(adapter)
|
9
src/backend/celeryconfig.py
Normal file
9
src/backend/celeryconfig.py
Normal file
@ -0,0 +1,9 @@
|
||||
broker_url = 'redis://redis:6379/0'
|
||||
result_backend = 'redis://redis:6379/1'
|
||||
task_serializer = 'json'
|
||||
result_serializer = 'json'
|
||||
accept_content = ['json']
|
||||
timezone = 'UTC'
|
||||
enable_utc = True
|
||||
task_track_started = True
|
||||
task_time_limit = 300
|
@ -1,14 +1,17 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
app_name: str = "Network Config API"
|
||||
ai_api_key: str = "your-silicon-mobility-api-key"
|
||||
ai_api_url: str = "https://api.silicon-mobility.com/v1/parse"
|
||||
debug: bool = False
|
||||
app_name: str = "Network Automation API"
|
||||
redis_url: str = Field("redis://localhost:6379", env="REDIS_URL")
|
||||
ai_api_key: str = Field(..., env="AI_API_KEY")
|
||||
max_connections: int = Field(50, env="MAX_CONNECTIONS")
|
||||
default_timeout: int = Field(30, env="DEFAULT_TIMEOUT")
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
settings = Settings()
|
@ -1,6 +1,8 @@
|
||||
fastapi==0.109.1
|
||||
uvicorn==0.27.0
|
||||
python-dotenv==1.0.0
|
||||
requests==2.31.0
|
||||
pydantic==2.6.1
|
||||
pydantic-settings==2.1.0
|
||||
celery==5.3.6
|
||||
redis==4.6.0
|
||||
netmiko==4.2.0
|
||||
asyncssh==2.14.10.0
|
||||
prometheus-client==0.2
|
Loading…
x
Reference in New Issue
Block a user