Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/agent/agentq.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,15 @@ class AgentQ(ABC):
Abstract class for the AI agent. Extend the class to create a new agent with specific tools and prompts.

"""
def __init__(self, model: Model, agent_prompt: str, tools=None):
def __init__(self, model: Model, agent_prompt: str, tools=None, **kwargs):
if tools is None:
tools = []
self.agent: CodeAgent = None
self.model = model

for key, value in kwargs.items():
agent_prompt = agent_prompt.replace("{{" + key + "}}", value)

self.system_prompt = system_prompt + "\n" + agent_prompt
self.tools = tools

Expand All @@ -70,7 +74,7 @@ def init_agent(self):
tools=self.tools,
model=self.model,
add_base_tools=False,
additional_authorized_imports=["pytest", "time"]
additional_authorized_imports=["pytest", "time", "json"]
)

def do(self, task: str, force_regenerate: bool = False) -> any:
Expand All @@ -84,7 +88,7 @@ def do(self, task: str, force_regenerate: bool = False) -> any:
if not force_regenerate and task_code_exists(task):
python_code_snippet = get_task_code(task)
return self.__perform_code(python_code_snippet)
result = self.agent.run(task=self.system_prompt + "\n" + task, single_step=False)
result = self.agent.run(task=self.system_prompt + "\n\n" + task, single_step=False)
code_identifier = save_code()
if code_identifier:
add_to_store(code_identifier, task)
Expand Down
10 changes: 9 additions & 1 deletion src/agent/rest_agent/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,18 @@
- If the API requires authentication (e.g., API keys, Bearer tokens), extract the authentication method from the Swagger JSON and ensure the appropriate headers or parameters are included in the request.

### 8. Returned Result:
- Always return a result that contains the response body as a text string.
- Always return a result that contains the response body as a text string.
- After receiving the response to the request, only return the response body as text string. Do not continue performing any additional actions.

### 9. What to do in case of an error in the response:
- Even if the response status code is other than 200 - you only need to return the original response body as text.
- Even if the response body contains the word 'error' - you only need to return the original response body as text.
- Do not attempt to modify the request in order to get a different response. Just return the original response bpdy as text.


Use this Base URL: {{base_url}}

Use this Swagger JSON:

{{swagger_json}}
"""
7 changes: 2 additions & 5 deletions src/agent/rest_agent/rest_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@

class RestAgent(AgentQ):

def __init__(self, base_url: str, swagger_json_file: str, model: Model):
super().__init__(model=model, agent_prompt=agent_prompt, tools=tools)
self.base_url = base_url
self.swagger_json = Path(swagger_json_file).read_text()
def __init__(self, model: Model, base_url: str, swagger_json: str):
super().__init__(model=model, agent_prompt=agent_prompt, tools=tools, base_url=base_url, swagger_json=swagger_json)

def get_code_imports(self):
return "from src.agent.rest_agent.rest_tools import *"

def do(self, task: str, force_regenerate: bool = False) -> any:
task = f"{task}\n Use the following Base URL: {self.base_url}\n Use the following Swagger JSON:\n {self.swagger_json}"
log.debug(f"Running task: {task}")
return super().do(task=task, force_regenerate=force_regenerate)
39 changes: 36 additions & 3 deletions src/agent/rest_agent/rest_tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Tuple
import json

import requests

from smolagents import tool


Expand Down Expand Up @@ -37,4 +36,38 @@ def post_request(url: str, request_body: dict, headers: dict) -> str:
return response.text


tools = [get_request, post_request]
@tool
def find_json_key(json_str: str, key: str) -> any:
"""
Searches for a given key in a JSON string and returns its value.
Supports nested dictionaries and lists.

Args:
json_str: A string representation of a JSON object.
key: The key to search for.

:return: The value of the key if found, otherwise None.
"""
def search(data, key):
if isinstance(data, dict):
if key in data:
return data[key]
for value in data.values():
result = search(value, key)
if result is not None:
return result
elif isinstance(data, list):
for item in data:
result = search(item, key)
if result is not None:
return result
return None

try:
json_data = json.loads(json_str)
return search(json_data, key)
except json.JSONDecodeError:
raise ValueError("Invalid JSON string")


tools = [get_request, post_request, find_json_key]
13 changes: 10 additions & 3 deletions tests/examples/rest_agent_tests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from pathlib import Path

import pytest
from dotenv import load_dotenv
from smolagents import LiteLLMModel
Expand All @@ -10,18 +12,23 @@
@pytest.fixture
def q():
load_dotenv()
model = LiteLLMModel(model_id="gpt-4o-mini", api_key=os.getenv("OPEN_AI_API_KEY"))
model = LiteLLMModel(model_id="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY"))
base_url = "https://petstore.swagger.io/v2"
agent = RestAgent(base_url=base_url, swagger_json_file="../../swagger/swagger_petstore.json", model=model)
swagger_json = Path("../../swagger/swagger_petstore.json").read_text()
agent = RestAgent(model=model, base_url=base_url, swagger_json=swagger_json)
agent.init_agent()
return agent


def test_get_pet_by_id(q: AgentQ):
result = q.do("Get pet with ID=2")
result = q.do("Get pet with ID=3")
print(result)


def test_add_new_pet(q: AgentQ):
result = q.do("Add a new pet with the name 'Yoshi'")
print(result)

def test_add_new_pet_and_get_it_by_id(q: AgentQ):
result = q.do("Add a new pet with a random pet name, then find in the response JSON the ID of the newly created pet and get the pet details from the API")
print(result)