diff --git a/src/codebase_rag/services/knowledge/neo4j_knowledge_service.py b/src/codebase_rag/services/knowledge/neo4j_knowledge_service.py index f4a6b40..f6c034b 100644 --- a/src/codebase_rag/services/knowledge/neo4j_knowledge_service.py +++ b/src/codebase_rag/services/knowledge/neo4j_knowledge_service.py @@ -1,16 +1,14 @@ - """ modern knowledge graph service based on Neo4j's native vector index uses LlamaIndex's KnowledgeGraphIndex and Neo4j's native vector search functionality supports multiple LLM and embedding model providers """ -from dataclasses import dataclass, field -from typing import List, Dict, Any, Optional, Union -from pathlib import Path import asyncio -from loguru import logger import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Union from llama_index.core import ( KnowledgeGraphIndex, @@ -19,26 +17,28 @@ VectorStoreIndex, ) from llama_index.core.indices.knowledge_graph import KnowledgeGraphRAGRetriever -from llama_index.core.retrievers import VectorIndexRetriever from llama_index.core.response_synthesizers import get_response_synthesizer -from llama_index.core.schema import QueryBundle, NodeWithScore +from llama_index.core.retrievers import VectorIndexRetriever +from llama_index.core.schema import NodeWithScore, QueryBundle -# LLM Providers -from llama_index.llms.ollama import Ollama -from llama_index.llms.openai import OpenAI -from llama_index.llms.gemini import Gemini -from llama_index.llms.openrouter import OpenRouter +# Tools / workflow +from llama_index.core.tools import FunctionTool # Embedding Providers +from llama_index.embeddings.gemini import GeminiEmbedding from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.embeddings.openai import OpenAIEmbedding -from llama_index.embeddings.gemini import GeminiEmbedding # Graph Store from llama_index.graph_stores.neo4j import Neo4jGraphStore -# Tools / workflow -from llama_index.core.tools import FunctionTool +# LLM Providers +from llama_index.llms.gemini import Gemini +from llama_index.llms.ollama import Ollama +from llama_index.llms.openai import OpenAI +from llama_index.llms.openrouter import OpenRouter +from loguru import logger + try: # Optional dependency for advanced workflow integration from llama_index.core.workflow.tool_node import ToolNode except Exception: # pragma: no cover - optional runtime dependency @@ -51,11 +51,11 @@ merge_pipeline_configs, ) - # ========================= # Retrieval Pipeline Config # ========================= + @dataclass class PipelineConfig: """Configuration for running a retrieval pipeline.""" @@ -104,9 +104,7 @@ def _merge_nodes( """Merge retrieved nodes by keeping the highest scoring entry per node id.""" for node in nodes: node_id = node.node.node_id if getattr(node, "node", None) else node.node_id - if node_id not in aggregated or ( - (aggregated[node_id].score or 0) < (node.score or 0) - ): + if node_id not in aggregated or ((aggregated[node_id].score or 0) < (node.score or 0)): aggregated[node_id] = node @staticmethod @@ -201,9 +199,7 @@ def run(self, question: str, config: PipelineConfig) -> Dict[str, Any]: ) except Exception as exc: # pragma: no cover - defensive logging logger.warning(f"Tool node execution failed: {exc}") - tool_outputs.append( - {"tool": "tool_node", "error": str(exc), "is_error": True} - ) + tool_outputs.append({"tool": "tool_node", "error": str(exc), "is_error": True}) elif self.function_tools: for tool in self.function_tools: try: @@ -234,6 +230,7 @@ def run(self, question: str, config: PipelineConfig) -> Dict[str, Any]: # Knowledge Service Main # ====================== + class Neo4jKnowledgeService: """knowledge graph service based on Neo4j's native vector index""" @@ -321,9 +318,7 @@ def _create_embedding_model(self): ) elif provider == "openai": if not settings.openai_api_key: - raise ValueError( - "OpenAI API key is required for OpenAI embedding provider" - ) + raise ValueError("OpenAI API key is required for OpenAI embedding provider") return OpenAIEmbedding( model=settings.openai_embedding_model, api_key=settings.openai_api_key, @@ -332,18 +327,14 @@ def _create_embedding_model(self): ) elif provider == "gemini": if not settings.google_api_key: - raise ValueError( - "Google API key is required for Gemini embedding provider" - ) + raise ValueError("Google API key is required for Gemini embedding provider") return GeminiEmbedding( model_name=settings.gemini_embedding_model, api_key=settings.google_api_key, ) elif provider == "openrouter": if not settings.openrouter_api_key: - raise ValueError( - "OpenRouter API key is required for OpenRouter embedding provider" - ) + raise ValueError("OpenRouter API key is required for OpenRouter embedding provider") return OpenAIEmbedding( model=settings.openrouter_embedding_model, api_key=settings.openrouter_api_key, @@ -535,9 +526,7 @@ async def initialize(self) -> bool: ) # create storage context - self.storage_context = StorageContext.from_defaults( - graph_store=self.graph_store - ) + self.storage_context = StorageContext.from_defaults(graph_store=self.graph_store) # try to load existing index, if not exists, create new one try: @@ -606,8 +595,14 @@ def _resolve_pipeline_config( ) -> PipelineConfig: """Translate user configuration into a pipeline configuration.""" mode = (mode or "hybrid").lower() - run_graph = use_graph if use_graph is not None else mode in ("hybrid", "graph_only", "graph_first") - run_vector = use_vector if use_vector is not None else mode in ("hybrid", "vector_only", "vector_first") + run_graph = ( + use_graph if use_graph is not None else mode in ("hybrid", "graph_only", "graph_first") + ) + run_vector = ( + use_vector + if use_vector is not None + else mode in ("hybrid", "vector_only", "vector_first") + ) if not run_graph and not run_vector: raise ValueError("At least one of graph or vector retrieval must be enabled") @@ -786,9 +781,7 @@ def _process_pipeline() -> Dict[str, Any]: asyncio.to_thread(_process_pipeline), timeout=timeout, ) - logger.info( - f"Pipeline '{pipeline_name}' completed with {result['nodes_count']} nodes" - ) + logger.info(f"Pipeline '{pipeline_name}' completed with {result['nodes_count']} nodes") return { "success": True, "pipeline": pipeline_name, @@ -820,9 +813,7 @@ async def add_document( metadata.setdefault("timestamp", metadata.get("timestamp", time.time())) content_size = len(content) - timeout = ( - self.operation_timeout if content_size < 10000 else self.large_document_timeout - ) + timeout = self.operation_timeout if content_size < 10000 else self.large_document_timeout result = await self._run_ingestion_pipeline( "manual_input", @@ -835,10 +826,12 @@ async def add_document( ) if result.get("success"): - result.update({ - "message": f"Document '{metadata['title']}' added to knowledge graph", - "content_size": content_size, - }) + result.update( + { + "message": f"Document '{metadata['title']}' added to knowledge graph", + "content_size": content_size, + } + ) return result async def add_file(self, file_path: str) -> Dict[str, Any]: @@ -950,9 +943,13 @@ async def search_similar_nodes( text = text[:200] + "..." results.append( { - "node_id": node.node.node_id if getattr(node, "node", None) else node.node_id, + "node_id": ( + node.node.node_id if getattr(node, "node", None) else node.node_id + ), "text": text, - "metadata": dict(getattr(getattr(node, "node", None), "metadata", {}) or {}), + "metadata": dict( + getattr(getattr(node, "node", None), "metadata", {}) or {} + ), "score": node.score, } ) @@ -967,43 +964,48 @@ async def search_similar_nodes( logger.error(f"Failed to search similar nodes: {e}") return {"success": False, "error": str(e)} - async def get_statistics(self) -> Dict[str, Any]: - """Return lightweight service statistics for legacy API compatibility.""" - if not self._initialized: - raise Exception("Service not initialized") + def _collect_statistics(self) -> Dict[str, Any]: + """Collect statistics from the knowledge base in a synchronous manner.""" + base_stats: Dict[str, Any] = { + "initialized": self._initialized, + "graph_store_type": type(self.graph_store).__name__ if self.graph_store else None, + "vector_index_type": type(self.vector_index).__name__ if self.vector_index else None, + "pipeline": { + "default_top_k": getattr(self.query_pipeline, "default_top_k", None), + "default_graph_depth": getattr(self.query_pipeline, "default_graph_depth", None), + "supports_tools": bool(self.function_tools), + }, + } - def _collect_statistics() -> Dict[str, Any]: - base_stats: Dict[str, Any] = { - "initialized": self._initialized, - "graph_store_type": type(self.graph_store).__name__ if self.graph_store else None, - "vector_index_type": type(self.vector_index).__name__ if self.vector_index else None, - "pipeline": { - "default_top_k": getattr(self.query_pipeline, "default_top_k", None), - "default_graph_depth": getattr(self.query_pipeline, "default_graph_depth", None), - "supports_tools": bool(self.function_tools), - }, - } + if self.graph_store is None: + return base_stats - if self.graph_store is None: - return base_stats + try: + node_result = self.graph_store.query("MATCH (n) RETURN count(n) AS node_count") + base_stats["node_count"] = node_result[0].get("node_count", 0) if node_result else 0 + except Exception as exc: + base_stats["node_count_error"] = str(exc) - try: - node_result = self.graph_store.query("MATCH (n) RETURN count(n) AS node_count") - base_stats["node_count"] = node_result[0].get("node_count", 0) if node_result else 0 - except Exception as exc: - base_stats["node_count_error"] = str(exc) + try: + rel_result = self.graph_store.query( + "MATCH ()-[r]->() RETURN count(r) AS relationship_count" + ) + base_stats["relationship_count"] = ( + rel_result[0].get("relationship_count", 0) if rel_result else 0 + ) + except Exception as exc: + base_stats["relationship_count_error"] = str(exc) - try: - rel_result = self.graph_store.query("MATCH ()-[r]->() RETURN count(r) AS relationship_count") - base_stats["relationship_count"] = rel_result[0].get("relationship_count", 0) if rel_result else 0 - except Exception as exc: - base_stats["relationship_count_error"] = str(exc) + return base_stats - return base_stats + async def get_statistics(self) -> Dict[str, Any]: + """Return lightweight service statistics for legacy API compatibility.""" + if not self._initialized: + raise Exception("Service not initialized") try: statistics = await asyncio.wait_for( - asyncio.to_thread(_collect_statistics), + asyncio.to_thread(self._collect_statistics), timeout=self.operation_timeout, ) return {"success": True, "statistics": statistics} @@ -1015,35 +1017,36 @@ def _collect_statistics() -> Dict[str, Any]: logger.error(f"Failed to collect statistics: {exc}") return {"success": False, "error": str(exc)} + def _clear_sync(self) -> None: + """Clear graph and vector stores synchronously.""" + if self.graph_store is None: + raise RuntimeError("Graph store is not available") + + # Remove all nodes/relationships + self.graph_store.query("MATCH (n) DETACH DELETE n") + + # Best-effort vector store reset (depends on backend capabilities) + vector_store = getattr(self.storage_context, "vector_store", None) + if vector_store is not None: + delete_method = getattr(vector_store, "delete", None) + if callable(delete_method): + try: + delete_method(delete_all=True) + except TypeError: + delete_method() + except Exception as exc: # pragma: no cover - defensive logging + logger.warning(f"Vector store clear failed: {exc}") + async def clear_knowledge_base(self) -> Dict[str, Any]: """Clear Neo4j data and rebuild service indices for legacy API compatibility.""" if not self._initialized: raise Exception("Service not initialized") - async def _clear_graph() -> None: - def _clear_sync() -> None: - if self.graph_store is None: - raise RuntimeError("Graph store is not available") - - # Remove all nodes/relationships - self.graph_store.query("MATCH (n) DETACH DELETE n") - - # Best-effort vector store reset (depends on backend capabilities) - vector_store = getattr(self.storage_context, "vector_store", None) - if vector_store is not None: - delete_method = getattr(vector_store, "delete", None) - if callable(delete_method): - try: - delete_method(delete_all=True) - except TypeError: - delete_method() - except Exception as exc: # pragma: no cover - defensive logging - logger.warning(f"Vector store clear failed: {exc}") - - await asyncio.to_thread(_clear_sync) - try: - await asyncio.wait_for(_clear_graph(), timeout=self.operation_timeout) + await asyncio.wait_for( + asyncio.to_thread(self._clear_sync), + timeout=self.operation_timeout, + ) # Recreate storage context and indexes to reflect cleared state self.storage_context = StorageContext.from_defaults(graph_store=self.graph_store) @@ -1128,4 +1131,3 @@ async def close(self) -> None: await asyncio.to_thread(self.graph_store.close) self._initialized = False logger.info("Neo4j Knowledge Service closed") -