Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion js/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

export type { Score, ScorerArgs, Scorer } from "./score";
export * from "./llm";
export { init } from "./oai";
export { init, getDefaultModel } from "./oai";
export type { InitOptions } from "./oai";
export * from "./string";
export * from "./list";
export * from "./moderation";
Expand Down
83 changes: 82 additions & 1 deletion js/llm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
openaiClassifierShouldEvaluateTitles,
openaiClassifierShouldEvaluateTitlesWithCoT,
} from "./llm.fixtures";
import { init } from "./oai";
import { init, getDefaultModel } from "./oai";

export const server = setupServer();

Expand Down Expand Up @@ -340,4 +340,85 @@ Issue Description: {{page_content}}
expect(capturedRequestBody.max_tokens).toBe(256);
expect(capturedRequestBody.temperature).toBe(0.5);
});

test("LLMClassifierFromTemplate uses configured default model", async () => {
let capturedModel: string | undefined;

server.use(
http.post(
"https://api.openai.com/v1/chat/completions",
async ({ request }) => {
const body = (await request.json()) as any;
capturedModel = body.model;

return HttpResponse.json({
id: "chatcmpl-test",
object: "chat.completion",
created: Date.now(),
model: body.model,
choices: [
{
index: 0,
message: {
role: "assistant",
content: null,
tool_calls: [
{
id: "call_test",
type: "function",
function: {
name: "select_choice",
arguments: JSON.stringify({ choice: "1" }),
},
},
],
},
finish_reason: "tool_calls",
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
},
});
},
),
);

const client = new OpenAI({
apiKey: "test-api-key",
baseURL: "https://api.openai.com/v1",
});

// Test with configured default model
init({ client, defaultModel: "claude-3-5-sonnet-20241022" });
expect(getDefaultModel()).toBe("claude-3-5-sonnet-20241022");

const classifier = LLMClassifierFromTemplate({
name: "test",
promptTemplate: "Test prompt: {{output}}",
choiceScores: { "1": 1, "2": 0 },
});

await classifier({
output: "test output",
expected: "test expected",
});

expect(capturedModel).toBe("claude-3-5-sonnet-20241022");

// Test that explicit model overrides default
capturedModel = undefined;
await classifier({
output: "test output",
expected: "test expected",
model: "gpt-4-turbo",
});

expect(capturedModel).toBe("gpt-4-turbo");

// Reset for other tests
init({ client });
});
});
15 changes: 13 additions & 2 deletions js/llm.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { Score, Scorer, ScorerArgs } from "./score";
import { ChatCache, OpenAIAuth, cachedChatCompletion } from "./oai";
import {
ChatCache,
OpenAIAuth,
cachedChatCompletion,
getDefaultModel,
} from "./oai";
import { ModelGradedSpec, templates } from "./templates";
import {
ChatCompletionMessage,
Expand All @@ -20,6 +25,10 @@ export type LLMArgs = {
temperature?: number;
} & OpenAIAuth;

/**
* The default model to use for LLM-based evaluations.
* @deprecated Use `init({ defaultModel: "..." })` to configure the default model instead.
*/
export const DEFAULT_MODEL = "gpt-4o";

const PLAIN_RESPONSE_SCHEMA = {
Expand Down Expand Up @@ -203,7 +212,7 @@ export function LLMClassifierFromTemplate<RenderArgs>({
name,
promptTemplate,
choiceScores,
model = DEFAULT_MODEL,
model: modelArg,
useCoT: useCoTArg,
temperature,
maxTokens: maxTokensArg,
Expand All @@ -221,6 +230,8 @@ export function LLMClassifierFromTemplate<RenderArgs>({
runtimeArgs: ScorerArgs<string, LLMClassifierArgs<RenderArgs>>,
) => {
const useCoT = runtimeArgs.useCoT ?? useCoTArg ?? true;
// Use runtime model > template model > configured default model
const model = runtimeArgs.model ?? modelArg ?? getDefaultModel();

const prompt =
promptTemplate + "\n" + (useCoT ? COT_SUFFIX : NO_COT_SUFFIX);
Expand Down
31 changes: 30 additions & 1 deletion js/oai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
test,
vi,
} from "vitest";
import { buildOpenAIClient, init } from "./oai";
import { buildOpenAIClient, init, getDefaultModel } from "./oai";

import { setupServer } from "msw/node";

Expand All @@ -37,6 +37,9 @@ afterEach(() => {

process.env.OPENAI_API_KEY = OPENAI_API_KEY;
process.env.OPENAI_BASE_URL = OPENAI_BASE_URL;

// Reset init state
init({ client: undefined, defaultModel: undefined });
});

afterAll(() => {
Expand Down Expand Up @@ -257,6 +260,32 @@ describe("OAI", () => {

expect(Object.is(builtClient, otherClient)).toBe(true);
});

test("getDefaultModel returns gpt-4o by default", () => {
expect(getDefaultModel()).toBe("gpt-4o");
});

test("init sets default model", () => {
init({ defaultModel: "claude-3-5-sonnet-20241022" });
expect(getDefaultModel()).toBe("claude-3-5-sonnet-20241022");
});

test("init can reset default model", () => {
init({ defaultModel: "claude-3-5-sonnet-20241022" });
expect(getDefaultModel()).toBe("claude-3-5-sonnet-20241022");

init({ defaultModel: undefined });
expect(getDefaultModel()).toBe("gpt-4o");
});

test("init can set both client and default model", () => {
const client = new OpenAI({ apiKey: "test-api-key" });
init({ client, defaultModel: "gpt-4-turbo" });

const builtClient = buildOpenAIClient({});
expect(Object.is(builtClient, client)).toBe(true);
expect(getDefaultModel()).toBe("gpt-4-turbo");
});
});

const withMockWrapper = async (
Expand Down
51 changes: 50 additions & 1 deletion js/oai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,59 @@ declare global {
/* eslint-disable no-var */
var __inherited_braintrust_wrap_openai: ((openai: any) => any) | undefined;
var __client: OpenAI | undefined;
var __defaultModel: string | undefined;
}

export const init = ({ client }: { client?: OpenAI } = {}) => {
export interface InitOptions {
/**
* An OpenAI-compatible client to use for all evaluations.
* This can be an OpenAI client, or any client that implements the OpenAI API
* (e.g., configured to use the Braintrust proxy with Anthropic, Gemini, etc.)
*/
client?: OpenAI;
/**
* The default model to use for evaluations when not specified per-call.
* Defaults to "gpt-4o" if not set.
*
* When using non-OpenAI providers via the Braintrust proxy, set this to
* the appropriate model string (e.g., "claude-3-5-sonnet-20241022").
*/
defaultModel?: string;
}

/**
* Initialize autoevals with a custom client and/or default model.
*
* @example
* // Using with OpenAI (default)
* import { init } from "autoevals";
* import { OpenAI } from "openai";
*
* init({ client: new OpenAI() });
*
* @example
* // Using with Anthropic via Braintrust proxy
* import { init } from "autoevals";
* import { OpenAI } from "openai";
*
* init({
* client: new OpenAI({
* apiKey: process.env.BRAINTRUST_API_KEY,
* baseURL: "https://api.braintrust.dev/v1/proxy",
* }),
* defaultModel: "claude-3-5-sonnet-20241022",
* });
*/
export const init = ({ client, defaultModel }: InitOptions = {}) => {
globalThis.__client = client;
globalThis.__defaultModel = defaultModel;
};

/**
* Get the configured default model, or "gpt-4o" if not set.
*/
export const getDefaultModel = (): string => {
return globalThis.__defaultModel ?? "gpt-4o";
};

export async function cachedChatCompletion(
Expand Down
5 changes: 3 additions & 2 deletions js/ragas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import mustache from "mustache";

import { Scorer, ScorerArgs } from "./score";
import { DEFAULT_MODEL, LLMArgs } from "./llm";
import { LLMArgs } from "./llm";
import { getDefaultModel } from "./oai";
import { buildOpenAIClient, extractOpenAIArgs } from "./oai";
import OpenAI from "openai";
import { ListContains } from "./list";
Expand Down Expand Up @@ -869,7 +870,7 @@ function parseArgs(args: ScorerArgs<string, RagasArgs>): {
OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming,
"messages"
> = {
model: args.model ?? DEFAULT_MODEL,
model: args.model ?? getDefaultModel(),
temperature: args.temperature ?? 0,
};
if (args.maxTokens) {
Expand Down
17 changes: 11 additions & 6 deletions py/autoevals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,26 @@
**Multi-provider support via the Braintrust AI Proxy**:

Autoevals supports multiple LLM providers (Anthropic, Azure, etc.) through the Braintrust AI Proxy.
Configure your client to use the proxy:
Configure your client to use the proxy and set the default model:

```python
import os
from openai import AsyncOpenAI
from autoevals import init
from autoevals.llm import Factuality

# Configure client to use Braintrust AI Proxy
# Configure client to use Braintrust AI Proxy with Claude
client = AsyncOpenAI(
base_url="https://api.braintrustproxy.com/v1",
base_url="https://api.braintrust.dev/v1/proxy",
api_key=os.getenv("BRAINTRUST_API_KEY"),
)

# Use with any evaluator
evaluator = Factuality(client=client)
# Initialize with the client and default model
init(client=client, default_model="claude-3-5-sonnet-20241022")

# All evaluators will now use Claude by default
evaluator = Factuality()
result = evaluator.eval(input="...", output="...", expected="...")
```

**Braintrust integration**:
Expand Down Expand Up @@ -125,7 +130,7 @@ async def evaluate_qa():
from .llm import *
from .moderation import *
from .number import *
from .oai import init
from .oai import get_default_model, init
from .ragas import *
from .score import Score, Scorer, SerializableDataClass
from .string import *
Expand Down
8 changes: 6 additions & 2 deletions py/autoevals/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

from autoevals.partial import ScorerWithPartial

from .oai import Client, arun_cached_request, run_cached_request
from .oai import Client, arun_cached_request, get_default_model, run_cached_request
from .score import Score

# Disable HTML escaping in chevron.
Expand All @@ -78,6 +78,7 @@
"\n", " "
)

# Deprecated: Use init(default_model="...") to configure the default model instead.
DEFAULT_MODEL = "gpt-4o"

PLAIN_RESPONSE_SCHEMA = {
Expand Down Expand Up @@ -324,7 +325,7 @@ def __init__(
name,
prompt_template,
choice_scores,
model=DEFAULT_MODEL,
model=None,
use_cot=True,
max_tokens=None,
temperature=None,
Expand All @@ -335,6 +336,9 @@ def __init__(
**extra_render_args,
):
choice_strings = list(choice_scores.keys())
# Use configured default model if not specified
if model is None:
model = get_default_model()

prompt = prompt_template + "\n" + (COT_SUFFIX if use_cot else NO_COT_SUFFIX)
messages = [
Expand Down
Loading