|
| 1 | +import json |
| 2 | +import sys |
| 3 | +import os |
| 4 | +import pytest |
| 5 | + |
| 6 | +# Import the classes we want to test. |
| 7 | +from app.model.gpt import * |
| 8 | +from app.data_structures import FunctionCallIntent |
| 9 | +from app.model import common |
| 10 | +from openai import BadRequestError |
| 11 | + |
| 12 | + |
| 13 | +# Dummy classes to simulate the OpenAI response. |
| 14 | +class DummyUsage: |
| 15 | + prompt_tokens = 1 |
| 16 | + completion_tokens = 2 |
| 17 | + |
| 18 | +class DummyMessage: |
| 19 | + def __init__(self, content="Test response", tool_calls=None): |
| 20 | + self.content = content |
| 21 | + self.tool_calls = tool_calls |
| 22 | + |
| 23 | +class DummyChoice: |
| 24 | + def __init__(self): |
| 25 | + self.message = DummyMessage() |
| 26 | + |
| 27 | +class DummyResponse: |
| 28 | + def __init__(self): |
| 29 | + self.usage = DummyUsage() |
| 30 | + self.choices = [DummyChoice()] |
| 31 | + |
| 32 | +class DummyCompletions: |
| 33 | + last_kwargs = {} # initialize as a class attribute |
| 34 | + |
| 35 | + def create(self, *args, **kwargs): |
| 36 | + DummyCompletions.last_kwargs = kwargs # capture the kwargs passed in |
| 37 | + return DummyResponse() |
| 38 | + |
| 39 | +# Create a dummy client that raises BadRequestError. |
| 40 | +class DummyBadRequestCompletions: |
| 41 | + def create(self, *args, **kwargs): |
| 42 | + raise BadRequestError("error", code="context_length_exceeded") |
| 43 | + |
| 44 | +class DummyBadRequestClientChat: |
| 45 | + completions = DummyBadRequestCompletions() |
| 46 | + |
| 47 | +class DummyBadRequestClient: |
| 48 | + chat = DummyBadRequestClientChat() |
| 49 | + |
| 50 | +# Dummy client chat now includes a completions attribute. |
| 51 | +class DummyClientChat: |
| 52 | + completions = DummyCompletions() |
| 53 | + |
| 54 | +# Dummy client with a chat attribute. |
| 55 | +class DummyClient: |
| 56 | + chat = DummyClientChat() |
| 57 | + |
| 58 | +# Dummy thread cost container to capture cost updates. |
| 59 | +class DummyThreadCost: |
| 60 | + process_cost = 0.0 |
| 61 | + process_input_tokens = 0 |
| 62 | + process_output_tokens = 0 |
| 63 | + |
| 64 | +# Dummy tool call classes for testing extract_resp_func_calls. |
| 65 | +class DummyOpenaiFunction: |
| 66 | + def __init__(self, name="dummy_function", arguments='{"arg": "value"}'): |
| 67 | + self.name = name |
| 68 | + self.arguments = arguments |
| 69 | + |
| 70 | +class DummyToolCall: |
| 71 | + def __init__(self, name="dummy_function", arguments='{"arg": "value"}'): |
| 72 | + self.function = DummyOpenaiFunction(name, arguments) |
| 73 | + |
| 74 | +class DummyToolCallEmpty: |
| 75 | + def __init__(self, name="dummy_function"): |
| 76 | + # Empty string branch. |
| 77 | + self.function = DummyOpenaiFunction(name, "") |
| 78 | + |
| 79 | +class DummyToolCallInvalid: |
| 80 | + def __init__(self, name="dummy_function"): |
| 81 | + # Invalid JSON branch. |
| 82 | + self.function = DummyOpenaiFunction(name, "{invalid_json}") |
| 83 | + |
| 84 | +# Dummy OpenAI client to use in testing. |
| 85 | +class DummyOpenAI: |
| 86 | + def __init__(self, api_key): |
| 87 | + self.api_key = api_key |
| 88 | + |
| 89 | +# Dummy check_api_key function. |
| 90 | +def dummy_check_api_key(self): |
| 91 | + return "dummy-key" |
| 92 | + |
| 93 | +def dummy_check_api_key_failure(self): |
| 94 | + return "" |
| 95 | + |
| 96 | +# Dummy log_and_print to capture calls. |
| 97 | +log_and_print_called = False |
| 98 | +def dummy_log_and_print(message): |
| 99 | + global log_and_print_called |
| 100 | + log_and_print_called = True |
| 101 | + |
| 102 | +# To test sys.exit in check_api_key failure. |
| 103 | +class SysExitException(Exception): |
| 104 | + pass |
| 105 | + |
| 106 | +def dummy_sys_exit(code): |
| 107 | + raise SysExitException(f"sys.exit called with {code}") |
| 108 | + |
| 109 | + |
| 110 | +# -- Tests for setup --- |
| 111 | + |
| 112 | +def test_setup_initializes_client(monkeypatch): |
| 113 | + # Patch check_api_key to return a dummy key. |
| 114 | + monkeypatch.setattr(Gpt_o1, "check_api_key", dummy_check_api_key) |
| 115 | + # Patch the OpenAI constructor to return a DummyOpenAI instance. |
| 116 | + monkeypatch.setattr("app.model.gpt.OpenAI", lambda api_key: DummyOpenAI(api_key)) |
| 117 | + |
| 118 | + # Instantiate a model. |
| 119 | + model = Gpt_o1() |
| 120 | + # Ensure client is initially None. |
| 121 | + model.client = None |
| 122 | + |
| 123 | + # Call setup. |
| 124 | + model.setup() |
| 125 | + |
| 126 | + # Verify that client was set using our dummy OpenAI. |
| 127 | + assert model.client is not None |
| 128 | + assert isinstance(model.client, DummyOpenAI) |
| 129 | + assert model.client.api_key == "dummy-key" |
| 130 | + |
| 131 | +def test_setup_already_initialized(monkeypatch): |
| 132 | + # Patch check_api_key (should not be called if client is already set). |
| 133 | + monkeypatch.setattr(Gpt_o1, "check_api_key", dummy_check_api_key) |
| 134 | + # Create a dummy client. |
| 135 | + dummy_client = DummyOpenAI("pre-set-key") |
| 136 | + |
| 137 | + # Instantiate a model and set the client. |
| 138 | + model = Gpt_o1() |
| 139 | + model.client = dummy_client |
| 140 | + |
| 141 | + # Call setup again. |
| 142 | + model.setup() |
| 143 | + |
| 144 | + # Verify that client remains unchanged. |
| 145 | + assert model.client is dummy_client |
| 146 | + |
| 147 | + |
| 148 | +# --- Test edge cases for one model |
| 149 | + |
| 150 | +def test_singleton_behavior(monkeypatch): |
| 151 | + # Ensure that __new__ returns the same instance when __init__ is skipped. |
| 152 | + monkeypatch.setattr(OpenaiModel, "check_api_key", dummy_check_api_key) |
| 153 | + # Create an instance and attach an attribute. |
| 154 | + model1 = Gpt_o1() |
| 155 | + model1.some_attr = "initial" |
| 156 | + # Create another instance. |
| 157 | + model2 = Gpt_o1() |
| 158 | + # They should be the same object. |
| 159 | + assert model1 is model2 |
| 160 | + # __init__ should not reinitialize; the attribute remains. |
| 161 | + assert hasattr(model2, "some_attr") |
| 162 | + assert model2.some_attr == "initial" |
| 163 | + |
| 164 | +def test_check_api_key_success(monkeypatch): |
| 165 | + monkeypatch.setattr(os, "getenv", lambda key: "dummy-key") |
| 166 | + model = Gpt_o1() |
| 167 | + key = model.check_api_key() |
| 168 | + assert key == "dummy-key" |
| 169 | + |
| 170 | +def test_check_api_key_failure(monkeypatch, capsys): |
| 171 | + # Simulate missing OPENAI_KEY. |
| 172 | + monkeypatch.setattr(os, "getenv", lambda key: "") |
| 173 | + monkeypatch.setattr(sys, "exit", dummy_sys_exit) |
| 174 | + model = Gpt_o1() |
| 175 | + with pytest.raises(SysExitException): |
| 176 | + model.check_api_key() |
| 177 | + captured = capsys.readouterr().out |
| 178 | + assert "Please set the OPENAI_KEY env var" in captured |
| 179 | + |
| 180 | + |
| 181 | +def test_extract_resp_content(monkeypatch): |
| 182 | + model = Gpt_o1() |
| 183 | + # Test when content exists. |
| 184 | + msg = DummyMessage(content="Hello") |
| 185 | + assert model.extract_resp_content(msg) == "Hello" |
| 186 | + # Test when content is None. |
| 187 | + msg_none = DummyMessage(content=None) |
| 188 | + assert model.extract_resp_content(msg_none) == "" |
| 189 | + |
| 190 | +def test_extract_resp_func_calls(monkeypatch): |
| 191 | + model = Gpt_o1() |
| 192 | + # When tool_calls is None. |
| 193 | + msg = DummyMessage(tool_calls=None) |
| 194 | + assert model.extract_resp_func_calls(msg) == [] |
| 195 | + # When arguments is an empty string. |
| 196 | + dummy_call_empty = DummyToolCallEmpty() |
| 197 | + msg_empty = DummyMessage(tool_calls=[dummy_call_empty]) |
| 198 | + func_calls = model.extract_resp_func_calls(msg_empty) |
| 199 | + assert len(func_calls) == 1 |
| 200 | + assert func_calls[0].func_name == "dummy_function" |
| 201 | + assert func_calls[0].arg_values == {} |
| 202 | + # When arguments is invalid JSON. |
| 203 | + dummy_call_invalid = DummyToolCallInvalid() |
| 204 | + msg_invalid = DummyMessage(tool_calls=[dummy_call_invalid]) |
| 205 | + func_calls = model.extract_resp_func_calls(msg_invalid) |
| 206 | + assert len(func_calls) == 1 |
| 207 | + assert func_calls[0].func_name == "dummy_function" |
| 208 | + assert func_calls[0].arg_values == {} |
| 209 | + # When arguments is valid. |
| 210 | + dummy_call = DummyToolCall() |
| 211 | + msg_valid = DummyMessage(tool_calls=[dummy_call]) |
| 212 | + func_calls = model.extract_resp_func_calls(msg_valid) |
| 213 | + assert len(func_calls) == 1 |
| 214 | + assert func_calls[0].func_name == "dummy_function" |
| 215 | + assert func_calls[0].arg_values == {"arg": "value"} |
| 216 | + |
| 217 | + |
| 218 | +# --- Parametrized Test Over All Models --- |
| 219 | + |
| 220 | +@pytest.mark.parametrize("model_class, expected_name", [ |
| 221 | + ("Gpt_o1mini", "o1-mini"), |
| 222 | + ("Gpt_o1", "o1-2024-12-17"), |
| 223 | + ("Gpt4o_20240806", "gpt-4o-2024-08-06"), |
| 224 | + ("Gpt4o_20240513", "gpt-4o-2024-05-13"), |
| 225 | + ("Gpt4_Turbo20240409", "gpt-4-turbo-2024-04-09"), |
| 226 | + ("Gpt4_0125Preview", "gpt-4-0125-preview"), |
| 227 | + ("Gpt4_1106Preview", "gpt-4-1106-preview"), |
| 228 | + ("Gpt35_Turbo0125", "gpt-3.5-turbo-0125"), |
| 229 | + ("Gpt35_Turbo1106", "gpt-3.5-turbo-1106"), |
| 230 | + ("Gpt35_Turbo16k_0613", "gpt-3.5-turbo-16k-0613"), |
| 231 | + ("Gpt35_Turbo0613", "gpt-3.5-turbo-0613"), |
| 232 | + ("Gpt4_0613", "gpt-4-0613"), |
| 233 | + ("Gpt4o_mini_20240718", "gpt-4o-mini-2024-07-18"), |
| 234 | +]) |
| 235 | +def test_openai_model_call(monkeypatch, model_class, expected_name): |
| 236 | + # Dynamically import the model class from the gpt module. |
| 237 | + from app.model import gpt |
| 238 | + model_cls = getattr(gpt, model_class) |
| 239 | + |
| 240 | + # Patch necessary methods. |
| 241 | + monkeypatch.setattr(model_cls, "check_api_key", dummy_check_api_key) |
| 242 | + monkeypatch.setattr(model_cls, "calc_cost", lambda self, inp, out: 0.5) |
| 243 | + monkeypatch.setattr(common, "thread_cost", DummyThreadCost()) |
| 244 | + |
| 245 | + # Instantiate the model and set the dummy client. |
| 246 | + model = model_cls() |
| 247 | + model.client = DummyClient() |
| 248 | + |
| 249 | + # Prepare a dummy messages list to simulate a user call. |
| 250 | + messages = [{"role": "user", "content": "Hello"}] |
| 251 | + |
| 252 | + # Call the model's "call" method. |
| 253 | + result = model.call(messages) |
| 254 | + content, raw_tool_calls, func_call_intents, cost, input_tokens, output_tokens = result |
| 255 | + |
| 256 | + # Verify the model's name and basic call flow. |
| 257 | + assert model.name == expected_name |
| 258 | + assert content == "Test response" |
| 259 | + assert raw_tool_calls is None # Because DummyMessage.tool_calls was None. |
| 260 | + assert func_call_intents == [] # No tool calls provided. |
| 261 | + assert cost == 0.5 |
| 262 | + assert input_tokens == 1 |
| 263 | + assert output_tokens == 2 |
| 264 | + |
| 265 | + # Check that our dummy thread cost was updated. |
| 266 | + assert common.thread_cost.process_cost == 0.5 |
| 267 | + assert common.thread_cost.process_input_tokens == 1 |
| 268 | + assert common.thread_cost.process_output_tokens == 2 |
| 269 | + |
| 270 | + # Test extract_resp_content separately. |
| 271 | + dummy_msg = DummyMessage() |
| 272 | + extracted_content = model.extract_resp_content(dummy_msg) |
| 273 | + assert extracted_content == "Test response" |
| 274 | + |
| 275 | + # Test extract_resp_func_calls by simulating a tool call. |
| 276 | + dummy_msg.tool_calls = [DummyToolCall()] |
| 277 | + func_calls = model.extract_resp_func_calls(dummy_msg) |
| 278 | + assert len(func_calls) == 1 |
| 279 | + # Check that the function call intent is correctly extracted. |
| 280 | + assert func_calls[0].func_name == "dummy_function" |
| 281 | + assert func_calls[0].arg_values == {"arg": "value"} |
| 282 | + |
| 283 | +def test_call_single_tool_branch(monkeypatch): |
| 284 | + # Patch check_api_key to return a dummy key. |
| 285 | + monkeypatch.setattr(OpenaiModel, "check_api_key", dummy_check_api_key) |
| 286 | + # Patch calc_cost to return a fixed cost. |
| 287 | + monkeypatch.setattr(OpenaiModel, "calc_cost", lambda self, inp, out: 0.5) |
| 288 | + # Patch common.thread_cost with our dummy instance. |
| 289 | + monkeypatch.setattr(common, "thread_cost", DummyThreadCost()) |
| 290 | + |
| 291 | + # Instantiate a model (using Gpt_o1 as an example) and set the dummy client. |
| 292 | + from app.model.gpt import Gpt_o1 # import within test to avoid conflicts |
| 293 | + model = Gpt_o1() |
| 294 | + model.client = DummyClient() |
| 295 | + |
| 296 | + # Prepare a dummy messages list to simulate a user call. |
| 297 | + messages = [{"role": "user", "content": "Hello"}] |
| 298 | + # Prepare tools with exactly one tool. |
| 299 | + tools = [{"function": {"name": "dummy_tool"}}] |
| 300 | + |
| 301 | + # Call the model's "call" method. |
| 302 | + result = model.call(messages, tools=tools, temperature=1.0) |
| 303 | + |
| 304 | + # Access the kwargs passed to DummyCompletions.create. |
| 305 | + kw = DummyCompletions.last_kwargs |
| 306 | + # Assert that tool_choice was added. |
| 307 | + assert "tool_choice" in kw |
| 308 | + # Check that the tool_choice has the expected structure. |
| 309 | + assert kw["tool_choice"]["type"] == "function" |
| 310 | + assert kw["tool_choice"]["function"]["name"] == "dummy_tool" |
0 commit comments