diff --git a/src/mcp_client_cli/agent.py b/src/mcp_client_cli/agent.py new file mode 100644 index 0000000..d5e3999 --- /dev/null +++ b/src/mcp_client_cli/agent.py @@ -0,0 +1,129 @@ +from typing import Annotated, TypedDict, Sequence +import json + +from langchain_core.messages import BaseMessage +from langgraph.graph.message import add_messages +from langgraph.managed import IsLastStep +from langchain_core.tools import Tool +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import ToolMessage +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import RunnableConfig +from langgraph.graph import StateGraph, END +from langgraph.checkpoint.base import BaseCheckpointSaver +from langgraph.store.base import BaseStore + + +# The AgentState class is used to maintain the state of the agent during a conversation. +class AgentState(TypedDict): + # A list of messages exchanged in the conversation. + messages: Annotated[Sequence[BaseMessage], add_messages] + # A flag indicating whether the current step is the last step in the conversation. + is_last_step: IsLastStep + # The current date and time, used for context in the conversation. + today_datetime: str + # The user's memories. + memories: str = "no memories" + + +class ReActAgent: + """ + An agent that uses the ReAct architecture to respond to user queries. + + This mainly taken and modified from https://langchain-ai.github.io/langgraph/how-tos/react-agent-from-scratch + """ + + def __init__(self, model: BaseChatModel, tools: list[Tool], system_prompt: str, checkpointer: BaseCheckpointSaver, store: BaseStore): + self.prompt = ChatPromptTemplate.from_messages([ + ("system", system_prompt), + ("placeholder", "{messages}") + ]) + if tools: + self.model = model.bind_tools(tools) + self.tools_by_name = {tool.name: tool for tool in tools} + else: + self.model = model + self.tools_by_name = {} + self.chain = self.prompt | self.model + self.tools = tools + self.checkpointer = checkpointer + self.store = store + self.create_graph() + + async def astream(self, input: AgentState, thread_id: str): + async for chunk in self.graph.astream( + input, + stream_mode=["messages", "values"], + config={"configurable": {"thread_id": thread_id}, "recursion_limit": 100}, + ): + yield chunk + + def create_graph(self): + # Define a new graph + workflow = StateGraph(AgentState) + + # Define the two nodes we will cycle between + workflow.add_node("agent", self.call_model) + workflow.add_node("tools", self.tool_node) + + # Set the entrypoint as `agent` + # This means that this node is the first one called + workflow.set_entry_point("agent") + + # We now add a conditional edge + workflow.add_conditional_edges( + # First, we define the start node. We use `agent`. + # This means these are the edges taken after the `agent` node is called. + "agent", + # Next, we pass in the function that will determine which node is called next. + self.should_continue, + # Finally we pass in a mapping. + # The keys are strings, and the values are other nodes. + # END is a special node marking that the graph should finish. + # What will happen is we will call `should_continue`, and then the output of that + # will be matched against the keys in this mapping. + # Based on which one it matches, that node will then be called. + { + # If `tools`, then we call the tool node. + "continue": "tools", + # Otherwise we finish. + "end": END, + }, + ) + + # We now add a normal edge from `tools` to `agent`. + # This means that after `tools` is called, `agent` node is called next. + workflow.add_edge("tools", "agent") + + # Now we can compile and visualize our graph + self.graph = workflow.compile(checkpointer=self.checkpointer, store=self.store) + + async def tool_node(self, state: AgentState): + outputs = [] + for tool_call in state["messages"][-1].tool_calls: + tool_result = await self.tools_by_name[tool_call["name"]].ainvoke(tool_call["args"]) + outputs.append( + ToolMessage( + content=json.dumps(tool_result), + name=tool_call["name"], + tool_call_id=tool_call["id"], + ) + ) + return {"messages": outputs} + + async def call_model( + self, + state: AgentState, + config: RunnableConfig, + ): + response = await self.chain.ainvoke(state, config) + return {"messages": [response]} + + + def should_continue(self, state: AgentState): + messages = state["messages"] + last_message = messages[-1] + if not last_message.tool_calls: + return "end" + else: + return "continue" diff --git a/src/mcp_client_cli/cli.py b/src/mcp_client_cli/cli.py index 5af6593..bff2266 100755 --- a/src/mcp_client_cli/cli.py +++ b/src/mcp_client_cli/cli.py @@ -8,25 +8,22 @@ import argparse import asyncio import os -from typing import Annotated, TypedDict import uuid import sys import re import anyio -from langchain_core.messages import BaseMessage, HumanMessage -from langchain_core.prompts import ChatPromptTemplate +import base64 +import imghdr +import mimetypes + +from langchain_core.messages import HumanMessage from langchain_core.language_models.chat_models import BaseChatModel -from langgraph.prebuilt import create_react_agent -from langgraph.managed import IsLastStep -from langgraph.graph.message import add_messages from langchain.chat_models import init_chat_model from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from rich.console import Console from rich.table import Table -import base64 -import imghdr -import mimetypes +from .agent import * from .input import * from .const import * from .output import * @@ -36,17 +33,6 @@ from .memory import * from .config import AppConfig -# The AgentState class is used to maintain the state of the agent during a conversation. -class AgentState(TypedDict): - # A list of messages exchanged in the conversation. - messages: Annotated[list[BaseMessage], add_messages] - # A flag indicating whether the current step is the last step in the conversation. - is_last_step: IsLastStep - # The current date and time, used for context in the conversation. - today_datetime: str - # The user's memories. - memories: str = "no memories" - async def run() -> None: """Run the LLM agent.""" args = setup_argument_parser() @@ -211,20 +197,18 @@ async def handle_conversation(args: argparse.Namespace, query: HumanMessage, extra_body=extra_body ) - prompt = ChatPromptTemplate.from_messages([ - ("system", app_config.system_prompt), - ("placeholder", "{messages}") - ]) - conversation_manager = ConversationManager(SQLITE_DB) async with AsyncSqliteSaver.from_conn_string(SQLITE_DB) as checkpointer: store = SqliteStore(SQLITE_DB) memories = await get_memories(store) formatted_memories = "\n".join(f"- {memory}" for memory in memories) - agent_executor = create_react_agent( - model, tools, state_schema=AgentState, - state_modifier=prompt, checkpointer=checkpointer, store=store + agent_executor = ReActAgent( + model=model, + system_prompt=app_config.system_prompt, + tools=tools, + checkpointer=checkpointer, + store=store ) thread_id = (await conversation_manager.get_last_id() if is_conversation_continuation @@ -241,9 +225,7 @@ async def handle_conversation(args: argparse.Namespace, query: HumanMessage, try: async for chunk in agent_executor.astream( input_messages, - stream_mode=["messages", "values"], - config={"configurable": {"thread_id": thread_id, "user_id": "myself"}, - "recursion_limit": 100} + thread_id=thread_id, ): output.update(chunk) if not args.no_confirmations: diff --git a/src/mcp_client_cli/memory.py b/src/mcp_client_cli/memory.py index 2e2c960..be8c3ef 100644 --- a/src/mcp_client_cli/memory.py +++ b/src/mcp_client_cli/memory.py @@ -42,7 +42,7 @@ @tool async def save_memory(memories: List[str], *, config: RunnableConfig, store: Annotated[BaseStore, InjectedStore()]) -> str: '''Save the given memory for the current user. Do not save duplicate memories.''' - user_id = config.get("configurable", {}).get("user_id") + user_id = config.get("configurable", {}).get("user_id", "myself") namespace = ("memories", user_id) for memory in memories: id = uuid.uuid4().hex