diff --git a/games/2048/2048_agent.py b/games/2048/2048_agent.py index 9fb03668..9a47e883 100644 --- a/games/2048/2048_agent.py +++ b/games/2048/2048_agent.py @@ -4,7 +4,7 @@ import argparse import numpy as np from tools.utils import encode_image, log_output -from tools.serving.api_providers import anthropic_completion, openai_completion, gemini_completion +from tools.serving.api_providers import anthropic_completion, openai_completion, gemini_completion, ollama_completion import subprocess import multiprocessing import re @@ -142,6 +142,8 @@ def get_best_move(system_prompt, api_provider, model_name, move_history): response = openai_completion(system_prompt, model_name, base64_image, move_prompt) elif api_provider == "gemini": response = gemini_completion(system_prompt, model_name, base64_image, move_prompt) + elif api_provider == "ollama": + response = ollama_completion(system_prompt, model_name, base64_image, move_prompt) else: raise NotImplementedError(f"API provider '{api_provider}' is not supported.") diff --git a/requirements.txt b/requirements.txt index 06030b58..5d714a27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ anthropic==0.49.0 mss==10.0.0 numpy==1.24.4 +ollama==0.4.7 openai==1.65.4 opencv_python==4.8.1.78 opencv_python_headless==4.11.0.86 diff --git a/tools/serving/api_providers.py b/tools/serving/api_providers.py index 3827a8e2..694ac578 100644 --- a/tools/serving/api_providers.py +++ b/tools/serving/api_providers.py @@ -4,6 +4,41 @@ import anthropic import google.generativeai as genai from google.genai import types +from ollama import Client +import time + +def ollama_completion(system_prompt, model_name, base64_image, prompt): + client = Client( + host='http://127.0.0.1:11434' + ) + error_try = 5 + while error_try >= 0: + try: + response = client.chat( + model=model_name, + messages=[ + { + 'role': 'system', + 'content': system_prompt + }, + { + 'role': 'user', + 'content': prompt, + 'images': [base64_image] + } + ] + ) + break + except Exception as e: + if error_try <= 0: + print(f"[Ollama API] Error: {e}, aborting...") + raise e + print(f"[Ollama API] Error: {e}, retrying...") + error_try -= 1 + time.sleep(1) + + generated_code_str = response['message']['content'] + return generated_code_str def openai_completion(system_prompt, model_name, base64_image, prompt, temperature=0): client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))