Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
import numpy as np
import pandas as pd
import requests
from scipy.stats import bootstrap
from scipy.stats._warnings_errors import DegenerateDataWarning

from .artifact import Artifact
from .base_metric import Metric
Expand Down Expand Up @@ -76,8 +74,6 @@
logger = get_logger()
settings = get_settings()

warnings.filterwarnings("ignore", category=DegenerateDataWarning)


@retry_connection_with_exponential_backoff(backoff_factor=2)
def hf_evaluate_load(path: str, *args, **kwargs):
Expand Down Expand Up @@ -221,6 +217,11 @@ def _sample_to_scores(self, sample: List[Any]) -> Dict[str, Any]:
pass

def bootstrap(self, data: List[Any], score_names: List[str]):
from scipy.stats import bootstrap
from scipy.stats._warnings_errors import DegenerateDataWarning

warnings.filterwarnings("ignore", category=DegenerateDataWarning)

if self.ci_score_names is not None:
score_names = self.ci_score_names

Expand Down Expand Up @@ -1349,6 +1350,11 @@ def score_based_confidence_interval(
Returns:
Dict of confidence interval values
"""
from scipy.stats import bootstrap
from scipy.stats._warnings_errors import DegenerateDataWarning

warnings.filterwarnings("ignore", category=DegenerateDataWarning)

result = {}

if not self._can_compute_confidence_intervals(num_predictions=len(instances)):
Expand Down Expand Up @@ -1433,6 +1439,11 @@ def compute_global_confidence_intervals(
self, references, predictions, task_data, score_name
):
"""Computed confidence intervals for a set of references and predictions."""
from scipy.stats import bootstrap
from scipy.stats._warnings_errors import DegenerateDataWarning

warnings.filterwarnings("ignore", category=DegenerateDataWarning)

random_gen = self.new_random_generator()

def statistic(arr, axis):
Expand Down