From 6bdfe1f71b59a53de2258ee264c43f5b1f3c121d Mon Sep 17 00:00:00 2001 From: Assaf Toledo Date: Tue, 6 Jan 2026 22:40:06 +0200 Subject: [PATCH] lazy import of scipy Signed-off-by: Assaf Toledo --- src/unitxt/metrics.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 0ce7a4d1b8..bf88fc2a52 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -32,8 +32,6 @@ import numpy as np import pandas as pd import requests -from scipy.stats import bootstrap -from scipy.stats._warnings_errors import DegenerateDataWarning from .artifact import Artifact from .base_metric import Metric @@ -76,8 +74,6 @@ logger = get_logger() settings = get_settings() -warnings.filterwarnings("ignore", category=DegenerateDataWarning) - @retry_connection_with_exponential_backoff(backoff_factor=2) def hf_evaluate_load(path: str, *args, **kwargs): @@ -221,6 +217,11 @@ def _sample_to_scores(self, sample: List[Any]) -> Dict[str, Any]: pass def bootstrap(self, data: List[Any], score_names: List[str]): + from scipy.stats import bootstrap + from scipy.stats._warnings_errors import DegenerateDataWarning + + warnings.filterwarnings("ignore", category=DegenerateDataWarning) + if self.ci_score_names is not None: score_names = self.ci_score_names @@ -1349,6 +1350,11 @@ def score_based_confidence_interval( Returns: Dict of confidence interval values """ + from scipy.stats import bootstrap + from scipy.stats._warnings_errors import DegenerateDataWarning + + warnings.filterwarnings("ignore", category=DegenerateDataWarning) + result = {} if not self._can_compute_confidence_intervals(num_predictions=len(instances)): @@ -1433,6 +1439,11 @@ def compute_global_confidence_intervals( self, references, predictions, task_data, score_name ): """Computed confidence intervals for a set of references and predictions.""" + from scipy.stats import bootstrap + from scipy.stats._warnings_errors import DegenerateDataWarning + + warnings.filterwarnings("ignore", category=DegenerateDataWarning) + random_gen = self.new_random_generator() def statistic(arr, axis):