Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 103 additions & 101 deletions src/codebase_rag/services/knowledge/neo4j_knowledge_service.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@

"""
modern knowledge graph service based on Neo4j's native vector index
uses LlamaIndex's KnowledgeGraphIndex and Neo4j's native vector search functionality
supports multiple LLM and embedding model providers
"""

from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import asyncio
from loguru import logger
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from llama_index.core import (
KnowledgeGraphIndex,
Expand All @@ -19,26 +17,28 @@
VectorStoreIndex,
)
from llama_index.core.indices.knowledge_graph import KnowledgeGraphRAGRetriever
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core.schema import QueryBundle, NodeWithScore
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import NodeWithScore, QueryBundle

# LLM Providers
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from llama_index.llms.gemini import Gemini
from llama_index.llms.openrouter import OpenRouter
# Tools / workflow
from llama_index.core.tools import FunctionTool

# Embedding Providers
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.gemini import GeminiEmbedding

# Graph Store
from llama_index.graph_stores.neo4j import Neo4jGraphStore

# Tools / workflow
from llama_index.core.tools import FunctionTool
# LLM Providers
from llama_index.llms.gemini import Gemini
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from llama_index.llms.openrouter import OpenRouter
from loguru import logger

try: # Optional dependency for advanced workflow integration
from llama_index.core.workflow.tool_node import ToolNode
except Exception: # pragma: no cover - optional runtime dependency
Expand All @@ -51,11 +51,11 @@
merge_pipeline_configs,
)


# =========================
# Retrieval Pipeline Config
# =========================


@dataclass
class PipelineConfig:
"""Configuration for running a retrieval pipeline."""
Expand Down Expand Up @@ -104,9 +104,7 @@ def _merge_nodes(
"""Merge retrieved nodes by keeping the highest scoring entry per node id."""
for node in nodes:
node_id = node.node.node_id if getattr(node, "node", None) else node.node_id
if node_id not in aggregated or (
(aggregated[node_id].score or 0) < (node.score or 0)
):
if node_id not in aggregated or ((aggregated[node_id].score or 0) < (node.score or 0)):
aggregated[node_id] = node

@staticmethod
Expand Down Expand Up @@ -201,9 +199,7 @@ def run(self, question: str, config: PipelineConfig) -> Dict[str, Any]:
)
except Exception as exc: # pragma: no cover - defensive logging
logger.warning(f"Tool node execution failed: {exc}")
tool_outputs.append(
{"tool": "tool_node", "error": str(exc), "is_error": True}
)
tool_outputs.append({"tool": "tool_node", "error": str(exc), "is_error": True})
elif self.function_tools:
for tool in self.function_tools:
try:
Expand Down Expand Up @@ -234,6 +230,7 @@ def run(self, question: str, config: PipelineConfig) -> Dict[str, Any]:
# Knowledge Service Main
# ======================


class Neo4jKnowledgeService:
"""knowledge graph service based on Neo4j's native vector index"""

Expand Down Expand Up @@ -321,9 +318,7 @@ def _create_embedding_model(self):
)
elif provider == "openai":
if not settings.openai_api_key:
raise ValueError(
"OpenAI API key is required for OpenAI embedding provider"
)
raise ValueError("OpenAI API key is required for OpenAI embedding provider")
return OpenAIEmbedding(
model=settings.openai_embedding_model,
api_key=settings.openai_api_key,
Expand All @@ -332,18 +327,14 @@ def _create_embedding_model(self):
)
elif provider == "gemini":
if not settings.google_api_key:
raise ValueError(
"Google API key is required for Gemini embedding provider"
)
raise ValueError("Google API key is required for Gemini embedding provider")
return GeminiEmbedding(
model_name=settings.gemini_embedding_model,
api_key=settings.google_api_key,
)
elif provider == "openrouter":
if not settings.openrouter_api_key:
raise ValueError(
"OpenRouter API key is required for OpenRouter embedding provider"
)
raise ValueError("OpenRouter API key is required for OpenRouter embedding provider")
return OpenAIEmbedding(
model=settings.openrouter_embedding_model,
api_key=settings.openrouter_api_key,
Expand Down Expand Up @@ -535,9 +526,7 @@ async def initialize(self) -> bool:
)

# create storage context
self.storage_context = StorageContext.from_defaults(
graph_store=self.graph_store
)
self.storage_context = StorageContext.from_defaults(graph_store=self.graph_store)

# try to load existing index, if not exists, create new one
try:
Expand Down Expand Up @@ -606,8 +595,14 @@ def _resolve_pipeline_config(
) -> PipelineConfig:
"""Translate user configuration into a pipeline configuration."""
mode = (mode or "hybrid").lower()
run_graph = use_graph if use_graph is not None else mode in ("hybrid", "graph_only", "graph_first")
run_vector = use_vector if use_vector is not None else mode in ("hybrid", "vector_only", "vector_first")
run_graph = (
use_graph if use_graph is not None else mode in ("hybrid", "graph_only", "graph_first")
)
run_vector = (
use_vector
if use_vector is not None
else mode in ("hybrid", "vector_only", "vector_first")
)

if not run_graph and not run_vector:
raise ValueError("At least one of graph or vector retrieval must be enabled")
Expand Down Expand Up @@ -786,9 +781,7 @@ def _process_pipeline() -> Dict[str, Any]:
asyncio.to_thread(_process_pipeline),
timeout=timeout,
)
logger.info(
f"Pipeline '{pipeline_name}' completed with {result['nodes_count']} nodes"
)
logger.info(f"Pipeline '{pipeline_name}' completed with {result['nodes_count']} nodes")
return {
"success": True,
"pipeline": pipeline_name,
Expand Down Expand Up @@ -820,9 +813,7 @@ async def add_document(
metadata.setdefault("timestamp", metadata.get("timestamp", time.time()))

content_size = len(content)
timeout = (
self.operation_timeout if content_size < 10000 else self.large_document_timeout
)
timeout = self.operation_timeout if content_size < 10000 else self.large_document_timeout

result = await self._run_ingestion_pipeline(
"manual_input",
Expand All @@ -835,10 +826,12 @@ async def add_document(
)

if result.get("success"):
result.update({
"message": f"Document '{metadata['title']}' added to knowledge graph",
"content_size": content_size,
})
result.update(
{
"message": f"Document '{metadata['title']}' added to knowledge graph",
"content_size": content_size,
}
)
return result

async def add_file(self, file_path: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -950,9 +943,13 @@ async def search_similar_nodes(
text = text[:200] + "..."
results.append(
{
"node_id": node.node.node_id if getattr(node, "node", None) else node.node_id,
"node_id": (
node.node.node_id if getattr(node, "node", None) else node.node_id
),
"text": text,
"metadata": dict(getattr(getattr(node, "node", None), "metadata", {}) or {}),
"metadata": dict(
getattr(getattr(node, "node", None), "metadata", {}) or {}
),
"score": node.score,
}
)
Expand All @@ -967,43 +964,48 @@ async def search_similar_nodes(
logger.error(f"Failed to search similar nodes: {e}")
return {"success": False, "error": str(e)}

async def get_statistics(self) -> Dict[str, Any]:
"""Return lightweight service statistics for legacy API compatibility."""
if not self._initialized:
raise Exception("Service not initialized")
def _collect_statistics(self) -> Dict[str, Any]:
"""Collect statistics from the knowledge base in a synchronous manner."""
base_stats: Dict[str, Any] = {
"initialized": self._initialized,
"graph_store_type": type(self.graph_store).__name__ if self.graph_store else None,
"vector_index_type": type(self.vector_index).__name__ if self.vector_index else None,
"pipeline": {
"default_top_k": getattr(self.query_pipeline, "default_top_k", None),
"default_graph_depth": getattr(self.query_pipeline, "default_graph_depth", None),
"supports_tools": bool(self.function_tools),
},
}

def _collect_statistics() -> Dict[str, Any]:
base_stats: Dict[str, Any] = {
"initialized": self._initialized,
"graph_store_type": type(self.graph_store).__name__ if self.graph_store else None,
"vector_index_type": type(self.vector_index).__name__ if self.vector_index else None,
"pipeline": {
"default_top_k": getattr(self.query_pipeline, "default_top_k", None),
"default_graph_depth": getattr(self.query_pipeline, "default_graph_depth", None),
"supports_tools": bool(self.function_tools),
},
}
if self.graph_store is None:
return base_stats

if self.graph_store is None:
return base_stats
try:
node_result = self.graph_store.query("MATCH (n) RETURN count(n) AS node_count")
base_stats["node_count"] = node_result[0].get("node_count", 0) if node_result else 0
except Exception as exc:
base_stats["node_count_error"] = str(exc)

try:
node_result = self.graph_store.query("MATCH (n) RETURN count(n) AS node_count")
base_stats["node_count"] = node_result[0].get("node_count", 0) if node_result else 0
except Exception as exc:
base_stats["node_count_error"] = str(exc)
try:
rel_result = self.graph_store.query(
"MATCH ()-[r]->() RETURN count(r) AS relationship_count"
)
base_stats["relationship_count"] = (
rel_result[0].get("relationship_count", 0) if rel_result else 0
)
except Exception as exc:
base_stats["relationship_count_error"] = str(exc)

try:
rel_result = self.graph_store.query("MATCH ()-[r]->() RETURN count(r) AS relationship_count")
base_stats["relationship_count"] = rel_result[0].get("relationship_count", 0) if rel_result else 0
except Exception as exc:
base_stats["relationship_count_error"] = str(exc)
return base_stats

return base_stats
async def get_statistics(self) -> Dict[str, Any]:
"""Return lightweight service statistics for legacy API compatibility."""
if not self._initialized:
raise Exception("Service not initialized")

try:
statistics = await asyncio.wait_for(
asyncio.to_thread(_collect_statistics),
asyncio.to_thread(self._collect_statistics),
timeout=self.operation_timeout,
)
return {"success": True, "statistics": statistics}
Expand All @@ -1015,35 +1017,36 @@ def _collect_statistics() -> Dict[str, Any]:
logger.error(f"Failed to collect statistics: {exc}")
return {"success": False, "error": str(exc)}

def _clear_sync(self) -> None:
"""Clear graph and vector stores synchronously."""
if self.graph_store is None:
raise RuntimeError("Graph store is not available")

# Remove all nodes/relationships
self.graph_store.query("MATCH (n) DETACH DELETE n")

# Best-effort vector store reset (depends on backend capabilities)
vector_store = getattr(self.storage_context, "vector_store", None)
if vector_store is not None:
delete_method = getattr(vector_store, "delete", None)
if callable(delete_method):
try:
delete_method(delete_all=True)
except TypeError:
delete_method()
except Exception as exc: # pragma: no cover - defensive logging
logger.warning(f"Vector store clear failed: {exc}")

async def clear_knowledge_base(self) -> Dict[str, Any]:
"""Clear Neo4j data and rebuild service indices for legacy API compatibility."""
if not self._initialized:
raise Exception("Service not initialized")

async def _clear_graph() -> None:
def _clear_sync() -> None:
if self.graph_store is None:
raise RuntimeError("Graph store is not available")

# Remove all nodes/relationships
self.graph_store.query("MATCH (n) DETACH DELETE n")

# Best-effort vector store reset (depends on backend capabilities)
vector_store = getattr(self.storage_context, "vector_store", None)
if vector_store is not None:
delete_method = getattr(vector_store, "delete", None)
if callable(delete_method):
try:
delete_method(delete_all=True)
except TypeError:
delete_method()
except Exception as exc: # pragma: no cover - defensive logging
logger.warning(f"Vector store clear failed: {exc}")

await asyncio.to_thread(_clear_sync)

try:
await asyncio.wait_for(_clear_graph(), timeout=self.operation_timeout)
await asyncio.wait_for(
asyncio.to_thread(self._clear_sync),
timeout=self.operation_timeout,
)

# Recreate storage context and indexes to reflect cleared state
self.storage_context = StorageContext.from_defaults(graph_store=self.graph_store)
Expand Down Expand Up @@ -1128,4 +1131,3 @@ async def close(self) -> None:
await asyncio.to_thread(self.graph_store.close)
self._initialized = False
logger.info("Neo4j Knowledge Service closed")