Skip to content
87 changes: 58 additions & 29 deletions willa/chatbot/graph_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Manages the shared state and workflow for Willa chatbots."""
from typing import Any, Optional, Annotated, NotRequired
from typing import Optional, Annotated, NotRequired
from typing_extensions import TypedDict

from langchain_core.documents import Document
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import ChatMessage, HumanMessage, AIMessage
from langchain_core.vectorstores.base import VectorStore
Expand All @@ -19,10 +20,10 @@ class WillaChatbotState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
filtered_messages: NotRequired[list[AnyMessage]]
summarized_messages: NotRequired[list[AnyMessage]]
docs_context: NotRequired[str]
messages_for_generation: NotRequired[list[AnyMessage]]
search_query: NotRequired[str]
tind_metadata: NotRequired[str]
context: NotRequired[dict[str, Any]]
documents: NotRequired[list[dict[str, str]]]


class GraphManager: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -51,13 +52,15 @@ def _create_workflow(self) -> CompiledStateGraph:
workflow.add_node("summarize", summarization_node)
workflow.add_node("prepare_search", self._prepare_search_query)
workflow.add_node("retrieve_context", self._retrieve_context)
workflow.add_node("prepare_for_generation", self._prepare_for_generation)
workflow.add_node("generate_response", self._generate_response)

# Define edges
workflow.add_edge("filter_messages", "summarize")
workflow.add_edge("summarize", "prepare_search")
workflow.add_edge("prepare_search", "retrieve_context")
workflow.add_edge("retrieve_context", "generate_response")
workflow.add_edge("retrieve_context", "prepare_for_generation")
workflow.add_edge("prepare_for_generation", "generate_response")

workflow.set_entry_point("filter_messages")
workflow.set_finish_point("generate_response")
Expand All @@ -68,7 +71,10 @@ def _filter_messages(self, state: WillaChatbotState) -> dict[str, list[AnyMessag
"""Filter out TIND messages from the conversation history."""
messages = state["messages"]

filtered = [msg for msg in messages if 'tind' not in msg.response_metadata]
filtered: list[AnyMessage] = [
msg for msg in messages
if "tind" not in getattr(msg, "response_metadata", {}) and msg.type != "system"
]
return {"filtered_messages": filtered}

def _prepare_search_query(self, state: WillaChatbotState) -> dict[str, str]:
Expand All @@ -79,60 +85,83 @@ def _prepare_search_query(self, state: WillaChatbotState) -> dict[str, str]:

# summarization may include a system message as well as any human or ai messages
search_query = '\n'.join(str(msg.content) for msg in messages if hasattr(msg, 'content'))

# if summarization fails or some other issue, truncate to the last 2048 characters
if len(search_query) > 2048:
search_query = search_query[-2048:]

return {"search_query": search_query}

def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]:
def _format_retrieved_documents(self, matching_docs: list[Document]) -> list[dict[str, str]]:
"""Format documents from vector store into a list of dictionaries."""
formatted_documents: list[dict[str, str]] = []
for i, doc in enumerate(matching_docs, 1):
tind_metadata = doc.metadata.get('tind_metadata', {})
tind_id = tind_metadata.get('tind_id', [''])[0]
formatted_documents.append({
"id": f"{i}_{tind_id}",
"page_content": doc.page_content,
"title": tind_metadata.get('title', [''])[0],
"project": tind_metadata.get('isPartOf', [''])[0],
"tind_link": format_tind_context.get_tind_url(tind_id)
})
return formatted_documents

def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str | list[dict[str, str]]]:
"""Retrieve relevant context from vector store."""
search_query = state.get("search_query", "")
vector_store = self._vector_store

if not search_query or not vector_store:
return {"docs_context": "", "tind_metadata": ""}
return {"tind_metadata": "", "documents": []}

# Search for relevant documents
retriever = vector_store.as_retriever(search_kwargs={"k": int(CONFIG['K_VALUE'])})
matching_docs = retriever.invoke(search_query)
formatted_documents = self._format_retrieved_documents(matching_docs)

# Format context and metadata
docs_context = '\n\n'.join(doc.page_content for doc in matching_docs)
# Format tind metadata
tind_metadata = format_tind_context.get_tind_context(matching_docs)

return {"docs_context": docs_context, "tind_metadata": tind_metadata}
return {"tind_metadata": tind_metadata, "documents": formatted_documents}

# This should be refactored probably. Very bulky
def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
"""Generate response using the model."""
def _prepare_for_generation(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
"""Prepare the current and past messages for response generation."""
messages = state["messages"]
summarized_conversation = state.get("summarized_messages", messages)
docs_context = state.get("docs_context", "")
tind_metadata = state.get("tind_metadata", "")
model = self._model

if not model:
return {"messages": [AIMessage(content="Model not available.")]}

# Get the latest human message
latest_message = next(
(msg for msg in reversed(messages) if isinstance(msg, HumanMessage)),
None
)

if not latest_message:
if not any(isinstance(msg, HumanMessage) for msg in messages):
return {"messages": [AIMessage(content="I'm sorry, I didn't receive a question.")]}

prompt = get_langfuse_prompt()
system_messages = prompt.invoke({'context': docs_context,
'question': latest_message.content})
system_messages = prompt.invoke({})

if hasattr(system_messages, "messages"):
all_messages = summarized_conversation + system_messages.messages
else:
all_messages = summarized_conversation + [system_messages]

return {"messages_for_generation": all_messages}

def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
"""Generate response using the model."""
tind_metadata = state.get("tind_metadata", "")
model = self._model
documents = state.get("documents", [])
messages = state.get("messages_for_generation") or state.get("messages", [])

if not model:
return {"messages": [AIMessage(content="Model not available.")]}

# Get response from model
response = model.invoke(all_messages)
response = model.invoke(
messages,
additional_model_request_fields={"documents": documents}
)

# Create clean response content
response_content = str(response.content) if hasattr(response, 'content') else str(response)

response_messages: list[AnyMessage] = [AIMessage(content=response_content),
ChatMessage(content=tind_metadata, role='TIND',
response_metadata={'tind': True})]
Expand Down