Skip to content

Commit e1b0272

Browse files
author
WangGLJoseph
committed
add test for gpt.py
1 parent f78c8c8 commit e1b0272

File tree

1 file changed

+310
-0
lines changed

1 file changed

+310
-0
lines changed

test/app/model/test_gpt.py

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
import json
2+
import sys
3+
import os
4+
import pytest
5+
6+
# Import the classes we want to test.
7+
from app.model.gpt import *
8+
from app.data_structures import FunctionCallIntent
9+
from app.model import common
10+
from openai import BadRequestError
11+
12+
13+
# Dummy classes to simulate the OpenAI response.
14+
class DummyUsage:
15+
prompt_tokens = 1
16+
completion_tokens = 2
17+
18+
class DummyMessage:
19+
def __init__(self, content="Test response", tool_calls=None):
20+
self.content = content
21+
self.tool_calls = tool_calls
22+
23+
class DummyChoice:
24+
def __init__(self):
25+
self.message = DummyMessage()
26+
27+
class DummyResponse:
28+
def __init__(self):
29+
self.usage = DummyUsage()
30+
self.choices = [DummyChoice()]
31+
32+
class DummyCompletions:
33+
last_kwargs = {} # initialize as a class attribute
34+
35+
def create(self, *args, **kwargs):
36+
DummyCompletions.last_kwargs = kwargs # capture the kwargs passed in
37+
return DummyResponse()
38+
39+
# Create a dummy client that raises BadRequestError.
40+
class DummyBadRequestCompletions:
41+
def create(self, *args, **kwargs):
42+
raise BadRequestError("error", code="context_length_exceeded")
43+
44+
class DummyBadRequestClientChat:
45+
completions = DummyBadRequestCompletions()
46+
47+
class DummyBadRequestClient:
48+
chat = DummyBadRequestClientChat()
49+
50+
# Dummy client chat now includes a completions attribute.
51+
class DummyClientChat:
52+
completions = DummyCompletions()
53+
54+
# Dummy client with a chat attribute.
55+
class DummyClient:
56+
chat = DummyClientChat()
57+
58+
# Dummy thread cost container to capture cost updates.
59+
class DummyThreadCost:
60+
process_cost = 0.0
61+
process_input_tokens = 0
62+
process_output_tokens = 0
63+
64+
# Dummy tool call classes for testing extract_resp_func_calls.
65+
class DummyOpenaiFunction:
66+
def __init__(self, name="dummy_function", arguments='{"arg": "value"}'):
67+
self.name = name
68+
self.arguments = arguments
69+
70+
class DummyToolCall:
71+
def __init__(self, name="dummy_function", arguments='{"arg": "value"}'):
72+
self.function = DummyOpenaiFunction(name, arguments)
73+
74+
class DummyToolCallEmpty:
75+
def __init__(self, name="dummy_function"):
76+
# Empty string branch.
77+
self.function = DummyOpenaiFunction(name, "")
78+
79+
class DummyToolCallInvalid:
80+
def __init__(self, name="dummy_function"):
81+
# Invalid JSON branch.
82+
self.function = DummyOpenaiFunction(name, "{invalid_json}")
83+
84+
# Dummy OpenAI client to use in testing.
85+
class DummyOpenAI:
86+
def __init__(self, api_key):
87+
self.api_key = api_key
88+
89+
# Dummy check_api_key function.
90+
def dummy_check_api_key(self):
91+
return "dummy-key"
92+
93+
def dummy_check_api_key_failure(self):
94+
return ""
95+
96+
# Dummy log_and_print to capture calls.
97+
log_and_print_called = False
98+
def dummy_log_and_print(message):
99+
global log_and_print_called
100+
log_and_print_called = True
101+
102+
# To test sys.exit in check_api_key failure.
103+
class SysExitException(Exception):
104+
pass
105+
106+
def dummy_sys_exit(code):
107+
raise SysExitException(f"sys.exit called with {code}")
108+
109+
110+
# -- Tests for setup ---
111+
112+
def test_setup_initializes_client(monkeypatch):
113+
# Patch check_api_key to return a dummy key.
114+
monkeypatch.setattr(Gpt_o1, "check_api_key", dummy_check_api_key)
115+
# Patch the OpenAI constructor to return a DummyOpenAI instance.
116+
monkeypatch.setattr("app.model.gpt.OpenAI", lambda api_key: DummyOpenAI(api_key))
117+
118+
# Instantiate a model.
119+
model = Gpt_o1()
120+
# Ensure client is initially None.
121+
model.client = None
122+
123+
# Call setup.
124+
model.setup()
125+
126+
# Verify that client was set using our dummy OpenAI.
127+
assert model.client is not None
128+
assert isinstance(model.client, DummyOpenAI)
129+
assert model.client.api_key == "dummy-key"
130+
131+
def test_setup_already_initialized(monkeypatch):
132+
# Patch check_api_key (should not be called if client is already set).
133+
monkeypatch.setattr(Gpt_o1, "check_api_key", dummy_check_api_key)
134+
# Create a dummy client.
135+
dummy_client = DummyOpenAI("pre-set-key")
136+
137+
# Instantiate a model and set the client.
138+
model = Gpt_o1()
139+
model.client = dummy_client
140+
141+
# Call setup again.
142+
model.setup()
143+
144+
# Verify that client remains unchanged.
145+
assert model.client is dummy_client
146+
147+
148+
# --- Test edge cases for one model
149+
150+
def test_singleton_behavior(monkeypatch):
151+
# Ensure that __new__ returns the same instance when __init__ is skipped.
152+
monkeypatch.setattr(OpenaiModel, "check_api_key", dummy_check_api_key)
153+
# Create an instance and attach an attribute.
154+
model1 = Gpt_o1()
155+
model1.some_attr = "initial"
156+
# Create another instance.
157+
model2 = Gpt_o1()
158+
# They should be the same object.
159+
assert model1 is model2
160+
# __init__ should not reinitialize; the attribute remains.
161+
assert hasattr(model2, "some_attr")
162+
assert model2.some_attr == "initial"
163+
164+
def test_check_api_key_success(monkeypatch):
165+
monkeypatch.setattr(os, "getenv", lambda key: "dummy-key")
166+
model = Gpt_o1()
167+
key = model.check_api_key()
168+
assert key == "dummy-key"
169+
170+
def test_check_api_key_failure(monkeypatch, capsys):
171+
# Simulate missing OPENAI_KEY.
172+
monkeypatch.setattr(os, "getenv", lambda key: "")
173+
monkeypatch.setattr(sys, "exit", dummy_sys_exit)
174+
model = Gpt_o1()
175+
with pytest.raises(SysExitException):
176+
model.check_api_key()
177+
captured = capsys.readouterr().out
178+
assert "Please set the OPENAI_KEY env var" in captured
179+
180+
181+
def test_extract_resp_content(monkeypatch):
182+
model = Gpt_o1()
183+
# Test when content exists.
184+
msg = DummyMessage(content="Hello")
185+
assert model.extract_resp_content(msg) == "Hello"
186+
# Test when content is None.
187+
msg_none = DummyMessage(content=None)
188+
assert model.extract_resp_content(msg_none) == ""
189+
190+
def test_extract_resp_func_calls(monkeypatch):
191+
model = Gpt_o1()
192+
# When tool_calls is None.
193+
msg = DummyMessage(tool_calls=None)
194+
assert model.extract_resp_func_calls(msg) == []
195+
# When arguments is an empty string.
196+
dummy_call_empty = DummyToolCallEmpty()
197+
msg_empty = DummyMessage(tool_calls=[dummy_call_empty])
198+
func_calls = model.extract_resp_func_calls(msg_empty)
199+
assert len(func_calls) == 1
200+
assert func_calls[0].func_name == "dummy_function"
201+
assert func_calls[0].arg_values == {}
202+
# When arguments is invalid JSON.
203+
dummy_call_invalid = DummyToolCallInvalid()
204+
msg_invalid = DummyMessage(tool_calls=[dummy_call_invalid])
205+
func_calls = model.extract_resp_func_calls(msg_invalid)
206+
assert len(func_calls) == 1
207+
assert func_calls[0].func_name == "dummy_function"
208+
assert func_calls[0].arg_values == {}
209+
# When arguments is valid.
210+
dummy_call = DummyToolCall()
211+
msg_valid = DummyMessage(tool_calls=[dummy_call])
212+
func_calls = model.extract_resp_func_calls(msg_valid)
213+
assert len(func_calls) == 1
214+
assert func_calls[0].func_name == "dummy_function"
215+
assert func_calls[0].arg_values == {"arg": "value"}
216+
217+
218+
# --- Parametrized Test Over All Models ---
219+
220+
@pytest.mark.parametrize("model_class, expected_name", [
221+
("Gpt_o1mini", "o1-mini"),
222+
("Gpt_o1", "o1-2024-12-17"),
223+
("Gpt4o_20240806", "gpt-4o-2024-08-06"),
224+
("Gpt4o_20240513", "gpt-4o-2024-05-13"),
225+
("Gpt4_Turbo20240409", "gpt-4-turbo-2024-04-09"),
226+
("Gpt4_0125Preview", "gpt-4-0125-preview"),
227+
("Gpt4_1106Preview", "gpt-4-1106-preview"),
228+
("Gpt35_Turbo0125", "gpt-3.5-turbo-0125"),
229+
("Gpt35_Turbo1106", "gpt-3.5-turbo-1106"),
230+
("Gpt35_Turbo16k_0613", "gpt-3.5-turbo-16k-0613"),
231+
("Gpt35_Turbo0613", "gpt-3.5-turbo-0613"),
232+
("Gpt4_0613", "gpt-4-0613"),
233+
("Gpt4o_mini_20240718", "gpt-4o-mini-2024-07-18"),
234+
])
235+
def test_openai_model_call(monkeypatch, model_class, expected_name):
236+
# Dynamically import the model class from the gpt module.
237+
from app.model import gpt
238+
model_cls = getattr(gpt, model_class)
239+
240+
# Patch necessary methods.
241+
monkeypatch.setattr(model_cls, "check_api_key", dummy_check_api_key)
242+
monkeypatch.setattr(model_cls, "calc_cost", lambda self, inp, out: 0.5)
243+
monkeypatch.setattr(common, "thread_cost", DummyThreadCost())
244+
245+
# Instantiate the model and set the dummy client.
246+
model = model_cls()
247+
model.client = DummyClient()
248+
249+
# Prepare a dummy messages list to simulate a user call.
250+
messages = [{"role": "user", "content": "Hello"}]
251+
252+
# Call the model's "call" method.
253+
result = model.call(messages)
254+
content, raw_tool_calls, func_call_intents, cost, input_tokens, output_tokens = result
255+
256+
# Verify the model's name and basic call flow.
257+
assert model.name == expected_name
258+
assert content == "Test response"
259+
assert raw_tool_calls is None # Because DummyMessage.tool_calls was None.
260+
assert func_call_intents == [] # No tool calls provided.
261+
assert cost == 0.5
262+
assert input_tokens == 1
263+
assert output_tokens == 2
264+
265+
# Check that our dummy thread cost was updated.
266+
assert common.thread_cost.process_cost == 0.5
267+
assert common.thread_cost.process_input_tokens == 1
268+
assert common.thread_cost.process_output_tokens == 2
269+
270+
# Test extract_resp_content separately.
271+
dummy_msg = DummyMessage()
272+
extracted_content = model.extract_resp_content(dummy_msg)
273+
assert extracted_content == "Test response"
274+
275+
# Test extract_resp_func_calls by simulating a tool call.
276+
dummy_msg.tool_calls = [DummyToolCall()]
277+
func_calls = model.extract_resp_func_calls(dummy_msg)
278+
assert len(func_calls) == 1
279+
# Check that the function call intent is correctly extracted.
280+
assert func_calls[0].func_name == "dummy_function"
281+
assert func_calls[0].arg_values == {"arg": "value"}
282+
283+
def test_call_single_tool_branch(monkeypatch):
284+
# Patch check_api_key to return a dummy key.
285+
monkeypatch.setattr(OpenaiModel, "check_api_key", dummy_check_api_key)
286+
# Patch calc_cost to return a fixed cost.
287+
monkeypatch.setattr(OpenaiModel, "calc_cost", lambda self, inp, out: 0.5)
288+
# Patch common.thread_cost with our dummy instance.
289+
monkeypatch.setattr(common, "thread_cost", DummyThreadCost())
290+
291+
# Instantiate a model (using Gpt_o1 as an example) and set the dummy client.
292+
from app.model.gpt import Gpt_o1 # import within test to avoid conflicts
293+
model = Gpt_o1()
294+
model.client = DummyClient()
295+
296+
# Prepare a dummy messages list to simulate a user call.
297+
messages = [{"role": "user", "content": "Hello"}]
298+
# Prepare tools with exactly one tool.
299+
tools = [{"function": {"name": "dummy_tool"}}]
300+
301+
# Call the model's "call" method.
302+
result = model.call(messages, tools=tools, temperature=1.0)
303+
304+
# Access the kwargs passed to DummyCompletions.create.
305+
kw = DummyCompletions.last_kwargs
306+
# Assert that tool_choice was added.
307+
assert "tool_choice" in kw
308+
# Check that the tool_choice has the expected structure.
309+
assert kw["tool_choice"]["type"] == "function"
310+
assert kw["tool_choice"]["function"]["name"] == "dummy_tool"

0 commit comments

Comments
 (0)