@@ -37,17 +37,6 @@ class DummyCompletions:
3737 def create (self , * args , ** kwargs ):
3838 DummyCompletions .last_kwargs = kwargs # capture the kwargs passed in
3939 return DummyResponse ()
40-
41- # Create a dummy client that raises BadRequestError.
42- class DummyBadRequestCompletions :
43- def create (self , * args , ** kwargs ):
44- raise BadRequestError ("error" , code = "context_length_exceeded" )
45-
46- class DummyBadRequestClientChat :
47- completions = DummyBadRequestCompletions ()
48-
49- class DummyBadRequestClient :
50- chat = DummyBadRequestClientChat ()
5140
5241# Dummy client chat now includes a completions attribute.
5342class DummyClientChat :
@@ -307,21 +296,10 @@ def test_call_single_tool_branch(monkeypatch):
307296
308297# Define a dummy error class to simulate BadRequestError with a code attribute.
309298class DummyBadRequestError (BadRequestError ):
310- def __init__ (self , message , code ):
299+ def __init__ (self , message ):
311300 # Do not call super().__init__ to avoid unexpected keyword errors.
312- self .code = code
313301 self .message = message
314302
315- # Global flag to capture log_and_print invocation.
316- log_and_print_called = False
317- # Global flag to capture log_and_print invocation.
318- log_and_print_called = False
319-
320- def dummy_log_and_print (message ):
321- global log_and_print_called
322- log_and_print_called = True
323- print (f"dummy_log_and_print called with message: { message } " )
324-
325303class DummyThreadCost :
326304 process_cost = 0.0
327305 process_input_tokens = 0
@@ -336,22 +314,52 @@ def dummy_retry(*args, **kwargs):
336314 print ("dummy_retry decorator applied" )
337315 return lambda f : f
338316
317+ # Define a dummy response object with the required attributes.
318+ class DummyResponseObject :
319+ request = "dummy_request"
320+ status_code = 400 # Provide a dummy status code.
321+ headers = {"content-type" : "application/json" }
322+
339323# Create a dummy client that always raises BadRequestError.
340324class DummyBadRequestCompletions :
341325 def create (self , * args , ** kwargs ):
342326 print ("DummyBadRequestCompletions.create called" )
343- raise BadRequestError ("error" , code = "context_length_exceeded" )
327+ # Instantiate a BadRequestError with a dummy response object.
328+ err = BadRequestError ("error" , response = DummyResponseObject (), body = {})
329+ err .code = "context_length_exceeded"
330+ raise err
344331
345332class DummyBadRequestClientChat :
346333 completions = DummyBadRequestCompletions ()
347334
348335class DummyBadRequestClient :
349336 chat = DummyBadRequestClientChat ()
350337
351- def test_call_bad_request (monkeypatch ):
352- global log_and_print_called
353- log_and_print_called = False
338+ # Create a dummy client that always raises BadRequestError, with a different 'code' message.
339+ class DummyBadRequestCompletionsOther :
340+ def create (self , * args , ** kwargs ):
341+ print ("DummyBadRequestCompletionsOther.create called" )
342+ # Instantiate a BadRequestError with a dummy response object.
343+ err = BadRequestError ("error" , response = DummyResponseObject (), body = {})
344+ err .code = "some_other_code"
345+ raise err
346+
347+ class DummyBadRequestClientChatOther :
348+ completions = DummyBadRequestCompletionsOther ()
349+
350+ class DummyBadRequestClientOther :
351+ chat = DummyBadRequestClientChatOther ()
352+
353+ def extract_exception_chain (exc ):
354+ """Utility to walk the __cause__ chain and return a list of exceptions."""
355+ chain = [exc ]
356+ while exc .__cause__ is not None :
357+ exc = exc .__cause__
358+ chain .append (exc )
359+ return chain
354360
361+ def test_call_bad_request (monkeypatch ):
362+ # Do not patch log_and_print so that the actual lines in the except block execute.
355363 # Disable sleep functions so that no real delays occur.
356364 monkeypatch .setattr ("tenacity.sleep" , dummy_sleep )
357365 monkeypatch .setattr (time , "sleep" , dummy_sleep )
@@ -363,10 +371,7 @@ def test_call_bad_request(monkeypatch):
363371 # Replace common.thread_cost with our dummy instance.
364372 monkeypatch .setattr (common , "thread_cost" , DummyThreadCost ())
365373
366- # Patch log_and_print (imported from app.log) to record its call.
367- monkeypatch .setattr ("app.log.log_and_print" , dummy_log_and_print )
368-
369- # Create a dummy client that always raises DummyBadRequestError.
374+ # Create a dummy client that always raises BadRequestError.
370375 model = Gpt_o1 ()
371376 model .client = DummyBadRequestClient ()
372377
@@ -375,12 +380,39 @@ def test_call_bad_request(monkeypatch):
375380 print ("Calling model.call with messages:" , messages )
376381 with pytest .raises (RetryError ) as exc_info :
377382 model .call (messages , temperature = 1.0 )
378- # Extract the exception from the final attempt.
379- last_exception = exc_info .value .last_attempt .exception ()
380- print ("Last exception caught:" , last_exception )
381383
382- # Verify that the last exception has the expected code.
383- assert isinstance (last_exception , RetryError )
384- # assert last_exception.code == "context_length_exceeded"
385- # # Verify that our dummy log_and_print was invoked.
386- # assert log_and_print_called
384+ # Extract the last exception from the RetryError chain.
385+ last_exc = exc_info .value .last_attempt .exception ()
386+ print ("Final exception from last attempt:" , last_exc )
387+
388+ # Walk the cause chain to see if BadRequestError is present.
389+ chain = extract_exception_chain (last_exc )
390+ for i , e in enumerate (chain ):
391+ print (f"Exception in chain [{ i } ]: type={ type (e )} , message={ getattr (e , 'message' , str (e ))} , code={ getattr (e , 'code' , None )} " )
392+
393+ # Assert that one exception in the chain is a BadRequestError with the expected code.
394+ found = any (isinstance (e , BadRequestError ) and getattr (e , "code" , None ) == "context_length_exceeded"
395+ for e in chain )
396+ assert found , "BadRequestError with expected code not found in exception chain."
397+
398+ # Other tests with different error codes.
399+ model .client = DummyBadRequestClientOther ()
400+ messages = [{"role" : "user" , "content" : "Hello" }]
401+
402+ print ("Calling model.call with messages:" , messages )
403+ with pytest .raises (RetryError ) as exc_info :
404+ model .call (messages , temperature = 1.0 )
405+
406+ # Extract the last exception from the RetryError chain.
407+ last_exc = exc_info .value .last_attempt .exception ()
408+ print ("Final exception from last attempt:" , last_exc )
409+
410+ # Walk the cause chain to see if BadRequestError is present.
411+ chain = extract_exception_chain (last_exc )
412+ for i , e in enumerate (chain ):
413+ print (f"Exception in chain [{ i } ]: type={ type (e )} , message={ getattr (e , 'message' , str (e ))} , code={ getattr (e , 'code' , None )} " )
414+
415+ # Assert that one exception in the chain is a BadRequestError with the expected code.
416+ found = any (isinstance (e , BadRequestError ) and getattr (e , "code" , None ) == "some_other_code"
417+ for e in chain )
418+ assert found , "BadRequestError with expected code not found in exception chain."
0 commit comments