diff --git a/README.md b/README.md index ca25df8b..d7391d6f 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,10 @@ Currently we support gaming agents based on the following models: - Deepseek: - chat (V3) - reasoner (R1) +- Ollama: + - deepseek-r1:8b + - llama3.1:8b + - gemma3:12b Set your API keys with: @@ -231,9 +235,20 @@ You should be able to see the first level: 2. Open another terminal screen, launch your agent in terminal with -``` +```shell python games/sokoban/sokoban_agent.py ``` + +Gemini example +```shell +python games/sokoban/sokoban_agent.py --api_provider gemini --model_name gemini-2.0-flash --modality text-only --starting_level 1 +``` + +Ollama example +```shell +python games/sokoban/sokoban_agent.py --api_provider ollama --model_name gemma3:12b --modality text-only --starting_level 1 +``` + #### Other command options ``` --api_provider: API provider to use. diff --git a/games/sokoban/workers.py b/games/sokoban/workers.py index 27755f92..a866dc31 100644 --- a/games/sokoban/workers.py +++ b/games/sokoban/workers.py @@ -4,7 +4,9 @@ import numpy as np from tools.utils import encode_image, log_output, get_annotate_img -from tools.serving.api_providers import anthropic_completion, anthropic_text_completion, openai_completion, openai_text_reasoning_completion, gemini_completion, gemini_text_completion, deepseek_text_reasoning_completion +from tools.serving.api_providers import anthropic_completion, anthropic_text_completion, openai_completion, \ + openai_text_reasoning_completion, gemini_completion, gemini_text_completion, deepseek_text_reasoning_completion, \ + ollama_text_completion import re import json @@ -173,6 +175,8 @@ def sokoban_worker(system_prompt, api_provider, model_name, response = gemini_completion(system_prompt, model_name, base64_image, prompt) elif api_provider == "deepseek": response = deepseek_text_reasoning_completion(system_prompt, model_name, prompt) + elif api_provider == "ollama": + response = ollama_text_completion(system_prompt, model_name, prompt) else: raise NotImplementedError(f"API provider: {api_provider} is not supported.") diff --git a/requirements.txt b/requirements.txt index 06030b58..e38a424f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,5 +10,4 @@ pygame==2.6.1 PyGetWindow==0.0.9 google-generativeai==0.8.4 google.genai==1.5.0 - - +ollama==0.4.7 \ No newline at end of file diff --git a/tools/serving/api_providers.py b/tools/serving/api_providers.py index 5664f464..f63079ef 100644 --- a/tools/serving/api_providers.py +++ b/tools/serving/api_providers.py @@ -3,7 +3,8 @@ from openai import OpenAI import anthropic import google.generativeai as genai -from google.generativeai import types +from ollama import chat +from ollama import ChatResponse def anthropic_completion(system_prompt, model_name, base64_image, prompt, thinking=False): print(f"anthropic vision-text activated... thinking: {thinking}") @@ -435,17 +436,9 @@ def openai_multiimage_completion(system_prompt, model_name, prompt, list_content def gemini_text_completion(system_prompt, model_name, prompt): genai.configure(api_key=os.getenv("GEMINI_API_KEY")) model = genai.GenerativeModel(model_name=model_name) - messages = [ prompt, ] - - try: - response = model.generate_content( - messages - ) - except Exception as e: - print(f"error: {e}") try: response = model.generate_content(messages) @@ -505,35 +498,6 @@ def anthropic_text_completion(system_prompt, model_name, prompt, thinking=False) return generated_str -def gemini_text_completion(system_prompt, model_name, prompt): - genai.configure(api_key=os.getenv("GEMINI_API_KEY")) - model = genai.GenerativeModel(model_name=model_name) - - messages = [ - prompt, - ] - - try: - response = model.generate_content( - messages - ) - except Exception as e: - print(f"error: {e}") - - try: - response = model.generate_content(messages) - - # Ensure response is valid and contains candidates - if not response or not hasattr(response, "candidates") or not response.candidates: - print("Warning: Empty or invalid response") - return "" - - return response.text # Access response.text safely - - except Exception as e: - print(f"Error: {e}") - return "" - def gemini_completion(system_prompt, model_name, base64_image, prompt): genai.configure(api_key=os.getenv("GEMINI_API_KEY")) model = genai.GenerativeModel(model_name=model_name) @@ -631,3 +595,23 @@ def deepseek_text_reasoning_completion(system_prompt, model_name, prompt): # generated_str = response.choices[0].message.content print(content) return content + + +def ollama_text_completion(system_prompt, model_name, prompt): + try: + response: ChatResponse = chat(model=model_name, messages=[ + { + 'role': 'user', + 'content': prompt, + }, + ]) + + if not response or not hasattr(response, "message") or not response.message: + print("Warning: Empty or invalid response") + return "" + + return response.message.content + + except Exception as e: + print(f"Error: {e}") + return "" \ No newline at end of file