From d616f6737fcb21003af8a989450ca324d31437cd Mon Sep 17 00:00:00 2001 From: Stephen Belanger Date: Tue, 13 Jan 2026 07:15:22 +0800 Subject: [PATCH] Add configurable default model support This change allows users to configure which model to use as the default for all evaluations, replacing the hardcoded gpt-4o default. Changes: - Add `defaultModel` parameter to `init()` in both JS and Python - Add `getDefaultModel()` function to retrieve configured default model - Update LLMClassifier and RAGAS scorers to use configurable default model - Update documentation with examples for different use cases This enables: - Using different OpenAI models (gpt-4-turbo, o1, gpt-3.5-turbo, etc.) - Using non-OpenAI models via Braintrust proxy (Claude, Gemini, Llama, etc.) - Configuring once and having all evaluators use the preferred model Example usage: ```javascript init({ client: new OpenAI({ apiKey: process.env.BRAINTRUST_API_KEY, baseURL: "https://api.braintrust.dev/v1/proxy", }), defaultModel: "claude-3-5-sonnet-20241022", }); ``` Fixes #136 Co-Authored-By: Claude Sonnet 4.5 --- js/index.ts | 3 +- js/llm.test.ts | 83 +++++++++++++++++++++++++++++++++++++++- js/llm.ts | 15 +++++++- js/oai.test.ts | 31 ++++++++++++++- js/oai.ts | 51 +++++++++++++++++++++++- js/ragas.ts | 5 ++- py/autoevals/__init__.py | 17 +++++--- py/autoevals/llm.py | 8 +++- py/autoevals/oai.py | 36 ++++++++++++++++- py/autoevals/ragas.py | 55 ++++++++++++++++++-------- py/autoevals/test_llm.py | 79 +++++++++++++++++++++++++++++++++++++- py/autoevals/test_oai.py | 43 +++++++++++++++++++++ 12 files changed, 391 insertions(+), 35 deletions(-) diff --git a/js/index.ts b/js/index.ts index 49e7863..ced08cc 100644 --- a/js/index.ts +++ b/js/index.ts @@ -29,7 +29,8 @@ export type { Score, ScorerArgs, Scorer } from "./score"; export * from "./llm"; -export { init } from "./oai"; +export { init, getDefaultModel } from "./oai"; +export type { InitOptions } from "./oai"; export * from "./string"; export * from "./list"; export * from "./moderation"; diff --git a/js/llm.test.ts b/js/llm.test.ts index fcdb30f..6f7b6bf 100644 --- a/js/llm.test.ts +++ b/js/llm.test.ts @@ -14,7 +14,7 @@ import { openaiClassifierShouldEvaluateTitles, openaiClassifierShouldEvaluateTitlesWithCoT, } from "./llm.fixtures"; -import { init } from "./oai"; +import { init, getDefaultModel } from "./oai"; export const server = setupServer(); @@ -340,4 +340,85 @@ Issue Description: {{page_content}} expect(capturedRequestBody.max_tokens).toBe(256); expect(capturedRequestBody.temperature).toBe(0.5); }); + + test("LLMClassifierFromTemplate uses configured default model", async () => { + let capturedModel: string | undefined; + + server.use( + http.post( + "https://api.openai.com/v1/chat/completions", + async ({ request }) => { + const body = (await request.json()) as any; + capturedModel = body.model; + + return HttpResponse.json({ + id: "chatcmpl-test", + object: "chat.completion", + created: Date.now(), + model: body.model, + choices: [ + { + index: 0, + message: { + role: "assistant", + content: null, + tool_calls: [ + { + id: "call_test", + type: "function", + function: { + name: "select_choice", + arguments: JSON.stringify({ choice: "1" }), + }, + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 20, + total_tokens: 30, + }, + }); + }, + ), + ); + + const client = new OpenAI({ + apiKey: "test-api-key", + baseURL: "https://api.openai.com/v1", + }); + + // Test with configured default model + init({ client, defaultModel: "claude-3-5-sonnet-20241022" }); + expect(getDefaultModel()).toBe("claude-3-5-sonnet-20241022"); + + const classifier = LLMClassifierFromTemplate({ + name: "test", + promptTemplate: "Test prompt: {{output}}", + choiceScores: { "1": 1, "2": 0 }, + }); + + await classifier({ + output: "test output", + expected: "test expected", + }); + + expect(capturedModel).toBe("claude-3-5-sonnet-20241022"); + + // Test that explicit model overrides default + capturedModel = undefined; + await classifier({ + output: "test output", + expected: "test expected", + model: "gpt-4-turbo", + }); + + expect(capturedModel).toBe("gpt-4-turbo"); + + // Reset for other tests + init({ client }); + }); }); diff --git a/js/llm.ts b/js/llm.ts index 066e00f..7a63480 100644 --- a/js/llm.ts +++ b/js/llm.ts @@ -1,5 +1,10 @@ import { Score, Scorer, ScorerArgs } from "./score"; -import { ChatCache, OpenAIAuth, cachedChatCompletion } from "./oai"; +import { + ChatCache, + OpenAIAuth, + cachedChatCompletion, + getDefaultModel, +} from "./oai"; import { ModelGradedSpec, templates } from "./templates"; import { ChatCompletionMessage, @@ -20,6 +25,10 @@ export type LLMArgs = { temperature?: number; } & OpenAIAuth; +/** + * The default model to use for LLM-based evaluations. + * @deprecated Use `init({ defaultModel: "..." })` to configure the default model instead. + */ export const DEFAULT_MODEL = "gpt-4o"; const PLAIN_RESPONSE_SCHEMA = { @@ -203,7 +212,7 @@ export function LLMClassifierFromTemplate({ name, promptTemplate, choiceScores, - model = DEFAULT_MODEL, + model: modelArg, useCoT: useCoTArg, temperature, maxTokens: maxTokensArg, @@ -221,6 +230,8 @@ export function LLMClassifierFromTemplate({ runtimeArgs: ScorerArgs>, ) => { const useCoT = runtimeArgs.useCoT ?? useCoTArg ?? true; + // Use runtime model > template model > configured default model + const model = runtimeArgs.model ?? modelArg ?? getDefaultModel(); const prompt = promptTemplate + "\n" + (useCoT ? COT_SUFFIX : NO_COT_SUFFIX); diff --git a/js/oai.test.ts b/js/oai.test.ts index 1c0808e..abf0d59 100644 --- a/js/oai.test.ts +++ b/js/oai.test.ts @@ -10,7 +10,7 @@ import { test, vi, } from "vitest"; -import { buildOpenAIClient, init } from "./oai"; +import { buildOpenAIClient, init, getDefaultModel } from "./oai"; import { setupServer } from "msw/node"; @@ -37,6 +37,9 @@ afterEach(() => { process.env.OPENAI_API_KEY = OPENAI_API_KEY; process.env.OPENAI_BASE_URL = OPENAI_BASE_URL; + + // Reset init state + init({ client: undefined, defaultModel: undefined }); }); afterAll(() => { @@ -257,6 +260,32 @@ describe("OAI", () => { expect(Object.is(builtClient, otherClient)).toBe(true); }); + + test("getDefaultModel returns gpt-4o by default", () => { + expect(getDefaultModel()).toBe("gpt-4o"); + }); + + test("init sets default model", () => { + init({ defaultModel: "claude-3-5-sonnet-20241022" }); + expect(getDefaultModel()).toBe("claude-3-5-sonnet-20241022"); + }); + + test("init can reset default model", () => { + init({ defaultModel: "claude-3-5-sonnet-20241022" }); + expect(getDefaultModel()).toBe("claude-3-5-sonnet-20241022"); + + init({ defaultModel: undefined }); + expect(getDefaultModel()).toBe("gpt-4o"); + }); + + test("init can set both client and default model", () => { + const client = new OpenAI({ apiKey: "test-api-key" }); + init({ client, defaultModel: "gpt-4-turbo" }); + + const builtClient = buildOpenAIClient({}); + expect(Object.is(builtClient, client)).toBe(true); + expect(getDefaultModel()).toBe("gpt-4-turbo"); + }); }); const withMockWrapper = async ( diff --git a/js/oai.ts b/js/oai.ts index cfe5b54..bc0a762 100644 --- a/js/oai.ts +++ b/js/oai.ts @@ -149,10 +149,59 @@ declare global { /* eslint-disable no-var */ var __inherited_braintrust_wrap_openai: ((openai: any) => any) | undefined; var __client: OpenAI | undefined; + var __defaultModel: string | undefined; } -export const init = ({ client }: { client?: OpenAI } = {}) => { +export interface InitOptions { + /** + * An OpenAI-compatible client to use for all evaluations. + * This can be an OpenAI client, or any client that implements the OpenAI API + * (e.g., configured to use the Braintrust proxy with Anthropic, Gemini, etc.) + */ + client?: OpenAI; + /** + * The default model to use for evaluations when not specified per-call. + * Defaults to "gpt-4o" if not set. + * + * When using non-OpenAI providers via the Braintrust proxy, set this to + * the appropriate model string (e.g., "claude-3-5-sonnet-20241022"). + */ + defaultModel?: string; +} + +/** + * Initialize autoevals with a custom client and/or default model. + * + * @example + * // Using with OpenAI (default) + * import { init } from "autoevals"; + * import { OpenAI } from "openai"; + * + * init({ client: new OpenAI() }); + * + * @example + * // Using with Anthropic via Braintrust proxy + * import { init } from "autoevals"; + * import { OpenAI } from "openai"; + * + * init({ + * client: new OpenAI({ + * apiKey: process.env.BRAINTRUST_API_KEY, + * baseURL: "https://api.braintrust.dev/v1/proxy", + * }), + * defaultModel: "claude-3-5-sonnet-20241022", + * }); + */ +export const init = ({ client, defaultModel }: InitOptions = {}) => { globalThis.__client = client; + globalThis.__defaultModel = defaultModel; +}; + +/** + * Get the configured default model, or "gpt-4o" if not set. + */ +export const getDefaultModel = (): string => { + return globalThis.__defaultModel ?? "gpt-4o"; }; export async function cachedChatCompletion( diff --git a/js/ragas.ts b/js/ragas.ts index d5a5285..307bc59 100644 --- a/js/ragas.ts +++ b/js/ragas.ts @@ -2,7 +2,8 @@ import mustache from "mustache"; import { Scorer, ScorerArgs } from "./score"; -import { DEFAULT_MODEL, LLMArgs } from "./llm"; +import { LLMArgs } from "./llm"; +import { getDefaultModel } from "./oai"; import { buildOpenAIClient, extractOpenAIArgs } from "./oai"; import OpenAI from "openai"; import { ListContains } from "./list"; @@ -869,7 +870,7 @@ function parseArgs(args: ScorerArgs): { OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming, "messages" > = { - model: args.model ?? DEFAULT_MODEL, + model: args.model ?? getDefaultModel(), temperature: args.temperature ?? 0, }; if (args.maxTokens) { diff --git a/py/autoevals/__init__.py b/py/autoevals/__init__.py index 6644215..4d19a57 100644 --- a/py/autoevals/__init__.py +++ b/py/autoevals/__init__.py @@ -46,21 +46,26 @@ **Multi-provider support via the Braintrust AI Proxy**: Autoevals supports multiple LLM providers (Anthropic, Azure, etc.) through the Braintrust AI Proxy. -Configure your client to use the proxy: +Configure your client to use the proxy and set the default model: ```python import os from openai import AsyncOpenAI +from autoevals import init from autoevals.llm import Factuality -# Configure client to use Braintrust AI Proxy +# Configure client to use Braintrust AI Proxy with Claude client = AsyncOpenAI( - base_url="https://api.braintrustproxy.com/v1", + base_url="https://api.braintrust.dev/v1/proxy", api_key=os.getenv("BRAINTRUST_API_KEY"), ) -# Use with any evaluator -evaluator = Factuality(client=client) +# Initialize with the client and default model +init(client=client, default_model="claude-3-5-sonnet-20241022") + +# All evaluators will now use Claude by default +evaluator = Factuality() +result = evaluator.eval(input="...", output="...", expected="...") ``` **Braintrust integration**: @@ -125,7 +130,7 @@ async def evaluate_qa(): from .llm import * from .moderation import * from .number import * -from .oai import init +from .oai import get_default_model, init from .ragas import * from .score import Score, Scorer, SerializableDataClass from .string import * diff --git a/py/autoevals/llm.py b/py/autoevals/llm.py index d8a0324..2431429 100644 --- a/py/autoevals/llm.py +++ b/py/autoevals/llm.py @@ -56,7 +56,7 @@ from autoevals.partial import ScorerWithPartial -from .oai import Client, arun_cached_request, run_cached_request +from .oai import Client, arun_cached_request, get_default_model, run_cached_request from .score import Score # Disable HTML escaping in chevron. @@ -78,6 +78,7 @@ "\n", " " ) +# Deprecated: Use init(default_model="...") to configure the default model instead. DEFAULT_MODEL = "gpt-4o" PLAIN_RESPONSE_SCHEMA = { @@ -324,7 +325,7 @@ def __init__( name, prompt_template, choice_scores, - model=DEFAULT_MODEL, + model=None, use_cot=True, max_tokens=None, temperature=None, @@ -335,6 +336,9 @@ def __init__( **extra_render_args, ): choice_strings = list(choice_scores.keys()) + # Use configured default model if not specified + if model is None: + model = get_default_model() prompt = prompt_template + "\n" + (COT_SUFFIX if use_cot else NO_COT_SUFFIX) messages = [ diff --git a/py/autoevals/oai.py b/py/autoevals/oai.py index c439a09..33eef02 100644 --- a/py/autoevals/oai.py +++ b/py/autoevals/oai.py @@ -197,6 +197,7 @@ def is_wrapped(self) -> bool: _client_var = ContextVar[Optional[LLMClient]]("client") +_default_model_var = ContextVar[Optional[str]]("default_model") T = TypeVar("T") @@ -238,8 +239,8 @@ def resolve_client(client: Client, is_async: bool = False) -> LLMClient: return LLMClient(openai=client, is_async=is_async) -def init(client: Client | None = None, is_async: bool = False): - """Initialize Autoevals with an optional custom LLM client. +def init(client: Client | None = None, is_async: bool = False, default_model: str | None = None): + """Initialize Autoevals with an optional custom LLM client and default model. This function sets up the global client context for Autoevals to use. If no client is provided, the default OpenAI client will be used. @@ -252,8 +253,39 @@ def init(client: Client | None = None, is_async: bool = False): - OpenAIV1: Wrapped in a new LLMClient instance (OpenAI SDK v1) is_async: Whether to create a client with async operations. Defaults to False. Deprecated: Use the `client` argument directly with your desired async/sync configuration. + default_model: The default model to use for evaluations when not specified per-call. + Defaults to "gpt-4o" if not set. When using non-OpenAI providers via the Braintrust + proxy, set this to the appropriate model string (e.g., "claude-3-5-sonnet-20241022"). + + Example: + Using with OpenAI (default):: + + from openai import OpenAI + from autoevals import init + + init(client=OpenAI()) + + Using with Anthropic via Braintrust proxy:: + + import os + from openai import OpenAI + from autoevals import init + + init( + client=OpenAI( + api_key=os.environ["BRAINTRUST_API_KEY"], + base_url="https://api.braintrust.dev/v1/proxy", + ), + default_model="claude-3-5-sonnet-20241022", + ) """ _client_var.set(resolve_client(client, is_async=is_async) if client else None) + _default_model_var.set(default_model) + + +def get_default_model() -> str: + """Get the configured default model, or "gpt-4o" if not set.""" + return _default_model_var.get(None) or "gpt-4o" warned_deprecated_api_key_base_url = False diff --git a/py/autoevals/ragas.py b/py/autoevals/ragas.py index 2e432fe..a50ec1a 100644 --- a/py/autoevals/ragas.py +++ b/py/autoevals/ragas.py @@ -17,7 +17,7 @@ **Common arguments**: - - `model`: Model to use for evaluation, defaults to DEFAULT_RAGAS_MODEL (gpt-3.5-turbo-16k) + - `model`: Model to use for evaluation, defaults to the model configured via init(default_model=...) or "gpt-4o" - `client`: Optional Client for API calls. If not provided, uses global client from init() **Example**: @@ -64,7 +64,7 @@ from . import Score from .list import ListContains from .llm import OpenAILLMScorer -from .oai import Client, arun_cached_request, run_cached_request +from .oai import Client, _default_model_var, arun_cached_request, get_default_model, run_cached_request from .string import EmbeddingSimilarity @@ -74,7 +74,30 @@ def check_required(name, **kwargs): raise ValueError(f"{name} requires {key} value") +# Deprecated: Use init(default_model="...") to configure the default model instead. +# This was previously "gpt-4o-mini" but now defaults to the configured model. DEFAULT_RAGAS_MODEL = "gpt-4o-mini" + + +def _get_model(model: str | None) -> str: + """Get the model to use, respecting init(default_model=...) configuration. + + Falls back to DEFAULT_RAGAS_MODEL if no model is specified and no custom + default has been configured. + """ + if model is not None: + return model + + # Check if user configured a custom default via init(default_model=...) + # If they did (even if it's "gpt-4o"), respect it for consistency + configured_default = _default_model_var.get(None) + if configured_default is not None: + return configured_default + + # Fall back to RAGAS-specific default when user hasn't configured anything + return DEFAULT_RAGAS_MODEL + + DEFAULT_RAGAS_EMBEDDING_MODEL = "text-embedding-3-small" ENTITY_PROMPT = """Given a text, extract unique entities without repetition. Ensure you consider different forms or mentions of the same entity as a single entity. @@ -168,10 +191,10 @@ class ContextEntityRecall(OpenAILLMScorer): context: The context document(s) to search for entities in """ - def __init__(self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, client: Client | None = None, **kwargs): + def __init__(self, pairwise_scorer=None, model: str | None = None, client: Client | None = None, **kwargs): super().__init__(client=client, **kwargs) - self.extraction_model = model + self.extraction_model = _get_model(model) self.contains_scorer = ListContains( pairwise_scorer=pairwise_scorer or EmbeddingSimilarity(client=client), allow_extra_entities=True ) @@ -312,10 +335,10 @@ class ContextRelevancy(OpenAILLMScorer): context: The context document(s) to evaluate """ - def __init__(self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, client: Client | None = None, **kwargs): + def __init__(self, pairwise_scorer=None, model: str | None = None, client: Client | None = None, **kwargs): super().__init__(client=client, **kwargs) - self.model = model + self.model = _get_model(model) def _postprocess(self, context, response): sentences = json.loads(response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"]) @@ -472,10 +495,10 @@ class ContextRecall(OpenAILLMScorer): context: The context document(s) to evaluate """ - def __init__(self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, client: Client | None = None, **kwargs): + def __init__(self, pairwise_scorer=None, model: str | None = None, client: Client | None = None, **kwargs): super().__init__(client=client, **kwargs) - self.model = model + self.model = _get_model(model) def _postprocess(self, response): statements = json.loads(response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"]) @@ -632,10 +655,10 @@ class ContextPrecision(OpenAILLMScorer): context: The context document(s) to evaluate """ - def __init__(self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, client: Client | None = None, **kwargs): + def __init__(self, pairwise_scorer=None, model: str | None = None, client: Client | None = None, **kwargs): super().__init__(client=client, **kwargs) - self.model = model + self.model = _get_model(model) def _postprocess(self, response): precision = json.loads(response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"]) @@ -894,10 +917,10 @@ class Faithfulness(OpenAILLMScorer): context: The context document(s) to evaluate against """ - def __init__(self, model=DEFAULT_RAGAS_MODEL, client: Client | None = None, **kwargs): + def __init__(self, model: str | None = None, client: Client | None = None, **kwargs): super().__init__(client=client, **kwargs) - self.model = model + self.model = _get_model(model) async def _run_eval_async(self, output, expected=None, input=None, context=None, **kwargs): check_required("Faithfulness", input=input, output=output, context=context) @@ -1056,7 +1079,7 @@ class AnswerRelevancy(OpenAILLMScorer): def __init__( self, - model=DEFAULT_RAGAS_MODEL, + model: str | None = None, strictness=3, temperature=0.5, embedding_model=DEFAULT_RAGAS_EMBEDDING_MODEL, @@ -1065,7 +1088,7 @@ def __init__( ): super().__init__(temperature=temperature, client=client, **kwargs) - self.model = model + self.model = _get_model(model) self.strictness = strictness self.temperature = temperature self.embedding_model = embedding_model @@ -1301,7 +1324,7 @@ class AnswerCorrectness(OpenAILLMScorer): def __init__( self, pairwise_scorer=None, - model=DEFAULT_RAGAS_MODEL, + model: str | None = None, factuality_weight=0.75, answer_similarity_weight=0.25, answer_similarity=None, @@ -1310,7 +1333,7 @@ def __init__( ): super().__init__(client=client, **kwargs) - self.model = model + self.model = _get_model(model) self.answer_similarity = answer_similarity or AnswerSimilarity(client=client) if factuality_weight == 0 and answer_similarity_weight == 0: diff --git a/py/autoevals/test_llm.py b/py/autoevals/test_llm.py index 543eafb..3b129b3 100644 --- a/py/autoevals/test_llm.py +++ b/py/autoevals/test_llm.py @@ -10,7 +10,7 @@ from autoevals import init from autoevals.llm import Battle, Factuality, LLMClassifier, OpenAILLMClassifier, build_classification_tools -from autoevals.oai import OpenAIV1Module +from autoevals.oai import OpenAIV1Module, get_default_model class TestModel(BaseModel): @@ -470,3 +470,80 @@ def capture_request(request): # Verify that max_tokens and temperature ARE in the request with correct values assert captured_request_body["max_tokens"] == 256 assert captured_request_body["temperature"] == 0.5 + + +@respx.mock +def test_llm_classifier_uses_configured_default_model(): + """Test that LLMClassifier uses the configured default model.""" + captured_model = None + + def capture_model(request): + nonlocal captured_model + captured_model = request.content.decode("utf-8") + # Parse JSON to extract model + import json + + data = json.loads(captured_model) + captured_model = data.get("model") + + return Response( + 200, + json={ + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1234567890, + "model": captured_model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_test", + "type": "function", + "function": {"name": "select_choice", "arguments": '{"choice": "1"}'}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + }, + ) + + respx.post("https://api.openai.com/v1/chat/completions").mock(side_effect=capture_model) + + client = OpenAI(api_key="test-api-key", base_url="https://api.openai.com/v1") + + # Test with configured default model + init(client, default_model="claude-3-5-sonnet-20241022") + assert get_default_model() == "claude-3-5-sonnet-20241022" + + classifier = LLMClassifier( + name="test", + prompt_template="Test prompt: {{output}}", + choice_scores={"1": 1, "2": 0}, + ) + + classifier.eval(output="test output", expected="test expected") + + assert captured_model == "claude-3-5-sonnet-20241022" + + # Test that explicit model overrides default + captured_model = None + classifier_with_model = LLMClassifier( + name="test", + prompt_template="Test prompt: {{output}}", + choice_scores={"1": 1, "2": 0}, + model="gpt-4-turbo", + ) + + classifier_with_model.eval(output="test output", expected="test expected") + + assert captured_model == "gpt-4-turbo" + + # Reset for other tests + init(None) diff --git a/py/autoevals/test_oai.py b/py/autoevals/test_oai.py index 8f31479..f9a081f 100644 --- a/py/autoevals/test_oai.py +++ b/py/autoevals/test_oai.py @@ -20,6 +20,7 @@ OpenAIV1Module, _named_wrapper, # type: ignore[import] # Accessing private members for testing _wrap_openai, # type: ignore[import] # Accessing private members for testing + get_default_model, get_openai_wrappers, prepare_openai, ) @@ -249,3 +250,45 @@ def test_prepare_openai_v0_with_client(mock_openai_v0: OpenAIV0Module): assert prepared_client.is_wrapped assert prepared_client.openai.api_key is mock_openai_v0.api_key # must be set by the user assert prepared_client.complete.__name__ == "acreate" + + +def test_get_default_model_returns_gpt_4o_by_default(): + """Test that get_default_model returns gpt-4o when no default is configured.""" + # Reset init to clear any previous default model + init(None) + assert get_default_model() == "gpt-4o" + + +def test_init_sets_default_model(): + """Test that init sets the default model correctly.""" + init(None, default_model="claude-3-5-sonnet-20241022") + assert get_default_model() == "claude-3-5-sonnet-20241022" + + # Reset + init(None) + + +def test_init_can_reset_default_model(): + """Test that init can reset the default model to gpt-4o.""" + init(None, default_model="claude-3-5-sonnet-20241022") + assert get_default_model() == "claude-3-5-sonnet-20241022" + + init(None, default_model=None) + assert get_default_model() == "gpt-4o" + + +def test_init_can_set_both_client_and_default_model(): + """Test that init can set both client and default model together.""" + client = openai.OpenAI(api_key="api-key", base_url="http://test") + init(client, default_model="gpt-4-turbo") + + prepared_client = prepare_openai() + assert prepared_client.is_wrapped + + # Unwrap to check it's the same client + assert unwrap_named_wrapper(prepared_client.openai) == client + + assert get_default_model() == "gpt-4-turbo" + + # Reset + init(None)