diff --git a/js/ragas.test.ts b/js/ragas.test.ts index a74342e..be5c7d9 100644 --- a/js/ragas.test.ts +++ b/js/ragas.test.ts @@ -234,3 +234,97 @@ describe("ContextRelevancy score clamping", () => { expect(result.score).toBeGreaterThanOrEqual(0); }); }); + +describe("AnswerCorrectness custom embedding model", () => { + const server = setupServer(); + + beforeAll(() => { + server.listen({ + onUnhandledRequest: (req) => { + throw new Error(`Unhandled request ${req.method}, ${req.url}`); + }, + }); + }); + + afterEach(() => { + server.resetHandlers(); + init(); + }); + + afterAll(() => { + server.close(); + }); + + test("AnswerCorrectness uses custom embedding model", async () => { + let capturedEmbeddingModel: string | undefined; + + server.use( + http.post("https://api.openai.com/v1/chat/completions", async () => { + return HttpResponse.json({ + id: "test-id", + object: "chat.completion", + created: Date.now(), + model: "gpt-4o", + choices: [ + { + index: 0, + message: { + role: "assistant", + tool_calls: [ + { + id: "call_test", + type: "function", + function: { + name: "classify_statements", + arguments: JSON.stringify({ + TP: ["Paris is the capital"], + FP: [], + FN: [], + }), + }, + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + }); + }), + http.post("https://api.openai.com/v1/embeddings", async ({ request }) => { + const body = (await request.json()) as { model: string; input: string }; + capturedEmbeddingModel = body.model; + return HttpResponse.json({ + object: "list", + data: [ + { + object: "embedding", + embedding: new Array(1536).fill(0.1), + index: 0, + }, + ], + model: body.model, + usage: { + prompt_tokens: 5, + total_tokens: 5, + }, + }); + }), + ); + + init({ + client: new OpenAI({ + apiKey: "test-api-key", + baseURL: "https://api.openai.com/v1", + }), + }); + + await AnswerCorrectness({ + input: "What is the capital of France?", + output: "Paris", + expected: "Paris is the capital of France", + embeddingModel: "text-embedding-3-large", + }); + + expect(capturedEmbeddingModel).toBe("text-embedding-3-large"); + }); +}); diff --git a/js/ragas.ts b/js/ragas.ts index 43f93e9..ef2e1f4 100644 --- a/js/ragas.ts +++ b/js/ragas.ts @@ -791,28 +791,31 @@ export const AnswerRelevancy: ScorerWithPartial< /** * Scores the semantic similarity between the generated answer and ground truth. */ -export const AnswerSimilarity: ScorerWithPartial = - makePartial(async (args) => { - const { ...inputs } = parseArgs(args); +export const AnswerSimilarity: ScorerWithPartial< + string, + RagasArgs & { model?: string } +> = makePartial(async (args) => { + const { ...inputs } = parseArgs(args); - const { output, expected } = checkRequired( - { output: inputs.output, expected: inputs.expected }, - "AnswerSimilarity", - ); + const { output, expected } = checkRequired( + { output: inputs.output, expected: inputs.expected }, + "AnswerSimilarity", + ); - const { score, error } = await EmbeddingSimilarity({ - ...extractOpenAIArgs(args), - output, - expected, - expectedMin: 0, - }); + const { score, error } = await EmbeddingSimilarity({ + ...extractOpenAIArgs(args), + output, + expected, + expectedMin: 0, + model: args.model, + }); - return { - name: "AnswerSimilarity", - score, - error, - }; - }, "AnswerSimilarity"); + return { + name: "AnswerSimilarity", + score, + error, + }; +}, "AnswerSimilarity"); const CORRECTNESS_PROMPT = `Given a ground truth and an answer, analyze each statement in the answer and classify them in one of the following categories: @@ -880,6 +883,7 @@ export const AnswerCorrectness: ScorerWithPartial< factualityWeight?: number; answerSimilarityWeight?: number; answerSimilarity?: Scorer; + embeddingModel?: string; } > = makePartial(async (args) => { const { chatArgs, client, ...inputs } = parseArgs(args); @@ -930,7 +934,12 @@ export const AnswerCorrectness: ScorerWithPartial< }), answerSimilarityWeight === 0 ? null - : answerSimilarity({ output, expected, openAiApiKey: args.openAiApiKey }), + : answerSimilarity({ + output, + expected, + openAiApiKey: args.openAiApiKey, + model: args.embeddingModel, + }), ]); const factuality = answerCorrectnessClassificationSchema.parse( diff --git a/py/autoevals/ragas.py b/py/autoevals/ragas.py index 8bb6a3a..794ab03 100644 --- a/py/autoevals/ragas.py +++ b/py/autoevals/ragas.py @@ -1245,15 +1245,15 @@ def __init__( async def _run_eval_async(self, output, expected=None, input=None, **kwargs): check_required("AnswerSimilarity", expected=expected, output=output) - return await EmbeddingSimilarity(client=self.client).eval_async( - output=output, expected=expected, model=self.model, **self.extra_args + return await EmbeddingSimilarity(client=self.client, model=self.model).eval_async( + output=output, expected=expected, **self.extra_args ) def _run_eval_sync(self, output, expected=None, input=None, **kwargs): check_required("AnswerSimilarity", expected=expected, output=output) - return EmbeddingSimilarity(client=self.client).eval( - output=output, expected=expected, model=self.model, **self.extra_args + return EmbeddingSimilarity(client=self.client, model=self.model).eval( + output=output, expected=expected, **self.extra_args ) @@ -1370,6 +1370,7 @@ class AnswerCorrectness(OpenAILLMScorer): factuality_weight: Optional float between 0-1 for factual correctness weight answer_similarity_weight: Optional float between 0-1 for answer similarity weight answer_similarity: Optional AnswerSimilarity instance for similarity evaluation + embedding_model: Optional model to use for answer similarity embeddings """ def __init__( @@ -1379,13 +1380,17 @@ def __init__( factuality_weight=0.75, answer_similarity_weight=0.25, answer_similarity=None, + embedding_model=None, client: Client | None = None, **kwargs, ): super().__init__(client=client, **kwargs) self.model = _get_model(model) - self.answer_similarity = answer_similarity or AnswerSimilarity(client=client) + self.answer_similarity = answer_similarity or AnswerSimilarity( + model=embedding_model if embedding_model is not None else DEFAULT_RAGAS_EMBEDDING_MODEL, + client=client, + ) if factuality_weight == 0 and answer_similarity_weight == 0: raise ValueError("At least one weight must be nonzero") @@ -1416,14 +1421,12 @@ def _postprocess(self, factuality, similarity): async def _run_answer_similarity_async(self, output, expected): if self.answer_similarity_weight == 0: return None - return await self.answer_similarity.eval_async( - output=output, expected=expected, model=self.model, **self.extra_args - ) + return await self.answer_similarity.eval_async(output=output, expected=expected, **self.extra_args) def _run_answer_similarity_sync(self, output, expected): if self.answer_similarity_weight == 0: return None - return self.answer_similarity.eval(output=output, expected=expected, model=self.model, **self.extra_args) + return self.answer_similarity.eval(output=output, expected=expected, **self.extra_args) async def _run_eval_async(self, output, expected=None, input=None, **kwargs): check_required("AnswerCorrectness", input=input, expected=expected, output=output) diff --git a/py/autoevals/test_ragas.py b/py/autoevals/test_ragas.py index 556b3ae..0f53a28 100644 --- a/py/autoevals/test_ragas.py +++ b/py/autoevals/test_ragas.py @@ -2,7 +2,11 @@ import json import pytest +import respx +from httpx import Response +from openai import OpenAI +from autoevals import init from autoevals.ragas import * data = { @@ -119,3 +123,76 @@ def test_context_relevancy_score_normal_case(): assert result.score == pytest.approx(expected_score, rel=1e-3) assert result.score <= 1.0 assert result.score >= 0.0 + + +@respx.mock +def test_answer_correctness_uses_custom_embedding_model(): + """Test that AnswerCorrectness passes embedding_model parameter through to embeddings API.""" + captured_embedding_model = None + + def capture_embedding_model(request): + nonlocal captured_embedding_model + body = request.content.decode() + import json + + data = json.loads(body) + captured_embedding_model = data.get("model") + return Response( + 200, + json={ + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.1] * 1536, + "index": 0, + } + ], + "model": data.get("model"), + "usage": {"prompt_tokens": 5, "total_tokens": 5}, + }, + ) + + def mock_chat_completions(request): + return Response( + 200, + json={ + "id": "test-id", + "object": "chat.completion", + "created": 1234567890, + "model": "gpt-4o", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "tool_calls": [ + { + "id": "call_test", + "type": "function", + "function": { + "name": "classify_statements", + "arguments": '{"TP": ["Paris is the capital"], "FP": [], "FN": []}', + }, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + }, + ) + + respx.post("https://api.openai.com/v1/chat/completions").mock(side_effect=mock_chat_completions) + respx.post("https://api.openai.com/v1/embeddings").mock(side_effect=capture_embedding_model) + + init(OpenAI(api_key="test-api-key", base_url="https://api.openai.com/v1")) + + metric = AnswerCorrectness(embedding_model="text-embedding-3-large") + metric.eval( + input="What is the capital of France?", + output="Paris", + expected="Paris is the capital of France", + ) + + assert captured_embedding_model == "text-embedding-3-large"