diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index c31d612a3a..2b05ff4a6b 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -1461,7 +1461,18 @@ def _infer( options=args, ) results.append(response) - + if return_meta_data: + return [ + TextGenerationInferenceOutput( + prediction=element["message"]["content"], + generated_text=element["message"]["content"], + input_tokens=element.get("prompt_eval_count", 0), + output_tokens=element.get("eval_count", 0), + model_name=self.model, + inference_type=self.label, + ) + for element in results + ] return [element["message"]["content"] for element in results] diff --git a/tests/inference/test_inference_engine.py b/tests/inference/test_inference_engine.py index 48261b8f01..f70a3be1c0 100644 --- a/tests/inference/test_inference_engine.py +++ b/tests/inference/test_inference_engine.py @@ -159,7 +159,7 @@ def test_llava_inference_engine(self): def test_watsonx_inference(self): model = WMLInferenceEngineGeneration( - model_name="google/flan-t5-xl", + model_name="ibm/granite-3-8b-instruct", data_classification_policy=["public"], random_seed=111, min_new_tokens=1, @@ -193,7 +193,7 @@ def test_watsonx_inference_with_external_client(self): from ibm_watsonx_ai.client import APIClient, Credentials model = WMLInferenceEngineGeneration( - model_name="google/flan-t5-xl", + model_name="ibm/granite-3-8b-instruct", data_classification_policy=["public"], random_seed=111, min_new_tokens=1, @@ -279,7 +279,7 @@ def test_option_selecting_by_log_prob_inference_engines(self): ] watsonx_engine = WMLInferenceEngineGeneration( - model_name="meta-llama/llama-3-2-1b-instruct" + model_name="ibm/granite-3-8b-instruct" ) for engine in [watsonx_engine]: @@ -383,7 +383,7 @@ def test_lite_llm_inference_engine(self): def test_lite_llm_inference_engine_without_task_data_not_failing(self): LiteLLMInferenceEngine( - model="watsonx/meta-llama/llama-3-2-1b-instruct", + model="watsonx/meta-llama/llama-3-2-11b-vision-instruct", max_tokens=2, temperature=0, top_p=1, diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index 160b9d2543..097856d267 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -2708,7 +2708,7 @@ def test_perplexity(self): metric=perplexity_question, predictions=prediction, references=references ) self.assertAlmostEqual( - first_instance_target, outputs[0]["score"]["instance"]["score"] + first_instance_target, outputs[0]["score"]["instance"]["score"], places=5 ) def test_fuzzyner(self):