From 83cc9a3d194e8c02073bd3b2019c1f77b54aafbe Mon Sep 17 00:00:00 2001 From: Rony Byalsky Date: Mon, 27 Jan 2025 16:13:03 +0200 Subject: [PATCH] REST Agent - WIP2 --- src/agent/agentq.py | 10 +++++--- src/agent/rest_agent/prompt.py | 10 +++++++- src/agent/rest_agent/rest_agent.py | 7 ++---- src/agent/rest_agent/rest_tools.py | 39 +++++++++++++++++++++++++++--- tests/examples/rest_agent_tests.py | 13 +++++++--- 5 files changed, 64 insertions(+), 15 deletions(-) diff --git a/src/agent/agentq.py b/src/agent/agentq.py index 6eb0b5f..13017d8 100644 --- a/src/agent/agentq.py +++ b/src/agent/agentq.py @@ -48,11 +48,15 @@ class AgentQ(ABC): Abstract class for the AI agent. Extend the class to create a new agent with specific tools and prompts. """ - def __init__(self, model: Model, agent_prompt: str, tools=None): + def __init__(self, model: Model, agent_prompt: str, tools=None, **kwargs): if tools is None: tools = [] self.agent: CodeAgent = None self.model = model + + for key, value in kwargs.items(): + agent_prompt = agent_prompt.replace("{{" + key + "}}", value) + self.system_prompt = system_prompt + "\n" + agent_prompt self.tools = tools @@ -70,7 +74,7 @@ def init_agent(self): tools=self.tools, model=self.model, add_base_tools=False, - additional_authorized_imports=["pytest", "time"] + additional_authorized_imports=["pytest", "time", "json"] ) def do(self, task: str, force_regenerate: bool = False) -> any: @@ -84,7 +88,7 @@ def do(self, task: str, force_regenerate: bool = False) -> any: if not force_regenerate and task_code_exists(task): python_code_snippet = get_task_code(task) return self.__perform_code(python_code_snippet) - result = self.agent.run(task=self.system_prompt + "\n" + task, single_step=False) + result = self.agent.run(task=self.system_prompt + "\n\n" + task, single_step=False) code_identifier = save_code() if code_identifier: add_to_store(code_identifier, task) diff --git a/src/agent/rest_agent/prompt.py b/src/agent/rest_agent/prompt.py index 8be3a24..a9ab8cf 100644 --- a/src/agent/rest_agent/prompt.py +++ b/src/agent/rest_agent/prompt.py @@ -34,10 +34,18 @@ - If the API requires authentication (e.g., API keys, Bearer tokens), extract the authentication method from the Swagger JSON and ensure the appropriate headers or parameters are included in the request. ### 8. Returned Result: -- Always return a result that contains the response body as a text string. +- Always return a result that contains the response body as a text string. +- After receiving the response to the request, only return the response body as text string. Do not continue performing any additional actions. ### 9. What to do in case of an error in the response: - Even if the response status code is other than 200 - you only need to return the original response body as text. - Even if the response body contains the word 'error' - you only need to return the original response body as text. - Do not attempt to modify the request in order to get a different response. Just return the original response bpdy as text. + + +Use this Base URL: {{base_url}} + +Use this Swagger JSON: + +{{swagger_json}} """ \ No newline at end of file diff --git a/src/agent/rest_agent/rest_agent.py b/src/agent/rest_agent/rest_agent.py index b552914..46e35da 100644 --- a/src/agent/rest_agent/rest_agent.py +++ b/src/agent/rest_agent/rest_agent.py @@ -11,15 +11,12 @@ class RestAgent(AgentQ): - def __init__(self, base_url: str, swagger_json_file: str, model: Model): - super().__init__(model=model, agent_prompt=agent_prompt, tools=tools) - self.base_url = base_url - self.swagger_json = Path(swagger_json_file).read_text() + def __init__(self, model: Model, base_url: str, swagger_json: str): + super().__init__(model=model, agent_prompt=agent_prompt, tools=tools, base_url=base_url, swagger_json=swagger_json) def get_code_imports(self): return "from src.agent.rest_agent.rest_tools import *" def do(self, task: str, force_regenerate: bool = False) -> any: - task = f"{task}\n Use the following Base URL: {self.base_url}\n Use the following Swagger JSON:\n {self.swagger_json}" log.debug(f"Running task: {task}") return super().do(task=task, force_regenerate=force_regenerate) diff --git a/src/agent/rest_agent/rest_tools.py b/src/agent/rest_agent/rest_tools.py index fa04ad0..868ad10 100644 --- a/src/agent/rest_agent/rest_tools.py +++ b/src/agent/rest_agent/rest_tools.py @@ -1,7 +1,6 @@ -from typing import Tuple +import json import requests - from smolagents import tool @@ -37,4 +36,38 @@ def post_request(url: str, request_body: dict, headers: dict) -> str: return response.text -tools = [get_request, post_request] \ No newline at end of file +@tool +def find_json_key(json_str: str, key: str) -> any: + """ + Searches for a given key in a JSON string and returns its value. + Supports nested dictionaries and lists. + + Args: + json_str: A string representation of a JSON object. + key: The key to search for. + + :return: The value of the key if found, otherwise None. + """ + def search(data, key): + if isinstance(data, dict): + if key in data: + return data[key] + for value in data.values(): + result = search(value, key) + if result is not None: + return result + elif isinstance(data, list): + for item in data: + result = search(item, key) + if result is not None: + return result + return None + + try: + json_data = json.loads(json_str) + return search(json_data, key) + except json.JSONDecodeError: + raise ValueError("Invalid JSON string") + + +tools = [get_request, post_request, find_json_key] \ No newline at end of file diff --git a/tests/examples/rest_agent_tests.py b/tests/examples/rest_agent_tests.py index a89447a..254bc98 100644 --- a/tests/examples/rest_agent_tests.py +++ b/tests/examples/rest_agent_tests.py @@ -1,4 +1,6 @@ import os +from pathlib import Path + import pytest from dotenv import load_dotenv from smolagents import LiteLLMModel @@ -10,18 +12,23 @@ @pytest.fixture def q(): load_dotenv() - model = LiteLLMModel(model_id="gpt-4o-mini", api_key=os.getenv("OPEN_AI_API_KEY")) + model = LiteLLMModel(model_id="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY")) base_url = "https://petstore.swagger.io/v2" - agent = RestAgent(base_url=base_url, swagger_json_file="../../swagger/swagger_petstore.json", model=model) + swagger_json = Path("../../swagger/swagger_petstore.json").read_text() + agent = RestAgent(model=model, base_url=base_url, swagger_json=swagger_json) agent.init_agent() return agent def test_get_pet_by_id(q: AgentQ): - result = q.do("Get pet with ID=2") + result = q.do("Get pet with ID=3") print(result) def test_add_new_pet(q: AgentQ): result = q.do("Add a new pet with the name 'Yoshi'") + print(result) + +def test_add_new_pet_and_get_it_by_id(q: AgentQ): + result = q.do("Add a new pet with a random pet name, then find in the response JSON the ID of the newly created pet and get the pet details from the API") print(result) \ No newline at end of file