220 lines
8.2 KiB
Python
220 lines
8.2 KiB
Python
from elasticsearch import Elasticsearch
|
|
import chromadb
|
|
from chromadb.utils import embedding_functions
|
|
import json
|
|
import requests
|
|
import logging
|
|
from typing import Dict, List, Any
|
|
from dotenv import load_dotenv
|
|
import os
|
|
from pathlib import Path
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ElasticSearchAI:
|
|
def __init__(self):
|
|
# Initialize Elasticsearch client with environment variables
|
|
es_config = {
|
|
'hosts': [f"{os.getenv('ES_HOST', 'localhost')}:{os.getenv('ES_PORT', '9200')}"],
|
|
'basic_auth': (
|
|
os.getenv('ES_USERNAME', 'elastic'),
|
|
os.getenv('ES_PASSWORD', '')
|
|
)
|
|
}
|
|
|
|
# Add SSL configuration if enabled
|
|
if os.getenv('ES_USE_SSL', 'false').lower() == 'true':
|
|
es_config.update({
|
|
'use_ssl': True,
|
|
'verify_certs': os.getenv('ES_VERIFY_CERTS', 'true').lower() == 'true',
|
|
'ca_certs': os.getenv('ES_CA_CERT_PATH')
|
|
})
|
|
|
|
self.es = Elasticsearch(**es_config)
|
|
|
|
# Initialize ChromaDB client
|
|
chroma_config = {
|
|
'host': os.getenv('CHROMADB_HOST', 'localhost'),
|
|
'port': int(os.getenv('CHROMADB_PORT', '8000')),
|
|
'persist_directory': os.getenv('CHROMADB_PERSISTENCE_DIR', './chroma_storage')
|
|
}
|
|
|
|
self.chroma_client = chromadb.Client()
|
|
self.collection = self.chroma_client.create_collection(
|
|
name="search_results",
|
|
embedding_function=embedding_functions.DefaultEmbeddingFunction()
|
|
)
|
|
|
|
# Cache for index information
|
|
self.index_info = None
|
|
|
|
# Ollama configuration
|
|
self.ollama_url = f"http://{os.getenv('OLLAMA_HOST', 'localhost')}:{os.getenv('OLLAMA_PORT', '11434')}"
|
|
self.ollama_model = os.getenv('OLLAMA_MODEL', 'llama2')
|
|
|
|
def get_index_information(self) -> Dict:
|
|
"""Gather index information from Elasticsearch"""
|
|
try:
|
|
# Get all havoc-* indices and their mappings
|
|
indices_info = self.es.indices.get("havoc-*")
|
|
|
|
# Get index statistics
|
|
indices_stats = self.es.indices.stats(index="havoc-*")
|
|
|
|
# Combine the information
|
|
index_data = {}
|
|
for index_name, info in indices_info.items():
|
|
index_data[index_name] = {
|
|
"mappings": info["mappings"],
|
|
"description": {
|
|
"short": info["settings"]["index"].get("metadata", {}).get("desc_short", "No short description available"),
|
|
"long": info["settings"]["index"].get("metadata", {}).get("desc_long", "No long description available")
|
|
},
|
|
"stats": {
|
|
"doc_count": indices_stats["indices"][index_name]["total"]["docs"]["count"],
|
|
"size_bytes": indices_stats["indices"][index_name]["total"]["store"]["size_in_bytes"]
|
|
}
|
|
}
|
|
|
|
self.index_info = index_data
|
|
return index_data
|
|
except Exception as e:
|
|
logger.error(f"Error gathering index information: {str(e)}")
|
|
raise
|
|
|
|
def query_ollama(self, user_prompt: str) -> List[Dict]:
|
|
"""Query Ollama to generate Elasticsearch DSL queries"""
|
|
# Ensure we have fresh index information
|
|
if not self.index_info:
|
|
self.get_index_information()
|
|
|
|
# Create a more informative system prompt using the gathered information
|
|
index_descriptions = []
|
|
for index_name, info in self.index_info.items():
|
|
index_descriptions.append(f"""
|
|
Index: {index_name}
|
|
Short description: {info['description']['short']}
|
|
Long description: {info['description']['long']}
|
|
Document count: {info['stats']['doc_count']:,}
|
|
Size: {info['stats']['size_bytes'] / (1024*1024*1024):.2f} GB
|
|
Mapping: {json.dumps(info['mappings'], indent=2)}
|
|
""")
|
|
|
|
system_prompt = f"""
|
|
You are an Elasticsearch query generator. Convert the following natural language query into Elasticsearch DSL queries.
|
|
|
|
Available indices and their information:
|
|
{''.join(index_descriptions)}
|
|
|
|
Rules:
|
|
1. Return a JSON array where each object represents queries for a single index
|
|
2. Each object should have format:
|
|
{{"index": "index-name", "queries": [list of query objects for this index]}}
|
|
3. Consider the mapping types when generating queries:
|
|
- For text fields, use match/match_phrase for full-text search
|
|
- For keyword fields, use term/terms for exact matches
|
|
- For IP fields, use proper IP query syntax
|
|
- For date fields, use date range queries when appropriate
|
|
- For nested fields, use nested queries
|
|
4. Each query should be a valid Elasticsearch DSL query object
|
|
5. Only query indices that are relevant to the user's request
|
|
|
|
User query: {user_prompt}
|
|
"""
|
|
|
|
response = requests.post(
|
|
f"{self.ollama_url}/api/generate",
|
|
json={
|
|
"model": self.ollama_model,
|
|
"prompt": system_prompt,
|
|
"format": "json"
|
|
}
|
|
)
|
|
|
|
return json.loads(response.json()["response"])
|
|
|
|
def execute_search(self, query_data: List[Dict]) -> List[Dict[str, Any]]:
|
|
"""Execute Elasticsearch search across specified indices with multiple queries"""
|
|
results = []
|
|
|
|
for index_queries in query_data:
|
|
index = index_queries["index"]
|
|
queries = index_queries["queries"]
|
|
|
|
try:
|
|
# Execute each query for this index
|
|
for query in queries:
|
|
response = self.es.search(
|
|
index=index,
|
|
body={"query": query}
|
|
)
|
|
results.extend(response["hits"]["hits"])
|
|
except Exception as e:
|
|
logger.error(f"Error searching index {index}: {str(e)}")
|
|
|
|
return results
|
|
|
|
def store_in_chroma(self, results: List[Dict[str, Any]], query_id: str):
|
|
"""Store search results in ChromaDB"""
|
|
documents = []
|
|
metadatas = []
|
|
ids = []
|
|
|
|
for i, result in enumerate(results):
|
|
documents.append(json.dumps(result["_source"]))
|
|
metadatas.append({
|
|
"index": result["_index"],
|
|
"score": result["_score"]
|
|
})
|
|
ids.append(f"{query_id}_{i}")
|
|
|
|
self.collection.add(
|
|
documents=documents,
|
|
metadatas=metadatas,
|
|
ids=ids
|
|
)
|
|
|
|
def generate_response(self, user_prompt: str, results: List[Dict[str, Any]]) -> str:
|
|
"""Generate a response based on search results"""
|
|
response_prompt = f"""
|
|
Generate a response for the user's query: {user_prompt}
|
|
|
|
Based on these search results: {json.dumps(results, indent=2)}
|
|
|
|
If the user requested JSON output, return valid JSON.
|
|
Otherwise, provide a natural language summary.
|
|
"""
|
|
|
|
response = requests.post(
|
|
f"{self.ollama_url}/api/generate",
|
|
json={
|
|
"model": self.ollama_model,
|
|
"prompt": response_prompt
|
|
}
|
|
)
|
|
|
|
return response.json()["response"]
|
|
|
|
def process_query(self, user_prompt: str) -> str:
|
|
"""Process a natural language query end-to-end"""
|
|
# Ensure we have fresh index information
|
|
if not self.index_info:
|
|
self.get_index_information()
|
|
|
|
# Generate Elasticsearch query using Ollama
|
|
query_data = self.query_ollama(user_prompt)
|
|
|
|
# Execute search
|
|
results = self.execute_search(query_data)
|
|
|
|
# Store results in ChromaDB
|
|
query_id = str(hash(user_prompt))
|
|
self.store_in_chroma(results, query_id)
|
|
|
|
# Generate response
|
|
return self.generate_response(user_prompt, results) |