diff --git a/requirements.txt b/requirements.txt index 60e7d95..e01bf9e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -163,6 +163,7 @@ llama-index-llms-ollama==0.9.0 # via code-graph (pyproject.toml) llama-index-workflows==2.11.0 # via llama-index-core + # via code-graph (pyproject.toml) loguru==0.7.3 # via code-graph (pyproject.toml) markupsafe==3.0.3 diff --git a/src/codebase_rag/api/agent_routes.py b/src/codebase_rag/api/agent_routes.py new file mode 100644 index 0000000..a68f620 --- /dev/null +++ b/src/codebase_rag/api/agent_routes.py @@ -0,0 +1,111 @@ +"""FastAPI routes exposing the unified LlamaIndex agent workflow.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, HTTPException, Path +from pydantic import BaseModel, Field + +from codebase_rag.services.agents import agent_session_manager + + +router = APIRouter(prefix="/agent", tags=["Agent Orchestration"]) + + +class SessionSummary(BaseModel): + session_id: str + project_id: str + metadata: Dict[str, Any] = Field(default_factory=dict) + turns: int = 0 + tool_events: int = 0 + + +class CreateSessionRequest(BaseModel): + project_id: str = Field(..., description="Project identifier used for retrieval and memory scoping.") + metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional metadata stored alongside the session.") + + +class CreateSessionResponse(SessionSummary): + pass + + +class AgentMessageRequest(BaseModel): + message: str = Field(..., description="User message to send to the orchestrator agent.") + auto_save_memories: bool = Field( + default=False, + description="If True, the memory extraction tool will persist high-confidence memories automatically.", + ) + + +class AgentMessageResponse(BaseModel): + session_id: str + reply: str + tool_events: List[Dict[str, Any]] = Field(default_factory=list) + task: Dict[str, Any] = Field(default_factory=dict) + chat_history: List[Dict[str, str]] = Field(default_factory=list) + + +class SessionStateResponse(BaseModel): + session_id: str + project_id: str + metadata: Dict[str, Any] + chat_history: List[Dict[str, str]] + tool_events: List[Dict[str, Any]] + task_trace: List[Dict[str, Any]] + + +@router.post("/sessions", response_model=CreateSessionResponse) +async def create_agent_session(payload: CreateSessionRequest) -> Dict[str, Any]: + """Create a new agent session scoped to the provided project.""" + + return await agent_session_manager.create_session( + project_id=payload.project_id, + metadata=payload.metadata, + ) + + +@router.get("/sessions", response_model=Dict[str, List[SessionSummary]]) +async def list_agent_sessions() -> Dict[str, List[SessionSummary]]: + """List all active agent sessions.""" + + sessions = await agent_session_manager.list_sessions() + return {"sessions": sessions} + + +@router.get("/sessions/{session_id}", response_model=SessionStateResponse) +async def get_agent_session(session_id: str = Path(..., description="Session identifier")) -> Dict[str, Any]: + """Fetch detailed state for a specific session.""" + + try: + return await agent_session_manager.get_session_state(session_id) + except KeyError as exc: # pragma: no cover - defensive + raise HTTPException(status_code=404, detail=str(exc)) from exc + + +@router.delete("/sessions/{session_id}") +async def close_agent_session(session_id: str = Path(..., description="Session identifier")) -> Dict[str, str]: + """Terminate an existing agent session.""" + + await agent_session_manager.close_session(session_id) + return {"status": "closed", "session_id": session_id} + + +@router.post("/sessions/{session_id}/messages", response_model=AgentMessageResponse) +async def send_agent_message( + payload: AgentMessageRequest, + session_id: str = Path(..., description="Session identifier"), +) -> Dict[str, Any]: + """Send a message to the orchestrator agent and obtain the response.""" + + try: + return await agent_session_manager.process_message( + session_id=session_id, + message=payload.message, + auto_save_memories=payload.auto_save_memories, + ) + except KeyError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + except AttributeError as exc: # pragma: no cover - unexpected agent shape + raise HTTPException(status_code=500, detail=str(exc)) from exc + diff --git a/src/codebase_rag/core/routes.py b/src/codebase_rag/core/routes.py index 373c3f0..5289a4b 100644 --- a/src/codebase_rag/core/routes.py +++ b/src/codebase_rag/core/routes.py @@ -10,6 +10,7 @@ from codebase_rag.api.websocket_routes import router as ws_router from codebase_rag.api.sse_routes import router as sse_router from codebase_rag.api.memory_routes import router as memory_router +from codebase_rag.api.agent_routes import router as agent_router def setup_routes(app: FastAPI) -> None: @@ -21,4 +22,5 @@ def setup_routes(app: FastAPI) -> None: app.include_router(task_router, prefix="/api/v1", tags=["Task Management"]) app.include_router(sse_router, prefix="/api/v1", tags=["Real-time Updates"]) app.include_router(memory_router, tags=["Memory Management"]) + app.include_router(agent_router, prefix="/api/v1", tags=["Agent Orchestration"]) \ No newline at end of file diff --git a/src/codebase_rag/services/__init__.py b/src/codebase_rag/services/__init__.py index 297bcf6..69c27e0 100644 --- a/src/codebase_rag/services/__init__.py +++ b/src/codebase_rag/services/__init__.py @@ -28,4 +28,5 @@ "utils", "pipeline", "graph", + "agents", ] diff --git a/src/codebase_rag/services/agents/__init__.py b/src/codebase_rag/services/agents/__init__.py new file mode 100644 index 0000000..459f43d --- /dev/null +++ b/src/codebase_rag/services/agents/__init__.py @@ -0,0 +1,17 @@ +"""Agent orchestration services built on top of LlamaIndex workflows.""" + +from .base import create_default_agent +from .session_manager import AgentSessionManager +from .tools import AGENT_TOOLS, KNOWLEDGE_TOOLS, MEMORY_TOOLS + +__all__ = [ + "create_default_agent", + "AgentSessionManager", + "agent_session_manager", + "AGENT_TOOLS", + "KNOWLEDGE_TOOLS", + "MEMORY_TOOLS", +] + + +agent_session_manager = AgentSessionManager() diff --git a/src/codebase_rag/services/agents/base.py b/src/codebase_rag/services/agents/base.py new file mode 100644 index 0000000..3241901 --- /dev/null +++ b/src/codebase_rag/services/agents/base.py @@ -0,0 +1,44 @@ +"""Factories for constructing LlamaIndex workflow agents.""" + +from typing import Sequence + +from llama_index.core import Settings +from llama_index.core.agent.workflow import FunctionAgent + +from codebase_rag.config import settings + +from .tools import AGENT_TOOLS + + +def create_default_agent(*, tools: Sequence = AGENT_TOOLS) -> FunctionAgent: + """Create a FunctionAgent wired with the default toolset. + + The agent uses the globally configured LlamaIndex LLM settings and provides + instructions aimed at orchestrating knowledge retrieval, memory extraction and + lightweight task tracking across a project-oriented workflow. + """ + + if Settings.llm is None: + raise ValueError( + "Settings.llm is not configured. Initialize the Neo4j knowledge service " + "or configure Settings.llm before creating agents." + ) + + description = ( + "Project knowledge orchestrator capable of looking up graph knowledge, " + "searching vector similarities, extracting new memories and persisting them." + ) + + system_prompt = ( + "You are the CodebaseRAG coordinator. Always inspect the available tools to " + "answer user questions, retrieve supporting context from Neo4j, and store new " + "memories when relevant. Make sure responses explain which tools were used." + ) + + return FunctionAgent( + name=settings.app_name or "codebase-rag-agent", + description=description, + system_prompt=system_prompt, + tools=list(tools), + llm=Settings.llm, + ) diff --git a/src/codebase_rag/services/agents/session_manager.py b/src/codebase_rag/services/agents/session_manager.py new file mode 100644 index 0000000..2589b43 --- /dev/null +++ b/src/codebase_rag/services/agents/session_manager.py @@ -0,0 +1,213 @@ +"""Conversation orchestration built around LlamaIndex workflow agents.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional +from uuid import uuid4 + +from llama_index.core.agent import AgentOutput, ToolCall, ToolCallResult +from llama_index.core.agent.workflow import FunctionAgent +from llama_index.core.base.llms.types import ChatMessage, MessageRole + +from .base import create_default_agent + + +@dataclass +class AgentSession: + """In-memory record of a running agent session.""" + + session_id: str + project_id: str + agent: FunctionAgent + chat_history: List[ChatMessage] = field(default_factory=list) + tool_events: List[Dict[str, Any]] = field(default_factory=list) + task_trace: List[Dict[str, Any]] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + + def as_dict(self) -> Dict[str, Any]: + """Serialize a lightweight view of the session.""" + + return { + "session_id": self.session_id, + "project_id": self.project_id, + "metadata": self.metadata, + "turns": len(self.chat_history) // 2, + "tool_events": len(self.tool_events), + } + + +def _to_chat_message(role: MessageRole, content: str) -> ChatMessage: + return ChatMessage(role=role, content=content) + + +def _serialize_chat_history(chat_history: List[ChatMessage]) -> List[Dict[str, str]]: + return [ + { + "role": getattr(msg.role, "value", msg.role), + "content": msg.content, + } + for msg in chat_history + ] + + +def _extract_response_text(response: Any) -> str: + if response is None: + return "" + + if isinstance(response, AgentOutput): + return response.response.content or "" + + message = getattr(response, "message", None) + if message is not None and hasattr(message, "content"): + return message.content + + reply = getattr(response, "response", None) + if reply is not None and hasattr(reply, "content"): + return reply.content + + if hasattr(response, "response") and isinstance(response.response, str): + return response.response + + return str(response) + + +async def _collect_tool_activity(handler: Any) -> List[Dict[str, Any]]: + """Capture tool call activity emitted by the workflow handler.""" + + collected: List[Dict[str, Any]] = [] + call_index: Dict[str, int] = {} + + if not hasattr(handler, "stream_events"): + return collected + + async for event in handler.stream_events(): + if isinstance(event, ToolCall): + call_index[event.tool_id] = len(collected) + collected.append( + { + "tool": event.tool_name, + "input": event.tool_kwargs, + } + ) + elif isinstance(event, ToolCallResult): + payload = { + "tool": event.tool_name, + "input": event.tool_kwargs, + "output": getattr(event.tool_output, "content", event.tool_output), + } + existing_idx = call_index.get(event.tool_id) + if existing_idx is not None: + collected[existing_idx].update({k: v for k, v in payload.items() if v is not None}) + else: + collected.append({k: v for k, v in payload.items() if v is not None}) + + return collected + + +class AgentSessionManager: + """Manage long-lived workflow agent chat sessions and tool orchestration.""" + + def __init__(self, agent_factory=create_default_agent): + self._agent_factory = agent_factory + self._sessions: Dict[str, AgentSession] = {} + self._lock = asyncio.Lock() + + async def create_session( + self, + project_id: str, + *, + metadata: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + agent = self._agent_factory() + session_id = str(uuid4()) + session = AgentSession( + session_id=session_id, + project_id=project_id, + agent=agent, + metadata=metadata or {}, + ) + async with self._lock: + self._sessions[session_id] = session + + return session.as_dict() + + async def close_session(self, session_id: str) -> None: + async with self._lock: + self._sessions.pop(session_id, None) + + async def list_sessions(self) -> List[Dict[str, Any]]: + async with self._lock: + return [session.as_dict() for session in self._sessions.values()] + + async def get_session_state(self, session_id: str) -> Dict[str, Any]: + async with self._lock: + session = self._sessions.get(session_id) + + if session is None: + raise KeyError(f"Session '{session_id}' not found") + + return { + "session_id": session.session_id, + "project_id": session.project_id, + "metadata": session.metadata, + "chat_history": _serialize_chat_history(session.chat_history), + "tool_events": session.tool_events, + "task_trace": session.task_trace, + } + + async def process_message( + self, + session_id: str, + message: str, + *, + auto_save_memories: bool = False, + ) -> Dict[str, Any]: + async with self._lock: + session = self._sessions.get(session_id) + + if session is None: + raise KeyError(f"Session '{session_id}' not found") + + session.chat_history.append(_to_chat_message(MessageRole.USER, message)) + + conversation_payload = _serialize_chat_history(session.chat_history) + + call_kwargs = { + "chat_history": session.chat_history, + "metadata": { + "project_id": session.project_id, + "auto_save_memories": auto_save_memories, + "conversation": conversation_payload, + }, + } + + handler = session.agent.run(message, **call_kwargs) + + tool_events = await _collect_tool_activity(handler) + response = await handler + + reply_text = _extract_response_text(response) + if isinstance(response, AgentOutput): + session.chat_history.append(response.response) + else: + session.chat_history.append(_to_chat_message(MessageRole.ASSISTANT, reply_text)) + + session.tool_events.extend(tool_events) + + task_record = { + "user_message": message, + "assistant_reply": reply_text, + "tools_used": tool_events, + } + session.task_trace.append(task_record) + + return { + "session_id": session_id, + "reply": reply_text, + "tool_events": tool_events, + "task": task_record, + "chat_history": _serialize_chat_history(session.chat_history), + } + diff --git a/src/codebase_rag/services/agents/tools.py b/src/codebase_rag/services/agents/tools.py new file mode 100644 index 0000000..fbd8a63 --- /dev/null +++ b/src/codebase_rag/services/agents/tools.py @@ -0,0 +1,146 @@ +"""Declarative tool definitions exposed to the workflow agent.""" + +from __future__ import annotations + +import asyncio +from typing import Any, Dict, Iterable, List + +from llama_index.core.tools import AsyncFunctionTool + +from codebase_rag.services.knowledge.neo4j_knowledge_service import ( + neo4j_knowledge_service, +) +from codebase_rag.services.memory import memory_extractor, memory_store + +_knowledge_lock = asyncio.Lock() +_memory_lock = asyncio.Lock() + + +async def _ensure_knowledge_ready() -> None: + """Initialize the Neo4j knowledge service if required.""" + + if not neo4j_knowledge_service._initialized: # type: ignore[attr-defined] + async with _knowledge_lock: + if not neo4j_knowledge_service._initialized: # type: ignore[attr-defined] + await neo4j_knowledge_service.initialize() + + +async def _ensure_memory_ready() -> None: + """Initialize the memory store when first accessed.""" + + if not memory_store._initialized: # type: ignore[attr-defined] + async with _memory_lock: + if not memory_store._initialized: # type: ignore[attr-defined] + await memory_store.initialize() + + +async def _agent_query_knowledge(question: str, mode: str = "hybrid") -> Dict[str, Any]: + """Run a knowledge graph query through the Neo4j service.""" + + await _ensure_knowledge_ready() + return await neo4j_knowledge_service.query(question=question, mode=mode) + + +async def _agent_similar_search(query: str, top_k: int = 5) -> Dict[str, Any]: + """Retrieve nodes similar to the provided query text.""" + + await _ensure_knowledge_ready() + return await neo4j_knowledge_service.search_similar_nodes(query=query, top_k=top_k) + + +async def _agent_graph_summary() -> Dict[str, Any]: + """Expose a lightweight schema/statistics view of the knowledge graph.""" + + await _ensure_knowledge_ready() + schema = await neo4j_knowledge_service.get_graph_schema() + stats = await neo4j_knowledge_service.get_statistics() + return {"schema": schema, "statistics": stats} + + +async def _agent_extract_memories( + project_id: str, + conversation: List[Dict[str, str]], + auto_save: bool = False, +) -> Dict[str, Any]: + """Use the MemoryExtractor to analyse a conversation.""" + + await _ensure_memory_ready() + return await memory_extractor.extract_from_conversation( + project_id=project_id, + conversation=conversation, + auto_save=auto_save, + ) + + +async def _agent_save_memory( + project_id: str, + memory_type: str, + title: str, + content: str, + reason: str | None = None, + tags: Iterable[str] | None = None, + importance: float = 0.5, + metadata: Dict[str, Any] | None = None, +) -> Dict[str, Any]: + """Persist a memory entry directly through the MemoryStore.""" + + await _ensure_memory_ready() + return await memory_store.add_memory( + project_id=project_id, + memory_type=memory_type, # type: ignore[arg-type] + title=title, + content=content, + reason=reason, + tags=list(tags) if tags is not None else None, + importance=importance, + metadata=metadata, + ) + + +KNOWLEDGE_TOOLS = [ + AsyncFunctionTool.from_defaults( + fn=_agent_query_knowledge, + name="query_knowledge_graph", + description=( + "Query the Neo4j knowledge graph using hybrid retrieval. Use this when " + "you need long-form answers backed by stored documents." + ), + ), + AsyncFunctionTool.from_defaults( + fn=_agent_similar_search, + name="search_similar_nodes", + description=( + "Retrieve top related nodes using semantic similarity in the knowledge graph." + ), + ), + AsyncFunctionTool.from_defaults( + fn=_agent_graph_summary, + name="describe_graph_state", + description=( + "Get schema and health information about the Neo4j knowledge graph " + "to support planning or diagnostics." + ), + ), +] + +MEMORY_TOOLS = [ + AsyncFunctionTool.from_defaults( + fn=_agent_extract_memories, + name="extract_conversation_memories", + description=( + "Analyse the current conversation and suggest project memories. " + "Set auto_save to true to persist high-confidence results automatically." + ), + ), + AsyncFunctionTool.from_defaults( + fn=_agent_save_memory, + name="save_project_memory", + description=( + "Persist an explicit memory entry for the current project into the " + "long-term Neo4j store." + ), + ), +] + +AGENT_TOOLS = [*KNOWLEDGE_TOOLS, *MEMORY_TOOLS] + diff --git a/src/codebase_rag/services/pipeline/__init__.py b/src/codebase_rag/services/pipeline/__init__.py index 3312ae4..6ba5914 100644 --- a/src/codebase_rag/services/pipeline/__init__.py +++ b/src/codebase_rag/services/pipeline/__init__.py @@ -1 +1,5 @@ -# Knowledge Pipeline module initialization \ No newline at end of file +"""Pipeline service package exports.""" + +from codebase_rag.services.code import PackBuilder, pack_builder + +__all__ = ["PackBuilder", "pack_builder"] \ No newline at end of file diff --git a/src/codebase_rag/services/utils/metrics.py b/src/codebase_rag/services/utils/metrics.py index 798cd04..184bd32 100644 --- a/src/codebase_rag/services/utils/metrics.py +++ b/src/codebase_rag/services/utils/metrics.py @@ -154,7 +154,7 @@ class MetricsService: - """Service for managing Prometheus metrics""" + """Service for managing Prometheus metrics.""" def __init__(self): self.registry = registry @@ -289,8 +289,12 @@ async def update_neo4j_metrics(self, graph_service): self.update_neo4j_status(False) +class MetricsCollector(MetricsService): + """Backward compatible alias for the previous metrics collector API.""" + + # Create singleton instance -metrics_service = MetricsService() +metrics_service = MetricsCollector() def track_duration(operation: str, metric_type: str = "graph"): diff --git a/tests/conftest.py b/tests/conftest.py index ad97d98..d31db97 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,8 +7,15 @@ from pathlib import Path from unittest.mock import AsyncMock, Mock -# Add parent directory to path for imports -sys.path.insert(0, str(Path(__file__).parent.parent)) +# Ensure the project root and `src/` directory are available for imports. +# pytest executes from the repository root, but our package lives in `src/`. +ROOT_DIR = Path(__file__).parent.parent +SRC_DIR = ROOT_DIR / "src" + +for path in (ROOT_DIR, SRC_DIR): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) from fastapi.testclient import TestClient from src.codebase_rag.services.code import Neo4jGraphService