From ac35453ed19f8a98d7a6e6a76fda69603c60af78 Mon Sep 17 00:00:00 2001 From: Adhika Setya Pramudita Date: Wed, 1 Jan 2025 16:50:38 +0700 Subject: [PATCH 1/2] refactor(agent): implement own ReAct agent architecture - Introduced a new `agent.py` file containing the `ReActAgent` class, which utilizes the ReAct architecture to manage user interactions and responses. - Added the `AgentState` class to maintain conversation context, including messages, last step flag, current date, and user memories. - Updated `cli.py` to instantiate the `ReActAgent`, replacing the previous agent creation method, and streamlined the conversation handling process. - Enhanced the graph structure for managing conversation flow, allowing for conditional transitions between agent and tool nodes. --- src/mcp_client_cli/agent.py | 121 ++++++++++++++++++++++++++++++++++++ src/mcp_client_cli/cli.py | 48 +++++--------- 2 files changed, 138 insertions(+), 31 deletions(-) create mode 100644 src/mcp_client_cli/agent.py diff --git a/src/mcp_client_cli/agent.py b/src/mcp_client_cli/agent.py new file mode 100644 index 0000000..294de9b --- /dev/null +++ b/src/mcp_client_cli/agent.py @@ -0,0 +1,121 @@ +from typing import Annotated, TypedDict, Sequence +import json + +from langchain_core.messages import BaseMessage +from langgraph.graph.message import add_messages +from langgraph.managed import IsLastStep +from langchain_core.tools import Tool +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import ToolMessage +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import RunnableConfig +from langgraph.graph import StateGraph, END +from langgraph.checkpoint.base import BaseCheckpointSaver +from langgraph.store.base import BaseStore + + +# The AgentState class is used to maintain the state of the agent during a conversation. +class AgentState(TypedDict): + # A list of messages exchanged in the conversation. + messages: Annotated[Sequence[BaseMessage], add_messages] + # A flag indicating whether the current step is the last step in the conversation. + is_last_step: IsLastStep + # The current date and time, used for context in the conversation. + today_datetime: str + # The user's memories. + memories: str = "no memories" + + +class ReActAgent: + """ + An agent that uses the ReAct architecture to respond to user queries. + + This mainly taken and modified from https://langchain-ai.github.io/langgraph/how-tos/react-agent-from-scratch + """ + + def __init__(self, model: BaseChatModel, tools: list[Tool], system_prompt: str, checkpointer: BaseCheckpointSaver, store: BaseStore): + self.prompt = ChatPromptTemplate.from_messages([ + ("system", system_prompt), + ("placeholder", "{messages}") + ]) + self.model = model.bind_tools(tools) + self.chain = self.prompt | self.model + self.tools = tools + self.tools_by_name = {tool.name: tool for tool in tools} + self.checkpointer = checkpointer + self.store = store + self.create_graph() + + async def astream(self, input: AgentState, thread_id: str): + async for chunk in self.graph.astream(input, stream_mode=["messages", "values"], config={"configurable": {"thread_id": thread_id}}): + yield chunk + + def create_graph(self): + # Define a new graph + workflow = StateGraph(AgentState) + + # Define the two nodes we will cycle between + workflow.add_node("agent", self.call_model) + workflow.add_node("tools", self.tool_node) + + # Set the entrypoint as `agent` + # This means that this node is the first one called + workflow.set_entry_point("agent") + + # We now add a conditional edge + workflow.add_conditional_edges( + # First, we define the start node. We use `agent`. + # This means these are the edges taken after the `agent` node is called. + "agent", + # Next, we pass in the function that will determine which node is called next. + self.should_continue, + # Finally we pass in a mapping. + # The keys are strings, and the values are other nodes. + # END is a special node marking that the graph should finish. + # What will happen is we will call `should_continue`, and then the output of that + # will be matched against the keys in this mapping. + # Based on which one it matches, that node will then be called. + { + # If `tools`, then we call the tool node. + "continue": "tools", + # Otherwise we finish. + "end": END, + }, + ) + + # We now add a normal edge from `tools` to `agent`. + # This means that after `tools` is called, `agent` node is called next. + workflow.add_edge("tools", "agent") + + # Now we can compile and visualize our graph + self.graph = workflow.compile(checkpointer=self.checkpointer, store=self.store) + + async def tool_node(self, state: AgentState): + outputs = [] + for tool_call in state["messages"][-1].tool_calls: + tool_result = await self.tools_by_name[tool_call["name"]].ainvoke(tool_call["args"]) + outputs.append( + ToolMessage( + content=json.dumps(tool_result), + name=tool_call["name"], + tool_call_id=tool_call["id"], + ) + ) + return {"messages": outputs} + + async def call_model( + self, + state: AgentState, + config: RunnableConfig, + ): + response = await self.chain.ainvoke(state, config) + return {"messages": [response]} + + + def should_continue(self, state: AgentState): + messages = state["messages"] + last_message = messages[-1] + if not last_message.tool_calls: + return "end" + else: + return "continue" diff --git a/src/mcp_client_cli/cli.py b/src/mcp_client_cli/cli.py index 5af6593..0400ea7 100755 --- a/src/mcp_client_cli/cli.py +++ b/src/mcp_client_cli/cli.py @@ -8,25 +8,22 @@ import argparse import asyncio import os -from typing import Annotated, TypedDict import uuid import sys import re import anyio -from langchain_core.messages import BaseMessage, HumanMessage -from langchain_core.prompts import ChatPromptTemplate +import base64 +import imghdr +import mimetypes + +from langchain_core.messages import HumanMessage from langchain_core.language_models.chat_models import BaseChatModel -from langgraph.prebuilt import create_react_agent -from langgraph.managed import IsLastStep -from langgraph.graph.message import add_messages from langchain.chat_models import init_chat_model from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from rich.console import Console from rich.table import Table -import base64 -import imghdr -import mimetypes +from .agent import * from .input import * from .const import * from .output import * @@ -36,17 +33,6 @@ from .memory import * from .config import AppConfig -# The AgentState class is used to maintain the state of the agent during a conversation. -class AgentState(TypedDict): - # A list of messages exchanged in the conversation. - messages: Annotated[list[BaseMessage], add_messages] - # A flag indicating whether the current step is the last step in the conversation. - is_last_step: IsLastStep - # The current date and time, used for context in the conversation. - today_datetime: str - # The user's memories. - memories: str = "no memories" - async def run() -> None: """Run the LLM agent.""" args = setup_argument_parser() @@ -211,20 +197,18 @@ async def handle_conversation(args: argparse.Namespace, query: HumanMessage, extra_body=extra_body ) - prompt = ChatPromptTemplate.from_messages([ - ("system", app_config.system_prompt), - ("placeholder", "{messages}") - ]) - conversation_manager = ConversationManager(SQLITE_DB) async with AsyncSqliteSaver.from_conn_string(SQLITE_DB) as checkpointer: store = SqliteStore(SQLITE_DB) memories = await get_memories(store) formatted_memories = "\n".join(f"- {memory}" for memory in memories) - agent_executor = create_react_agent( - model, tools, state_schema=AgentState, - state_modifier=prompt, checkpointer=checkpointer, store=store + agent_executor = ReActAgent( + model=model, + system_prompt=app_config.system_prompt, + tools=tools, + checkpointer=checkpointer, + store=store ) thread_id = (await conversation_manager.get_last_id() if is_conversation_continuation @@ -241,10 +225,12 @@ async def handle_conversation(args: argparse.Namespace, query: HumanMessage, try: async for chunk in agent_executor.astream( input_messages, - stream_mode=["messages", "values"], - config={"configurable": {"thread_id": thread_id, "user_id": "myself"}, - "recursion_limit": 100} + thread_id=thread_id, + # stream_mode=["messages", "values"], + # config={"configurable": {"thread_id": thread_id, "user_id": "myself"}, + # "recursion_limit": 100} ): + # print(chunk) output.update(chunk) if not args.no_confirmations: if not output.confirm_tool_call(app_config.__dict__, chunk): From d754ad97e6e3d97986eed74aa9bb68cefd8f0c28 Mon Sep 17 00:00:00 2001 From: Adhika Setya Pramudita Date: Wed, 1 Jan 2025 17:17:35 +0700 Subject: [PATCH 2/2] refactor(agent, cli, memory): improve agent initialization and memory handling - Updated `ReActAgent` in `agent.py` to conditionally bind tools and initialize `tools_by_name` based on the presence of tools, enhancing flexibility. - Modified the `astream` method to include a recursion limit in the configuration, improving stability during streaming. - Adjusted memory saving functionality in `memory.py` to provide a default user ID of "myself" when not specified, ensuring consistent behavior. - Cleaned up commented-out code in `cli.py` related to conversation handling, streamlining the codebase. --- src/mcp_client_cli/agent.py | 14 +++++++++++--- src/mcp_client_cli/cli.py | 4 ---- src/mcp_client_cli/memory.py | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/mcp_client_cli/agent.py b/src/mcp_client_cli/agent.py index 294de9b..d5e3999 100644 --- a/src/mcp_client_cli/agent.py +++ b/src/mcp_client_cli/agent.py @@ -38,16 +38,24 @@ def __init__(self, model: BaseChatModel, tools: list[Tool], system_prompt: str, ("system", system_prompt), ("placeholder", "{messages}") ]) - self.model = model.bind_tools(tools) + if tools: + self.model = model.bind_tools(tools) + self.tools_by_name = {tool.name: tool for tool in tools} + else: + self.model = model + self.tools_by_name = {} self.chain = self.prompt | self.model self.tools = tools - self.tools_by_name = {tool.name: tool for tool in tools} self.checkpointer = checkpointer self.store = store self.create_graph() async def astream(self, input: AgentState, thread_id: str): - async for chunk in self.graph.astream(input, stream_mode=["messages", "values"], config={"configurable": {"thread_id": thread_id}}): + async for chunk in self.graph.astream( + input, + stream_mode=["messages", "values"], + config={"configurable": {"thread_id": thread_id}, "recursion_limit": 100}, + ): yield chunk def create_graph(self): diff --git a/src/mcp_client_cli/cli.py b/src/mcp_client_cli/cli.py index 0400ea7..bff2266 100755 --- a/src/mcp_client_cli/cli.py +++ b/src/mcp_client_cli/cli.py @@ -226,11 +226,7 @@ async def handle_conversation(args: argparse.Namespace, query: HumanMessage, async for chunk in agent_executor.astream( input_messages, thread_id=thread_id, - # stream_mode=["messages", "values"], - # config={"configurable": {"thread_id": thread_id, "user_id": "myself"}, - # "recursion_limit": 100} ): - # print(chunk) output.update(chunk) if not args.no_confirmations: if not output.confirm_tool_call(app_config.__dict__, chunk): diff --git a/src/mcp_client_cli/memory.py b/src/mcp_client_cli/memory.py index 2e2c960..be8c3ef 100644 --- a/src/mcp_client_cli/memory.py +++ b/src/mcp_client_cli/memory.py @@ -42,7 +42,7 @@ @tool async def save_memory(memories: List[str], *, config: RunnableConfig, store: Annotated[BaseStore, InjectedStore()]) -> str: '''Save the given memory for the current user. Do not save duplicate memories.''' - user_id = config.get("configurable", {}).get("user_id") + user_id = config.get("configurable", {}).get("user_id", "myself") namespace = ("memories", user_id) for memory in memories: id = uuid.uuid4().hex