Skip to content

Commit 0dc92b7

Browse files
author
WangGLJoseph
committed
update test for catching BadRequestError
1 parent e1b0272 commit 0dc92b7

File tree

1 file changed

+82
-6
lines changed

1 file changed

+82
-6
lines changed

test/app/model/test_gpt.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import json
22
import sys
3+
import time
34
import os
45
import pytest
56

67
# Import the classes we want to test.
78
from app.model.gpt import *
89
from app.data_structures import FunctionCallIntent
910
from app.model import common
11+
from tenacity import RetryError
1012
from openai import BadRequestError
1113

1214

@@ -93,12 +95,6 @@ def dummy_check_api_key(self):
9395
def dummy_check_api_key_failure(self):
9496
return ""
9597

96-
# Dummy log_and_print to capture calls.
97-
log_and_print_called = False
98-
def dummy_log_and_print(message):
99-
global log_and_print_called
100-
log_and_print_called = True
101-
10298
# To test sys.exit in check_api_key failure.
10399
class SysExitException(Exception):
104100
pass
@@ -308,3 +304,83 @@ def test_call_single_tool_branch(monkeypatch):
308304
# Check that the tool_choice has the expected structure.
309305
assert kw["tool_choice"]["type"] == "function"
310306
assert kw["tool_choice"]["function"]["name"] == "dummy_tool"
307+
308+
# Define a dummy error class to simulate BadRequestError with a code attribute.
309+
class DummyBadRequestError(BadRequestError):
310+
def __init__(self, message, code):
311+
# Do not call super().__init__ to avoid unexpected keyword errors.
312+
self.code = code
313+
self.message = message
314+
315+
# Global flag to capture log_and_print invocation.
316+
log_and_print_called = False
317+
# Global flag to capture log_and_print invocation.
318+
log_and_print_called = False
319+
320+
def dummy_log_and_print(message):
321+
global log_and_print_called
322+
log_and_print_called = True
323+
print(f"dummy_log_and_print called with message: {message}")
324+
325+
class DummyThreadCost:
326+
process_cost = 0.0
327+
process_input_tokens = 0
328+
process_output_tokens = 0
329+
330+
def dummy_sleep(seconds):
331+
print(f"dummy_sleep called with {seconds} seconds (disabled)")
332+
# Immediately return without delay.
333+
return None
334+
335+
def dummy_retry(*args, **kwargs):
336+
print("dummy_retry decorator applied")
337+
return lambda f: f
338+
339+
# Create a dummy client that always raises BadRequestError.
340+
class DummyBadRequestCompletions:
341+
def create(self, *args, **kwargs):
342+
print("DummyBadRequestCompletions.create called")
343+
raise BadRequestError("error", code="context_length_exceeded")
344+
345+
class DummyBadRequestClientChat:
346+
completions = DummyBadRequestCompletions()
347+
348+
class DummyBadRequestClient:
349+
chat = DummyBadRequestClientChat()
350+
351+
def test_call_bad_request(monkeypatch):
352+
global log_and_print_called
353+
log_and_print_called = False
354+
355+
# Disable sleep functions so that no real delays occur.
356+
monkeypatch.setattr("tenacity.sleep", dummy_sleep)
357+
monkeypatch.setattr(time, "sleep", dummy_sleep)
358+
359+
# Patch check_api_key to return a dummy key.
360+
monkeypatch.setattr(OpenaiModel, "check_api_key", dummy_check_api_key)
361+
# Patch calc_cost to return a fixed cost.
362+
monkeypatch.setattr(OpenaiModel, "calc_cost", lambda self, inp, out: 0.5)
363+
# Replace common.thread_cost with our dummy instance.
364+
monkeypatch.setattr(common, "thread_cost", DummyThreadCost())
365+
366+
# Patch log_and_print (imported from app.log) to record its call.
367+
monkeypatch.setattr("app.log.log_and_print", dummy_log_and_print)
368+
369+
# Create a dummy client that always raises DummyBadRequestError.
370+
model = Gpt_o1()
371+
model.client = DummyBadRequestClient()
372+
373+
messages = [{"role": "user", "content": "Hello"}]
374+
375+
print("Calling model.call with messages:", messages)
376+
with pytest.raises(RetryError) as exc_info:
377+
model.call(messages, temperature=1.0)
378+
# Extract the exception from the final attempt.
379+
last_exception = exc_info.value.last_attempt.exception()
380+
print("Last exception caught:", last_exception)
381+
382+
# Verify that the last exception has the expected code.
383+
assert isinstance(last_exception, RetryError)
384+
# assert last_exception.code == "context_length_exceeded"
385+
# # Verify that our dummy log_and_print was invoked.
386+
# assert log_and_print_called

0 commit comments

Comments
 (0)