6969d0c62e
VulnCheck - Open Source Vulnerability Management for Wazuh Features: - Vulnerability management with Wazuh integration - AI-powered CVE analysis (OpenAI, Anthropic, Google, DeepSeek, Ollama, Infomaniak) - SLA policy enforcement with automated email alerts - Automated patch verification via Wazuh Syscollector - Role-based access control (Admin, Editor, Readonly) - PDF/CSV reporting for compliance workflows - Full audit trail https://gitea.isuit.ch/vulncheck/vulncheck
226 lines
8.7 KiB
Python
226 lines
8.7 KiB
Python
"""
|
|
Universal AI API Client
|
|
Supports: OpenAI, Anthropic (Claude), Gemini, DeepSeek, Ollama, and Infomaniak.
|
|
"""
|
|
import os
|
|
import logging
|
|
import json
|
|
from typing import Optional, Dict, Any, List
|
|
from datetime import datetime, timedelta
|
|
import httpx
|
|
from fastapi import Depends
|
|
from sqlalchemy.orm import Session
|
|
from app.database import get_db
|
|
from app.models.setting import Setting
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AIError(Exception):
|
|
"""Base exception for AI Client errors"""
|
|
pass
|
|
|
|
class UniversalAIClient:
|
|
"""
|
|
Universal AI Client supporting multiple providers.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
provider: str = "openai",
|
|
api_key: Optional[str] = None,
|
|
base_url: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
product_id: Optional[str] = None # Specific for Infomaniak
|
|
):
|
|
self.provider = provider.lower()
|
|
self.api_key = api_key
|
|
self.model = model
|
|
self.product_id = product_id
|
|
|
|
# Determine Base URL
|
|
if self.provider == "infomaniak":
|
|
if product_id:
|
|
# v1 API for chat completions (v2 is only for models listing)
|
|
self.base_url = f"https://api.infomaniak.com/1/ai/{product_id}/openai/chat/completions"
|
|
else:
|
|
self.base_url = "https://api.infomaniak.com/1/ai/<PRODUCT_ID>/openai/chat/completions"
|
|
elif self.provider == "anthropic":
|
|
self.base_url = base_url or "https://api.anthropic.com/v1/messages"
|
|
elif self.provider == "gemini":
|
|
self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions"
|
|
elif self.provider == "deepseek":
|
|
self.base_url = (base_url or "https://api.deepseek.com").rstrip("/")
|
|
if not self.base_url.endswith("/chat/completions"):
|
|
self.base_url += "/chat/completions"
|
|
elif self.provider == "ollama":
|
|
self.base_url = (base_url or "http://localhost:11434/v1").rstrip("/")
|
|
if not self.base_url.endswith("/chat/completions"):
|
|
self.base_url += "/chat/completions"
|
|
elif self.provider == "openai":
|
|
self.base_url = (base_url or "https://api.openai.com/v1").rstrip("/")
|
|
if not self.base_url.endswith("/chat/completions"):
|
|
self.base_url += "/chat/completions"
|
|
else:
|
|
self.base_url = base_url
|
|
|
|
if not self.api_key and self.provider != "ollama":
|
|
logger.warning(f"AI API Key missing for provider {self.provider}")
|
|
|
|
self.client = httpx.Client(
|
|
timeout=httpx.Timeout(300.0, connect=10.0),
|
|
limits=httpx.Limits(max_connections=5, max_keepalive_connections=2)
|
|
)
|
|
|
|
logger.info(f"Universal AI Client initialized: Provider={self.provider}, Model={self.model}")
|
|
|
|
@staticmethod
|
|
def from_config(config_dict: Dict[str, Any]) -> "UniversalAIClient":
|
|
"""Creates client from a config dictionary"""
|
|
return UniversalAIClient(
|
|
provider=config_dict.get("provider", "openai"),
|
|
api_key=config_dict.get("api_token") or config_dict.get("api_key"),
|
|
base_url=config_dict.get("base_url"),
|
|
model=config_dict.get("model"),
|
|
product_id=config_dict.get("product_id")
|
|
)
|
|
|
|
def analyze_cve(
|
|
self,
|
|
cve_id: str,
|
|
package_name: Optional[str] = None,
|
|
package_version: Optional[str] = None,
|
|
cvss_score: Optional[float] = None
|
|
) -> Dict[str, Any]:
|
|
"""Analyzes a CVE using the configured LLM"""
|
|
prompt = self._build_analysis_prompt(cve_id, package_name, package_version, cvss_score)
|
|
|
|
try:
|
|
response = self._call_llm(prompt, enable_web_search=(self.provider == "infomaniak"))
|
|
parsed_result = self._parse_analysis_response(response)
|
|
return parsed_result
|
|
except Exception as e:
|
|
logger.error(f"CVE analysis failed for {cve_id}: {e}")
|
|
raise AIError(f"Analysis failed: {e}")
|
|
|
|
def _build_analysis_prompt(self, cve_id, package_name, package_version, cvss_score) -> str:
|
|
# Same prompt logic as before
|
|
context = f"CVE: {cve_id}"
|
|
if package_name:
|
|
context += f"\nPackage: {package_name} {package_version or ''}"
|
|
if cvss_score:
|
|
context += f"\nCVSS Score: {cvss_score}"
|
|
|
|
return f"""You are a Cybersecurity Expert. Analyze the following vulnerability:
|
|
|
|
{context}
|
|
|
|
Perform a comprehensive analysis and answer:
|
|
1. Threat Level (low/medium/high/critical)
|
|
2. Exploits availability
|
|
3. In-the-wild exploitation
|
|
4. Workarounds
|
|
5. Remediation steps
|
|
6. Threat Intel sources
|
|
|
|
Structure your response as JSON ONLY:
|
|
{{
|
|
"threat_level": "high",
|
|
"analysis_summary": "Summary...",
|
|
"exploits_found": [{{ "url": "...", "type": "...", "description": "..." }}],
|
|
"in_the_wild": true,
|
|
"workarounds": ["..."],
|
|
"remediation_steps": ["..."],
|
|
"threat_intel_sources": [{{ "name": "...", "url": "..." }}]
|
|
}}
|
|
"""
|
|
|
|
def _call_llm(self, prompt: str, enable_web_search: bool = False, temperature: float = 0.3) -> str:
|
|
if self.provider == "anthropic":
|
|
return self._call_anthropic(prompt, temperature)
|
|
return self._call_openai_compatible(prompt, enable_web_search, temperature)
|
|
|
|
def _call_openai_compatible(self, prompt: str, enable_web_search: bool, temperature: float) -> str:
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": [
|
|
{"role": "system", "content": "You are a Cybersecurity Analyst."},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
"temperature": temperature,
|
|
"max_tokens": 2000
|
|
}
|
|
|
|
if enable_web_search and self.provider == "infomaniak":
|
|
payload["tools"] = [{"type": "web_search"}]
|
|
|
|
logger.info(f"Calling {self.provider} API with model: {self.model}")
|
|
|
|
response = self.client.post(self.base_url, headers=headers, json=payload)
|
|
|
|
if response.status_code != 200:
|
|
logger.error(f"API Error {response.status_code}: {response.text}")
|
|
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data["choices"][0]["message"]["content"]
|
|
|
|
def _call_anthropic(self, prompt: str, temperature: float) -> str:
|
|
headers = {
|
|
"x-api-key": self.api_key,
|
|
"anthropic-version": "2023-06-01",
|
|
"content-type": "application/json"
|
|
}
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": 2000,
|
|
"temperature": temperature
|
|
}
|
|
response = self.client.post(self.base_url, headers=headers, json=payload)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data["content"][0]["text"]
|
|
|
|
def _parse_analysis_response(self, response: str) -> Dict[str, Any]:
|
|
try:
|
|
# Simple extractor for JSON blocks
|
|
if "```json" in response:
|
|
response = response.split("```json")[1].split("```")[0].strip()
|
|
elif "```" in response:
|
|
response = response.split("```")[1].split("```")[0].strip()
|
|
|
|
parsed = json.loads(response)
|
|
return {
|
|
"analysis_text": parsed.get("analysis_summary", ""),
|
|
"threat_level": parsed.get("threat_level", "medium"),
|
|
"exploits_found": json.dumps(parsed.get("exploits_found", [])),
|
|
"workarounds": json.dumps(parsed.get("workarounds", [])),
|
|
"remediation_steps": json.dumps(parsed.get("remediation_steps", [])),
|
|
"threat_intel_sources": json.dumps(parsed.get("threat_intel_sources", [])),
|
|
"in_the_wild": parsed.get("in_the_wild", False),
|
|
"confidence_score": "high",
|
|
"model_version": self.model,
|
|
"analysis_timestamp": datetime.now(),
|
|
"cache_expires_at": datetime.now() + timedelta(hours=24)
|
|
}
|
|
except Exception:
|
|
return {"analysis_text": response, "threat_level": "unknown"}
|
|
|
|
def close(self):
|
|
self.client.close()
|
|
|
|
def get_ai_client(db: Session = Depends(get_db)) -> UniversalAIClient:
|
|
config_setting = db.query(Setting).filter(Setting.key == "ai_config").first()
|
|
if config_setting and config_setting.value:
|
|
try:
|
|
config = json.loads(config_setting.value)
|
|
return UniversalAIClient.from_config(config)
|
|
except Exception:
|
|
logger.error("Invalid AI config in DB")
|
|
return UniversalAIClient()
|