Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions js/ragas.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,97 @@ describe("ContextRelevancy score clamping", () => {
expect(result.score).toBeGreaterThanOrEqual(0);
});
});

describe("AnswerCorrectness custom embedding model", () => {
const server = setupServer();

beforeAll(() => {
server.listen({
onUnhandledRequest: (req) => {
throw new Error(`Unhandled request ${req.method}, ${req.url}`);
},
});
});

afterEach(() => {
server.resetHandlers();
init();
});

afterAll(() => {
server.close();
});

test("AnswerCorrectness uses custom embedding model", async () => {
let capturedEmbeddingModel: string | undefined;

server.use(
http.post("https://api.openai.com/v1/chat/completions", async () => {
return HttpResponse.json({
id: "test-id",
object: "chat.completion",
created: Date.now(),
model: "gpt-4o",
choices: [
{
index: 0,
message: {
role: "assistant",
tool_calls: [
{
id: "call_test",
type: "function",
function: {
name: "classify_statements",
arguments: JSON.stringify({
TP: ["Paris is the capital"],
FP: [],
FN: [],
}),
},
},
],
},
finish_reason: "tool_calls",
},
],
});
}),
http.post("https://api.openai.com/v1/embeddings", async ({ request }) => {
const body = (await request.json()) as { model: string; input: string };
capturedEmbeddingModel = body.model;
return HttpResponse.json({
object: "list",
data: [
{
object: "embedding",
embedding: new Array(1536).fill(0.1),
index: 0,
},
],
model: body.model,
usage: {
prompt_tokens: 5,
total_tokens: 5,
},
});
}),
);

init({
client: new OpenAI({
apiKey: "test-api-key",
baseURL: "https://api.openai.com/v1",
}),
});

await AnswerCorrectness({
input: "What is the capital of France?",
output: "Paris",
expected: "Paris is the capital of France",
embeddingModel: "text-embedding-3-large",
});

expect(capturedEmbeddingModel).toBe("text-embedding-3-large");
});
});
49 changes: 29 additions & 20 deletions js/ragas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -791,28 +791,31 @@ export const AnswerRelevancy: ScorerWithPartial<
/**
* Scores the semantic similarity between the generated answer and ground truth.
*/
export const AnswerSimilarity: ScorerWithPartial<string, RagasArgs> =
makePartial(async (args) => {
const { ...inputs } = parseArgs(args);
export const AnswerSimilarity: ScorerWithPartial<
string,
RagasArgs & { model?: string }
> = makePartial(async (args) => {
const { ...inputs } = parseArgs(args);

const { output, expected } = checkRequired(
{ output: inputs.output, expected: inputs.expected },
"AnswerSimilarity",
);
const { output, expected } = checkRequired(
{ output: inputs.output, expected: inputs.expected },
"AnswerSimilarity",
);

const { score, error } = await EmbeddingSimilarity({
...extractOpenAIArgs(args),
output,
expected,
expectedMin: 0,
});
const { score, error } = await EmbeddingSimilarity({
...extractOpenAIArgs(args),
output,
expected,
expectedMin: 0,
model: args.model,
});

return {
name: "AnswerSimilarity",
score,
error,
};
}, "AnswerSimilarity");
return {
name: "AnswerSimilarity",
score,
error,
};
}, "AnswerSimilarity");

const CORRECTNESS_PROMPT = `Given a ground truth and an answer, analyze each statement in the answer and classify them in one of the following categories:

Expand Down Expand Up @@ -880,6 +883,7 @@ export const AnswerCorrectness: ScorerWithPartial<
factualityWeight?: number;
answerSimilarityWeight?: number;
answerSimilarity?: Scorer<string, object>;
embeddingModel?: string;
}
> = makePartial(async (args) => {
const { chatArgs, client, ...inputs } = parseArgs(args);
Expand Down Expand Up @@ -930,7 +934,12 @@ export const AnswerCorrectness: ScorerWithPartial<
}),
answerSimilarityWeight === 0
? null
: answerSimilarity({ output, expected, openAiApiKey: args.openAiApiKey }),
: answerSimilarity({
output,
expected,
openAiApiKey: args.openAiApiKey,
model: args.embeddingModel,
}),
]);

const factuality = answerCorrectnessClassificationSchema.parse(
Expand Down
21 changes: 12 additions & 9 deletions py/autoevals/ragas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,15 +1245,15 @@ def __init__(
async def _run_eval_async(self, output, expected=None, input=None, **kwargs):
check_required("AnswerSimilarity", expected=expected, output=output)

return await EmbeddingSimilarity(client=self.client).eval_async(
output=output, expected=expected, model=self.model, **self.extra_args
return await EmbeddingSimilarity(client=self.client, model=self.model).eval_async(
output=output, expected=expected, **self.extra_args
)

def _run_eval_sync(self, output, expected=None, input=None, **kwargs):
check_required("AnswerSimilarity", expected=expected, output=output)

return EmbeddingSimilarity(client=self.client).eval(
output=output, expected=expected, model=self.model, **self.extra_args
return EmbeddingSimilarity(client=self.client, model=self.model).eval(
output=output, expected=expected, **self.extra_args
)


Expand Down Expand Up @@ -1370,6 +1370,7 @@ class AnswerCorrectness(OpenAILLMScorer):
factuality_weight: Optional float between 0-1 for factual correctness weight
answer_similarity_weight: Optional float between 0-1 for answer similarity weight
answer_similarity: Optional AnswerSimilarity instance for similarity evaluation
embedding_model: Optional model to use for answer similarity embeddings
"""

def __init__(
Expand All @@ -1379,13 +1380,17 @@ def __init__(
factuality_weight=0.75,
answer_similarity_weight=0.25,
answer_similarity=None,
embedding_model=None,
client: Client | None = None,
**kwargs,
):
super().__init__(client=client, **kwargs)

self.model = _get_model(model)
self.answer_similarity = answer_similarity or AnswerSimilarity(client=client)
self.answer_similarity = answer_similarity or AnswerSimilarity(
model=embedding_model if embedding_model is not None else DEFAULT_RAGAS_EMBEDDING_MODEL,
client=client,
)

if factuality_weight == 0 and answer_similarity_weight == 0:
raise ValueError("At least one weight must be nonzero")
Expand Down Expand Up @@ -1416,14 +1421,12 @@ def _postprocess(self, factuality, similarity):
async def _run_answer_similarity_async(self, output, expected):
if self.answer_similarity_weight == 0:
return None
return await self.answer_similarity.eval_async(
output=output, expected=expected, model=self.model, **self.extra_args
)
return await self.answer_similarity.eval_async(output=output, expected=expected, **self.extra_args)

def _run_answer_similarity_sync(self, output, expected):
if self.answer_similarity_weight == 0:
return None
return self.answer_similarity.eval(output=output, expected=expected, model=self.model, **self.extra_args)
return self.answer_similarity.eval(output=output, expected=expected, **self.extra_args)

async def _run_eval_async(self, output, expected=None, input=None, **kwargs):
check_required("AnswerCorrectness", input=input, expected=expected, output=output)
Expand Down
77 changes: 77 additions & 0 deletions py/autoevals/test_ragas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import json

import pytest
import respx
from httpx import Response
from openai import OpenAI

from autoevals import init
from autoevals.ragas import *

data = {
Expand Down Expand Up @@ -119,3 +123,76 @@ def test_context_relevancy_score_normal_case():
assert result.score == pytest.approx(expected_score, rel=1e-3)
assert result.score <= 1.0
assert result.score >= 0.0


@respx.mock
def test_answer_correctness_uses_custom_embedding_model():
"""Test that AnswerCorrectness passes embedding_model parameter through to embeddings API."""
captured_embedding_model = None

def capture_embedding_model(request):
nonlocal captured_embedding_model
body = request.content.decode()
import json

data = json.loads(body)
captured_embedding_model = data.get("model")
return Response(
200,
json={
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [0.1] * 1536,
"index": 0,
}
],
"model": data.get("model"),
"usage": {"prompt_tokens": 5, "total_tokens": 5},
},
)

def mock_chat_completions(request):
return Response(
200,
json={
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4o",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"tool_calls": [
{
"id": "call_test",
"type": "function",
"function": {
"name": "classify_statements",
"arguments": '{"TP": ["Paris is the capital"], "FP": [], "FN": []}',
},
}
],
},
"finish_reason": "tool_calls",
}
],
},
)

respx.post("https://api.openai.com/v1/chat/completions").mock(side_effect=mock_chat_completions)
respx.post("https://api.openai.com/v1/embeddings").mock(side_effect=capture_embedding_model)

init(OpenAI(api_key="test-api-key", base_url="https://api.openai.com/v1"))

metric = AnswerCorrectness(embedding_model="text-embedding-3-large")
metric.eval(
input="What is the capital of France?",
output="Paris",
expected="Paris is the capital of France",
)

assert captured_embedding_model == "text-embedding-3-large"