From 4d5d4962df221a71d420154fb6dc88d512961d30 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Thu, 9 Oct 2025 12:54:51 +0000 Subject: [PATCH 01/42] feat: 2 vbench dimensions and vbench dependencies --- .pre-commit-config.yaml | 2 +- pyproject.toml | 2 + src/pruna/evaluation/metrics/__init__.py | 4 + .../metric_vbench_background_consistency.py | 140 +++++++ .../metrics/metric_vbench_dynamic_degree.py | 169 ++++++++ src/pruna/evaluation/metrics/vbench_utils.py | 365 ++++++++++++++++++ 6 files changed, 681 insertions(+), 1 deletion(-) create mode 100644 src/pruna/evaluation/metrics/metric_vbench_background_consistency.py create mode 100644 src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py create mode 100644 src/pruna/evaluation/metrics/vbench_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a4072ad..b899e7cb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,7 +39,7 @@ repos: grep -v "^D" | cut -f2- | while IFS= read -r file; do - if [ -f "$file" ] && ["$file" != ".pre-commit-config.yaml"] && grep -q "pruna_pro" "$file"; then + if [ -f "$file" ] && [ "$file" != ".pre-commit-config.yaml" ] && grep -q "pruna_pro" "$file"; then echo "Error: pruna_pro found in staged file $file" exit 1 fi diff --git a/pyproject.toml b/pyproject.toml index 3aca3631..88a880e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ gptqmodel = [ { index = "pruna_internal", marker = "sys_platform != 'darwin' or platform_machine != 'arm64'"}, { index = "pypi", marker = "sys_platform == 'darwin' and platform_machine == 'arm64'"}, ] +clip = {git = "https://github.com/openai/CLIP.git", rev = "dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1"} [project] name = "pruna" @@ -186,6 +187,7 @@ dev = [ ] cpu = [] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index 77ccef6a..f39de5c4 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -24,6 +24,8 @@ from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper +from pruna.evaluation.metrics.metric_vbench_background_consistency import VBenchBackgroundConsistency +from pruna.evaluation.metrics.metric_vbench_dynamic_degree import VBenchDynamicDegree __all__ = [ "MetricRegistry", @@ -43,4 +45,6 @@ "DinoScore", "SharpnessMetric", "AestheticLAION", + "VBenchBackgroundConsistency", + "VBenchDynamicDegree", ] diff --git a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py new file mode 100644 index 00000000..acf3df56 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py @@ -0,0 +1,140 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, List + +import clip +import torch +import torch.nn.functional as F # noqa: N812 +from vbench.utils import clip_transform, init_submodules + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import PAIRWISE, SINGLE, get_call_type_for_single_metric, metric_data_processor +from pruna.evaluation.metrics.vbench_utils import VBenchMixin +from pruna.logging.logger import pruna_logger + +METRIC_VBENCH_BACKGROUND_CONSISTENCY = "background_consistency" + + +@MetricRegistry.register(METRIC_VBENCH_BACKGROUND_CONSISTENCY) +class VBenchBackgroundConsistency(StatefulMetric, VBenchMixin): + """ + Background Consistency metric for VBench. + + Parameters + ---------- + *args : Any + The arguments to pass to the metric. + device : str | None + The device to run the metric on. + call_type : str + The call type to use for the metric. + **kwargs : Any + The keyword arguments to pass to the metric. + """ + + metric_name: str = METRIC_VBENCH_BACKGROUND_CONSISTENCY + default_call_type: str = "y" # We just need the outputs + higher_is_better: bool = True + # https://github.com/Vchitect/VBench/blob/dc62783c0fb4fd333249c0b669027fe102696682/evaluate.py#L111 + # explicitly sets the device to cuda. We respect this here. + runs_on: List[str] = ["cuda"] + modality: List[str] = ["video"] + # state + similarity_scores: torch.Tensor + n_samples: torch.Tensor + + def __init__( + self, + *args: Any, + device: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + if device is not None and str(device).split(":")[0] not in self.runs_on: + pruna_logger.error(f"Unsupported device {device}; supported: {self.runs_on}") + raise ValueError() + + if call_type == PAIRWISE: + # VBench itself does not support pairwise. + # We can work on this in the future. + pruna_logger.error("VBench does not support pairwise metrics. Please use single mode.") + raise ValueError() + + submodules_dict = init_submodules([METRIC_VBENCH_BACKGROUND_CONSISTENCY]) + model_path = submodules_dict[METRIC_VBENCH_BACKGROUND_CONSISTENCY][0] + + self.device = set_to_best_available_device(device) + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + + self.clip_model, self.preprocessor = clip.load(model_path, device=self.device) + self.video_transform = clip_transform(224) + + self.add_state("similarity_scores", torch.tensor(0.0)) + self.add_state("n_samples", torch.tensor(0)) + + def update(self, x: List[str], gt: Any, outputs: Any) -> None: + """ + Update the similarity scores for the batch. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data. + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images. + """ + outputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device) + # Background consistency metric only supports a batch size of 1. + # To support larger batch sizes, we stack the outputs. + outputs = torch.stack([self.video_transform(output) for output in outputs[0]]) + features = torch.stack([self.clip_model.encode_image(output) for output in outputs]) + features = F.normalize(features, dim=-1, p=2) + + first_feature = features[0].unsqueeze(0) + + similarity_to_first = F.cosine_similarity(first_feature, features[1:]).clamp(min=0.0) + similarity_to_prev = F.cosine_similarity(features[:-1], features[1:]).clamp(min=0.0) + + similarities = (similarity_to_first + similarity_to_prev) / 2 + + # Update stats + self.similarity_scores += similarities.sum().item() + self.n_samples += similarities.numel() + + def compute(self) -> MetricResult: + """ + Aggregate the final score. + + Returns + ------- + MetricResult + The final score. + """ + score = self.similarity_scores / self.n_samples + return MetricResult(self.metric_name, self.__dict__, score) + + def reset(self) -> None: + """Reset the metric states.""" + self.similarity_scores = torch.tensor(0.0) + self.n_samples = torch.tensor(0) diff --git a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py new file mode 100644 index 00000000..dd3f1218 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py @@ -0,0 +1,169 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, List + +import numpy as np +import torch +from easydict import EasyDict +from vbench.dynamic_degree import DynamicDegree +from vbench.third_party.RAFT.core.utils_core.utils import InputPadder +from vbench.utils import init_submodules + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import PAIRWISE, SINGLE, get_call_type_for_single_metric, metric_data_processor +from pruna.evaluation.metrics.vbench_utils import VBenchMixin +from pruna.logging.logger import pruna_logger + +METRIC_VBENCH_DYNAMIC_DEGREE = "dynamic_degree" + + +class PrunaDynamicDegree(DynamicDegree): + """Helper class to compute Dynamic Degree score for a given video.""" + + def infer(self, frames: torch.Tensor) -> bool: + """ + Compute Dynamic Degree score for a given video. + + Uses the RAFT Model for given video frames. + + Parameters + ---------- + frames: torch.Tensor + The video frames to compute the Dynamic Degree score for. + + Returns + ------- + bool + Whether the video contains large motions. + """ + self.set_params(frame=frames[0], count=len(frames)) + static_score = [] + for image1, image2 in zip(frames[:-1], frames[1:]): + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + _, flow_up = self.model(image1, image2, iters=20, test_mode=True) + max_rad = self.get_score(image1, flow_up) + static_score.append(max_rad) + whether_move = self.check_move(static_score) + return whether_move + + +@MetricRegistry.register(METRIC_VBENCH_DYNAMIC_DEGREE) +class VBenchDynamicDegree(StatefulMetric, VBenchMixin): + """ + Dynamic Degree Dimension from the Vbench video benchmark suite. + + It measures the degree of dynamics (i.e. whether it contains large motions) generated by the model. + + This is important since a completely static video can score well + in temporal quality metrics but is not actually useful. + + Parameters + ---------- + device: str | None, optional + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. + call_type: str, default="y" + The call type to be used, e.g., 'y' or 'y_gt'. Default is "y". + """ + + metric_name: str = METRIC_VBENCH_DYNAMIC_DEGREE + default_call_type: str = "y" # We just need the outputs + higher_is_better: bool = True + # https://github.com/Vchitect/VBench/blob/dc62783c0fb4fd333249c0b669027fe102696682/evaluate.py#L111 + # explicitly sets the device to cuda. We respect this here. + runs_on: List[str] = ["cuda"] + modality: List[str] = ["video"] + # state + scores: List[float] + + def __init__( + self, + *args: Any, + device: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + if device is not None and str(device).split(":")[0] not in self.runs_on: + pruna_logger.error(f"Unsupported device {device}; supported: {self.runs_on}") + raise ValueError() + + if call_type == PAIRWISE: + # VBench itself does not support pairwise. + # We can work on this in the future. + pruna_logger.error("VBench does not support pairwise metrics. Please use single mode.") + raise ValueError() + + submodules_dict = init_submodules([METRIC_VBENCH_DYNAMIC_DEGREE]) + model_path = submodules_dict[METRIC_VBENCH_DYNAMIC_DEGREE]["model"] + + self.device = set_to_best_available_device(device) + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + # RAFT models expect arguments to be passed as an object with attributes. + # So we need to convert the arguments to an EasyDict. + args_new = EasyDict({"model": model_path, "small": False, "mixed_precision": False, "alternate_corr": False}) + self.DynamicDegree = PrunaDynamicDegree(args_new, device) + self.add_state("scores", []) + + @torch.no_grad() + def update(self, x: List[str], gt: Any, outputs: Any) -> None: + """ + Calculate the dynamic degree score for the given video. + + The video is preprocessed to have approx. 8 frames per second. + + Then passed to the RAFT model to calculate the dynamic degree score. + + Each video is ranked as a 1 or 0 based on whether it contains large motions. + The final score is the mean of the scores for all videos. + + Parameters + ---------- + x: List[str] + The list of input videos. + gt: Any + The ground truth videos. + outputs: Any + The generated videos. + """ + outputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device) + score = self.DynamicDegree.infer(outputs) + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Calculate the final dynamic degree score. + + The final score is the mean of the scores for all videos. + + Returns + ------- + MetricResult + The dynamic degree score. + """ + final_score = np.mean(self.scores) + return MetricResult(name=self.metric_name, params=self.__dict__, result=final_score) + + def reset(self) -> None: + """Reset the state variables for the metric.""" + super().reset() + self.scores.clear() diff --git a/src/pruna/evaluation/metrics/vbench_utils.py b/src/pruna/evaluation/metrics/vbench_utils.py new file mode 100644 index 00000000..c64a9bba --- /dev/null +++ b/src/pruna/evaluation/metrics/vbench_utils.py @@ -0,0 +1,365 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re +from pathlib import Path +from typing import Any, Callable, Iterable, List + +import numpy as np +import torch +from diffusers.utils import export_to_gif, export_to_video, load_video +from PIL.Image import Image +from torchvision.transforms import ToTensor +from vbench import VBench + +from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.engine.utils import safe_memory_cleanup, set_to_best_available_device +from pruna.logging.logger import pruna_logger + + +class VBenchMixin: + """ + Mixin class for VBench metrics. + + Handles benchmark specific initilizations and artifact saving conventions. + + Parameters + ---------- + *args: Any + The arguments to pass to the metric. + **kwargs: Any + The keyword arguments to pass to the metric. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + if VBench is None: + # Additional debug info to debug installation issues + pruna_logger.debug("Initialization failed: VBench is None") + raise ImportError("VBench is not installed. Please check your pruna installation.") + + def create_filename(self, prompt: str, idx: int, file_extension: str, special_str: str = "") -> str: + """ + Create filename according to VBench formatting conventions. + + Parameters + ---------- + prompt: str + The prompt to create the filename from. + idx: int + The index of the video. Vbench uses 5 seeds for each prompt. + file_extension: str + The file extension to use. Vbench supports mp4 and gif. + special_str: str + A special string to add to the filename if you wish to add a specific identifier. + + Returns + ------- + str + The filename. + """ + return create_vbench_file_name(sanitize_prompt(prompt), idx, special_str, file_extension) + + +def load_videos(path: str | Path, return_type: str = "pt") -> List[Image] | np.ndarray | torch.Tensor: + """ + Load videos from a path. + + Parameters: + ---------- + path: str | Path + The path to the videos. + return_type: str + The type to return the videos as. Can be "pt", "np", "pil". + + Returns: + ------- + List[torch.Tensor] + The videos. + """ + video = load_video(str(path)) + if return_type == "pt": + return torch.stack([ToTensor()(frame) for frame in video]) + elif return_type == "np": + return np.stack([np.array(frame) for frame in video]) + elif return_type == "pil": + return video + else: + raise ValueError(f"Invalid return_type: {return_type}. Use 'pt', 'np', or 'pil'.") + + +def sanitize_prompt(prompt: str) -> str: + """ + Return a filesystem-safe version of a prompt. + + Replaces characters illegal in filenames and collapses whitespace so that + generated files are portable across file systems. + + Parameters: + ---------- + prompt : str + The prompt to sanitize. + + Returns: + ------- + str + The sanitized prompt. + """ + prompt = re.sub(r"[\\/:*?\"<>|]", " ", prompt) # remove illegal chars + prompt = re.sub(r"\s+", " ", prompt) # collapse multiple spaces + prompt = prompt.strip() # remove leading and trailing whitespace + return prompt + + +def prepare_batch(batch: str | tuple[str | List[str], Any]) -> str: + """ + Prepare the batch to be used in the generate_videos function. + + Pruna datamodules are expected to yield tuples where the first element is + a sequence of inputs; this utility enforces batch_size == 1 for simplicity. + + + Parameters: + ---------- + batch: str | tuple[str | List[str], Any] + The batch to prepare. + + Returns: + ------- + str + The prompt string. + """ + if isinstance(batch, str): + return batch + # for pruna datamodule. always returns a tuple where the first element is the input to the model. + elif isinstance(batch, tuple): + if not hasattr(batch[0], "__len__"): + raise ValueError(f"Batch[0] is not a sequence (got {type(batch[0])})") + if len(batch[0]) != 1: + raise ValueError(f"Only batch size 1 is supported; got {len(batch[0])}") + return batch[0][0] + else: + raise ValueError(f"Invalid batch type: {type(batch)}") + + +def _normalize_save_format(save_format: str) -> tuple[str, Callable]: + """ + Normalize the save format to be used in the generate_videos function. + + Parameters: + ---------- + save_format : str + The format to save the videos in. VBench supports mp4 and gif. + + Returns: + ------- + tuple[str, Callable] + The normalized save format and the save function. + """ + save_format = save_format.lower().strip() + if save_format == "mp4": + return ".mp4", export_to_video + if save_format == "gif": + return ".gif", export_to_gif + raise ValueError(f"Invalid save_format: {save_format}. Use 'mp4' or 'gif'.") + + +def _normalize_prompts( + prompts: str | List[str] | PrunaDataModule, split: str = "test", batch_size: int = 1 +) -> Iterable[str]: + """ + Normalize prompts to an iterable format to be used in the generate_videos function. + + Parameters: + ---------- + prompts : str | List[str] | PrunaDataModule + The prompts to normalize. + + Returns: + ------- + Iterable[str] + The normalized prompts. + """ + if isinstance(prompts, str): + return [prompts] + elif isinstance(prompts, PrunaDataModule): + return getattr(prompts, f"{split}_dataloader")(batch_size=batch_size) + else: # list of prompts, already iterable + return prompts + + +def _ensure_dir(p: Path) -> None: + """ + Ensure the directory exists. + + Parameters: + ---------- + p : Path + The path to ensure the directory exists. + """ + p.mkdir(parents=True, exist_ok=True) + + +def create_vbench_file_name(prompt: str, idx: int, special_str: str = "", postfix: str = ".mp4") -> str: + """ + Create a file name for the video in accordance with the VBench format. + + Parameters: + ---------- + prompt: str + The prompt to create the file name from. + idx: int + The index of the video. Vbench uses 5 seeds for each prompt. + special_str: str + A special string to add to the file name if you wish to add a specific identifier. + postfix: str + The format of the video file. Vbench supports mp4 and gif. + + Returns: + ------- + str + The file name for the video. + """ + return f"{prompt}{special_str}-{str(idx)}{postfix}" + + +def sample_video_from_pipelines(pipeline: Any, seeder: Any, prompt: str, device: str | torch.device = None, **kwargs): + """ + Sample a video from diffusers pipeline. + + Parameters: + ---------- + pipeline: Any + The pipeline to sample from. + prompt: str + The prompt to sample from. + seeder: Any + The seeding generator. + **kwargs: Any + Additional keyword arguments to pass to the pipeline. + + Returns: + ------- + torch.Tensor + The video tensor. + """ + with torch.inference_mode(): + out = pipeline(prompt=prompt, generator=seeder, **kwargs).frames[0] + + return out + + +def _wrap_sampler(model: Any, sampling_fn: Callable[..., Any]) -> Callable[..., Any]: + """ + Wrap a user-provided sampling function into a uniform callable. + + The returned callable has a keyword-only signature: + sampler(*, prompt: str, seed: int, device: str|torch.device, **kwargs) + + This wrapper always passes `model` as the first positional argument, so + custom functions can name their first parameter `model` or `pipeline`, etc. + + Parameters: + ---------- + model: Any + The model to sample from. + sampling_fn: Callable[..., Any] + The sampling function to wrap. + + """ + if sampling_fn != sample_video_from_pipelines: + pruna_logger.info( + "Using custom sampling function. Ensure it accepts (model, *, prompt, seed, device, **kwargs)." + ) + + # The sampling function may expect the model as "pipeline" so we pass it as an arg and not a kwarg. + def sampler(*, prompt: str, seeder: Any, device: str | torch.device, **kwargs: Any) -> Any: + return sampling_fn(model, prompt=prompt, seeder=seeder, device=device, **kwargs) + + return sampler + + +def generate_videos( + model: Any, + prompts: str | List[str] | PrunaDataModule, + split: str = "test", + unique_sample_per_video_count: int = 5, + global_seed: int = 42, + sampling_fn: Callable[..., Any] = sample_video_from_pipelines, + fps: int = 16, + save_dir: str | Path = "./saved_videos", + save_format: str = "mp4", + special_str: str = "", + device: str | torch.device = None, + **model_kwargs, +) -> None: + """ + Generate N samples per prompt and save them to disk with seed tracking. + + This function: + 1) Normalizes prompts (string, list, or datamodule). + 2) Uses an RNG seeded with `global_seed` for reproducibility across runs. + 3) Saves videos as MP4 or GIF. + + Parameters: + ---------- + model : Any + The model to sample from. + prompts : str | List[str] | PrunaDataModule + The prompts to sample from. + split : str + The split to sample from. + Default is "test" since most benchmarking datamodules in Pruna are configured to use the test split. + unique_sample_per_video_count : int + The number of unique samples per video. Default is 5 by VBench requirements. + global_seed : int + The global seed to sample from. + sampling_fn : Callable[..., Any] + The sampling function to use. + fps : int + The frames per second of the video. + save_dir : str | Path + The directory to save the videos to. + save_format : str + The format to save the videos in. VBench supports mp4 and gif. + special_str : str + A special string to add to the file name if you wish to add a specific identifier. + **model_kwargs : Any + Additional keyword arguments to pass to the sampling function. + """ + file_extension, save_fn = _normalize_save_format(save_format) + + device = set_to_best_available_device(device) + + prompt_iterable = _normalize_prompts(prompts, split, batch_size=1) + + save_dir = Path(save_dir) + _ensure_dir(save_dir) + + # set a run-level seed (VBench suggests this) (important for reproducibility) + seed_rng = torch.Generator().manual_seed(global_seed) + sampler = _wrap_sampler(model=model, sampling_fn=sampling_fn) + + for batch in prompt_iterable: + prompt = prepare_batch(batch) + for idx in range(unique_sample_per_video_count): + file_name = create_vbench_file_name(sanitize_prompt(prompt), idx, special_str, file_extension) + out_path = save_dir / file_name + + vid = sampler(prompt=prompt, seeder=seed_rng, device=device, **model_kwargs) + save_fn(vid, out_path, fps=fps) + + del vid + safe_memory_cleanup() From b8d039283c31e667fcf605aae8d57563c3b0faad Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Fri, 10 Oct 2025 15:34:25 +0000 Subject: [PATCH 02/42] test: vbench metric tests --- .../metrics/metric_vbench_dynamic_degree.py | 34 ++++++++--- tests/evaluation/test_vbench_metrics.py | 58 +++++++++++++++++++ 2 files changed, 85 insertions(+), 7 deletions(-) create mode 100644 tests/evaluation/test_vbench_metrics.py diff --git a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py index dd3f1218..c51be4f0 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py +++ b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py @@ -37,7 +37,7 @@ class PrunaDynamicDegree(DynamicDegree): """Helper class to compute Dynamic Degree score for a given video.""" - def infer(self, frames: torch.Tensor) -> bool: + def infer(self, frames: torch.Tensor, interval: int) -> bool: """ Compute Dynamic Degree score for a given video. @@ -53,7 +53,11 @@ def infer(self, frames: torch.Tensor) -> bool: bool Whether the video contains large motions. """ + frames = [fr.unsqueeze(0) for fr in frames] + + frames = self.extract_frame(frames, interval=max(1, interval)) self.set_params(frame=frames[0], count=len(frames)) + static_score = [] for image1, image2 in zip(frames[:-1], frames[1:]): padder = InputPadder(image1.shape) @@ -77,11 +81,22 @@ class VBenchDynamicDegree(StatefulMetric, VBenchMixin): Parameters ---------- + *args : Any + The arguments to be passed to the DynamicDegree class. device: str | None, optional The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. call_type: str, default="y" The call type to be used, e.g., 'y' or 'y_gt'. Default is "y". + interval : int, default=3 + The interval to be used to extract frames from the video. + The default Vbench dimension loads videos from file and preprocesses them to have 8 frames per second. + For instance, if the video is 24fps, Vbench will only get every 3rd frame. + Here, we deal directly with the model outputs, so we initialize the interval to be 3, + which is a reasonable skip interval. + Feel free to change this to your needs. + **kwargs : Any + The keyword arguments to be passed to the DynamicDegree class. """ metric_name: str = METRIC_VBENCH_DYNAMIC_DEGREE @@ -99,6 +114,7 @@ def __init__( *args: Any, device: str | None = None, call_type: str = SINGLE, + interval: int = 3, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) @@ -113,6 +129,7 @@ def __init__( pruna_logger.error("VBench does not support pairwise metrics. Please use single mode.") raise ValueError() + self.interval = interval submodules_dict = init_submodules([METRIC_VBENCH_DYNAMIC_DEGREE]) model_path = submodules_dict[METRIC_VBENCH_DYNAMIC_DEGREE]["model"] @@ -121,7 +138,7 @@ def __init__( # RAFT models expect arguments to be passed as an object with attributes. # So we need to convert the arguments to an EasyDict. args_new = EasyDict({"model": model_path, "small": False, "mixed_precision": False, "alternate_corr": False}) - self.DynamicDegree = PrunaDynamicDegree(args_new, device) + self.DynamicDegree = PrunaDynamicDegree(args_new, self.device) self.add_state("scores", []) @torch.no_grad() @@ -129,8 +146,6 @@ def update(self, x: List[str], gt: Any, outputs: Any) -> None: """ Calculate the dynamic degree score for the given video. - The video is preprocessed to have approx. 8 frames per second. - Then passed to the RAFT model to calculate the dynamic degree score. Each video is ranked as a 1 or 0 based on whether it contains large motions. @@ -143,11 +158,16 @@ def update(self, x: List[str], gt: Any, outputs: Any) -> None: gt: Any The ground truth videos. outputs: Any - The generated videos. + The generated videos. Should be a tensor of shape (B, T, C, H, W). + where B is the batch size, T is the number of frames, C is the number of channels, H is the height, + and W is the width. """ outputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device) - score = self.DynamicDegree.infer(outputs) - self.scores.append(score) + videos = outputs[0] + + for video in videos: + score = self.DynamicDegree.infer(video, self.interval) + self.scores.append(score) def compute(self) -> MetricResult: """ diff --git a/tests/evaluation/test_vbench_metrics.py b/tests/evaluation/test_vbench_metrics.py new file mode 100644 index 00000000..8b3874ed --- /dev/null +++ b/tests/evaluation/test_vbench_metrics.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import pytest +import torch +import numpy as np + + +from pruna.evaluation.metrics.metric_vbench_background_consistency import VBenchBackgroundConsistency +from pruna.evaluation.metrics.metric_vbench_dynamic_degree import VBenchDynamicDegree + + + +def test_metric_background_consistency(): + """Test metric background consistency.""" + # let's create a batch of 2 random RGB videos with 2 frames each that's 16 x 16 pixels. + random_input_video_batch = torch.randn(2, 2, 3, 16,16) + # let's create a batch of 2 all black RGB videos with 2 frames each that's 16 x 16 pixels. + all_black_input_video_batch = torch.zeros(2, 2, 3, 16,16) + metric = VBenchBackgroundConsistency() + metric.update(random_input_video_batch, random_input_video_batch, random_input_video_batch) + random_result = metric.compute() + metric.reset() + metric.update(all_black_input_video_batch, all_black_input_video_batch, all_black_input_video_batch) + all_black_result = metric.compute() + metric.reset() + # Background consistency checks for the cosine similarity between frames. + # Therefore we would expect a completely blacked out video (even though meaningless) to have a much higher score + # than a completely random set of frames. + assert all_black_result.result >= random_result.result + # Since the all black video is completely black, the cosine similarity between frames should be 1.0 + assert np.isclose(all_black_result.result, 1.0) + +@pytest.mark.cuda +@pytest.mark.parametrize("model_fixture", ["wan_tiny_random"], indirect=["model_fixture"]) +def test_metrtic_dynamic_degree_cuda(model_fixture): + """Test metric dynamic degree.""" + model, smash_config = model_fixture + model.to("cuda") + model.to(torch.float32) + # this is a prompt from the vbench dataset under dynamic degree dimension. + output_video = model("a dog running happily", num_inference_steps=10, output_type="pt").frames[0].unsqueeze(0) + + metric = VBenchDynamicDegree(interval=4) + metric.update(output_video, output_video, output_video) + result = metric.compute() + # a video of a dog running ideally should have a dynamic degree of 1.0 (since it contains large movements) + assert result.result == 1.0 + + +def test_metrtic_dynamic_degree_cpu(): + """Test metric dynamic degree on CPU.""" + # Testing for a lack of movement is much easier than testing for movement. + # We can easily do this on CPU, thus also checking the metric on CPU. + video = torch.zeros(1,4,3,64,64) # a completely black video doesn't contain any movements + metric = VBenchDynamicDegree(interval=1) + metric.update(video, video, video) + result = metric.compute() + assert result.result == 0.0 From d350a70d09b460e0d180c55dc42247642faad2b8 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Fri, 10 Oct 2025 15:53:34 +0000 Subject: [PATCH 03/42] docs: add more comprehensive docstring explanations for important parameters --- .../metrics/metric_vbench_dynamic_degree.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py index c51be4f0..d681a6bf 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py +++ b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py @@ -47,6 +47,10 @@ def infer(self, frames: torch.Tensor, interval: int) -> bool: ---------- frames: torch.Tensor The video frames to compute the Dynamic Degree score for. + interval: int + The interval to skip frames. It's possible for each consecutive frame to not have extreme motion, + even though the video itself contains large dynamic changes. + Therefore it's important to set the inteval to skip frames correctly. Returns ------- @@ -83,10 +87,10 @@ class VBenchDynamicDegree(StatefulMetric, VBenchMixin): ---------- *args : Any The arguments to be passed to the DynamicDegree class. - device: str | None, optional + device : str | None, optional The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. - call_type: str, default="y" + call_type : str, default="y" The call type to be used, e.g., 'y' or 'y_gt'. Default is "y". interval : int, default=3 The interval to be used to extract frames from the video. @@ -153,11 +157,11 @@ def update(self, x: List[str], gt: Any, outputs: Any) -> None: Parameters ---------- - x: List[str] + x : List[str] The list of input videos. - gt: Any + gt : Any The ground truth videos. - outputs: Any + outputs : Any The generated videos. Should be a tensor of shape (B, T, C, H, W). where B is the batch size, T is the number of frames, C is the number of channels, H is the height, and W is the width. From d16c583799a5323bc299a3fd013cc50932d9e084 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Mon, 13 Oct 2025 16:01:02 +0000 Subject: [PATCH 04/42] feat: add additional helper tools to utilities --- .../metric_vbench_background_consistency.py | 15 ++-- .../metrics/metric_vbench_dynamic_degree.py | 4 +- src/pruna/evaluation/metrics/vbench_utils.py | 74 ++++++++++++++++++- 3 files changed, 83 insertions(+), 10 deletions(-) diff --git a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py index acf3df56..93c7f600 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py +++ b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py @@ -19,6 +19,7 @@ import clip import torch import torch.nn.functional as F # noqa: N812 +from torchvision.transforms.functional import convert_image_dtype from vbench.utils import clip_transform, init_submodules from pruna.engine.utils import set_to_best_available_device @@ -107,14 +108,18 @@ def update(self, x: List[str], gt: Any, outputs: Any) -> None: outputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device) # Background consistency metric only supports a batch size of 1. # To support larger batch sizes, we stack the outputs. - outputs = torch.stack([self.video_transform(output) for output in outputs[0]]) + outputs = super().validate_batch(outputs[0]) + # This metric depends on the outputs being uint8. + outputs = torch.stack([convert_image_dtype(output, dtype=torch.uint8) for output in outputs]) + outputs = torch.stack([self.video_transform(output) for output in outputs]) features = torch.stack([self.clip_model.encode_image(output) for output in outputs]) - features = F.normalize(features, dim=-1, p=2) + features = torch.stack([F.normalize(feature, dim=-1, p=2) for feature in features]) - first_feature = features[0].unsqueeze(0) + # We vectorize the calculation to avoid for loops. + first_feature = features[:, 0, ...].unsqueeze(1).repeat(1, features.shape[1] - 1, 1) - similarity_to_first = F.cosine_similarity(first_feature, features[1:]).clamp(min=0.0) - similarity_to_prev = F.cosine_similarity(features[:-1], features[1:]).clamp(min=0.0) + similarity_to_first = F.cosine_similarity(first_feature, features[:, 1:, ...], dim=-1).clamp(min=0.0) + similarity_to_prev = F.cosine_similarity(features[:, :-1, ...], features[:, 1:, ...], dim=-1).clamp(min=0.0) similarities = (similarity_to_first + similarity_to_prev) / 2 diff --git a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py index d681a6bf..afcdeb63 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py +++ b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py @@ -162,12 +162,12 @@ def update(self, x: List[str], gt: Any, outputs: Any) -> None: gt : Any The ground truth videos. outputs : Any - The generated videos. Should be a tensor of shape (B, T, C, H, W). + The generated videos. Should be a tensor of shape (T, C, H, W) or (B, T, C, H, W). where B is the batch size, T is the number of frames, C is the number of channels, H is the height, and W is the width. """ outputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device) - videos = outputs[0] + videos = super().validate_batch(outputs[0]) for video in videos: score = self.DynamicDegree.infer(video, self.interval) diff --git a/src/pruna/evaluation/metrics/vbench_utils.py b/src/pruna/evaluation/metrics/vbench_utils.py index c64a9bba..13f6041e 100644 --- a/src/pruna/evaluation/metrics/vbench_utils.py +++ b/src/pruna/evaluation/metrics/vbench_utils.py @@ -20,13 +20,16 @@ import numpy as np import torch -from diffusers.utils import export_to_gif, export_to_video, load_video +from diffusers.utils import export_to_gif, export_to_video +from diffusers.utils import load_video as diffusers_load_video from PIL.Image import Image from torchvision.transforms import ToTensor from vbench import VBench from pruna.data.pruna_datamodule import PrunaDataModule from pruna.engine.utils import safe_memory_cleanup, set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.result import MetricResult from pruna.logging.logger import pruna_logger @@ -72,8 +75,28 @@ def create_filename(self, prompt: str, idx: int, file_extension: str, special_st """ return create_vbench_file_name(sanitize_prompt(prompt), idx, special_str, file_extension) + def validate_batch(self, batch: torch.Tensor) -> torch.Tensor: + """ + Make sure that the video tensor has correct dimensions. + + Parameters: + ---------- + batch: torch.Tensor + The video tensor. + + Returns: + ------- + torch.Tensor + The video tensor. + """ + if batch.ndim == 4: + return batch.unsqueeze(0) + elif batch.ndim != 5: + raise ValueError(f"Batch must be 4 or 5 dimensional video tensor with B,T,C,H,W, got {batch.ndim}") + return batch + -def load_videos(path: str | Path, return_type: str = "pt") -> List[Image] | np.ndarray | torch.Tensor: +def load_video(path: str | Path, return_type: str = "pt") -> List[Image] | np.ndarray | torch.Tensor: """ Load videos from a path. @@ -89,7 +112,7 @@ def load_videos(path: str | Path, return_type: str = "pt") -> List[Image] | np.n List[torch.Tensor] The videos. """ - video = load_video(str(path)) + video = diffusers_load_video(str(path)) if return_type == "pt": return torch.stack([ToTensor()(frame) for frame in video]) elif return_type == "np": @@ -100,6 +123,25 @@ def load_videos(path: str | Path, return_type: str = "pt") -> List[Image] | np.n raise ValueError(f"Invalid return_type: {return_type}. Use 'pt', 'np', or 'pil'.") +def load_videos_from_path(path: str | Path) -> List[List[Image] | np.ndarray | torch.Tensor]: + """ + Load entire directory of videos. + + Parameters: + ---------- + path : str | Path + The path to the directory of videos. + + Returns: + ------- + List[List[Image] | np.ndarray | torch.Tensor] + The videos. + """ + path = Path(str(path)) + videos = torch.stack([load_video(p) for p in path.glob("*.mp4")]) + return videos + + def sanitize_prompt(prompt: str) -> str: """ Return a filesystem-safe version of a prompt. @@ -363,3 +405,29 @@ def generate_videos( del vid safe_memory_cleanup() + + +def evaluate_videos(data: Any, metrics: StatefulMetric | List[StatefulMetric]) -> List[MetricResult]: + """ + Evaluation loop helper. + + Parameters: + ---------- + data : Any + The data to evaluate. + metrics : StatefulMetric | List[StatefulMetric] + The metrics to evaluate. + + Returns: + ------- + List[MetricResult] + The results of the evaluation. + """ + results = [] + if isinstance(metrics, StatefulMetric): + metrics = [metrics] + for metric in metrics: + for batch in data: + metric.update(batch, batch, batch) + results.append(metric.compute()) + return results From 82342cb7e866f7f29a640ea79e60ed1068cbdad3 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Tue, 14 Oct 2025 09:44:59 +0000 Subject: [PATCH 05/42] refactor: small updates to utilities and docstrings --- .../metric_vbench_background_consistency.py | 3 +++ src/pruna/evaluation/metrics/vbench_utils.py | 25 +++++++------------ tests/evaluation/test_vbench_metrics.py | 14 +++++------ 3 files changed, 19 insertions(+), 23 deletions(-) diff --git a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py index 93c7f600..3c5f7c9e 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py +++ b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py @@ -136,10 +136,13 @@ def compute(self) -> MetricResult: MetricResult The final score. """ + if self.n_samples == 0: + return MetricResult(self.metric_name, self.__dict__, 0.0) score = self.similarity_scores / self.n_samples return MetricResult(self.metric_name, self.__dict__, score) def reset(self) -> None: """Reset the metric states.""" + super().reset() self.similarity_scores = torch.tensor(0.0) self.n_samples = torch.tensor(0) diff --git a/src/pruna/evaluation/metrics/vbench_utils.py b/src/pruna/evaluation/metrics/vbench_utils.py index 13f6041e..2915baa8 100644 --- a/src/pruna/evaluation/metrics/vbench_utils.py +++ b/src/pruna/evaluation/metrics/vbench_utils.py @@ -24,7 +24,6 @@ from diffusers.utils import load_video as diffusers_load_video from PIL.Image import Image from torchvision.transforms import ToTensor -from vbench import VBench from pruna.data.pruna_datamodule import PrunaDataModule from pruna.engine.utils import safe_memory_cleanup, set_to_best_available_device @@ -47,12 +46,6 @@ class VBenchMixin: The keyword arguments to pass to the metric. """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - if VBench is None: - # Additional debug info to debug installation issues - pruna_logger.debug("Initialization failed: VBench is None") - raise ImportError("VBench is not installed. Please check your pruna installation.") - def create_filename(self, prompt: str, idx: int, file_extension: str, special_str: str = "") -> str: """ Create filename according to VBench formatting conventions. @@ -123,9 +116,9 @@ def load_video(path: str | Path, return_type: str = "pt") -> List[Image] | np.nd raise ValueError(f"Invalid return_type: {return_type}. Use 'pt', 'np', or 'pil'.") -def load_videos_from_path(path: str | Path) -> List[List[Image] | np.ndarray | torch.Tensor]: +def load_videos_from_path(path: str | Path) -> torch.Tensor: """ - Load entire directory of videos. + Load entire directory of mp4 videos as a single tensor ready to be passed to evaluation. Parameters: ---------- @@ -185,7 +178,7 @@ def prepare_batch(batch: str | tuple[str | List[str], Any]) -> str: """ if isinstance(batch, str): return batch - # for pruna datamodule. always returns a tuple where the first element is the input to the model. + # for pruna datamodule. always returns a tuple where the first element is the input (list of prompts) to the model. elif isinstance(batch, tuple): if not hasattr(batch[0], "__len__"): raise ValueError(f"Batch[0] is not a sequence (got {type(batch[0])})") @@ -308,7 +301,7 @@ def _wrap_sampler(model: Any, sampling_fn: Callable[..., Any]) -> Callable[..., Wrap a user-provided sampling function into a uniform callable. The returned callable has a keyword-only signature: - sampler(*, prompt: str, seed: int, device: str|torch.device, **kwargs) + sampler(*, prompt: str, seeder: Any, device: str|torch.device, **kwargs) This wrapper always passes `model` as the first positional argument, so custom functions can name their first parameter `model` or `pipeline`, etc. @@ -323,7 +316,7 @@ def _wrap_sampler(model: Any, sampling_fn: Callable[..., Any]) -> Callable[..., """ if sampling_fn != sample_video_from_pipelines: pruna_logger.info( - "Using custom sampling function. Ensure it accepts (model, *, prompt, seed, device, **kwargs)." + "Using custom sampling function. Ensure it accepts (model, *, prompt, seeder, device, **kwargs)." ) # The sampling function may expect the model as "pipeline" so we pass it as an arg and not a kwarg. @@ -337,7 +330,7 @@ def generate_videos( model: Any, prompts: str | List[str] | PrunaDataModule, split: str = "test", - unique_sample_per_video_count: int = 5, + unique_sample_per_video_count: int = 1, global_seed: int = 42, sampling_fn: Callable[..., Any] = sample_video_from_pipelines, fps: int = 16, @@ -351,9 +344,9 @@ def generate_videos( Generate N samples per prompt and save them to disk with seed tracking. This function: - 1) Normalizes prompts (string, list, or datamodule). - 2) Uses an RNG seeded with `global_seed` for reproducibility across runs. - 3) Saves videos as MP4 or GIF. + - Normalizes prompts (string, list, or datamodule). + - Uses an RNG seeded with `global_seed` for reproducibility across runs. + - Saves videos as MP4 or GIF. Parameters: ---------- diff --git a/tests/evaluation/test_vbench_metrics.py b/tests/evaluation/test_vbench_metrics.py index 8b3874ed..67a7d816 100644 --- a/tests/evaluation/test_vbench_metrics.py +++ b/tests/evaluation/test_vbench_metrics.py @@ -9,7 +9,7 @@ from pruna.evaluation.metrics.metric_vbench_dynamic_degree import VBenchDynamicDegree - +@pytest.mark.cuda def test_metric_background_consistency(): """Test metric background consistency.""" # let's create a batch of 2 random RGB videos with 2 frames each that's 16 x 16 pixels. @@ -32,8 +32,8 @@ def test_metric_background_consistency(): @pytest.mark.cuda @pytest.mark.parametrize("model_fixture", ["wan_tiny_random"], indirect=["model_fixture"]) -def test_metrtic_dynamic_degree_cuda(model_fixture): - """Test metric dynamic degree.""" +def test_metrtic_dynamic_degree_dynamic(model_fixture): + """Test metric dynamic degree with an example sample from the vbench dataset that returns a dynamic video.""" model, smash_config = model_fixture model.to("cuda") model.to(torch.float32) @@ -46,11 +46,11 @@ def test_metrtic_dynamic_degree_cuda(model_fixture): # a video of a dog running ideally should have a dynamic degree of 1.0 (since it contains large movements) assert result.result == 1.0 - -def test_metrtic_dynamic_degree_cpu(): - """Test metric dynamic degree on CPU.""" +@pytest.mark.cuda +def test_metric_dynamic_degree_static(): + """Test metric dynamic degree fail case.""" # Testing for a lack of movement is much easier than testing for movement. - # We can easily do this on CPU, thus also checking the metric on CPU. + # We create a completely black video to test the metric. video = torch.zeros(1,4,3,64,64) # a completely black video doesn't contain any movements metric = VBenchDynamicDegree(interval=1) metric.update(video, video, video) From 102e91f4da36804e7387fa0c0fb6e5208ef7008f Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Wed, 15 Oct 2025 09:08:37 +0000 Subject: [PATCH 06/42] refactor: add support for more calltypes in video eval utils --- src/pruna/evaluation/metrics/vbench_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/pruna/evaluation/metrics/vbench_utils.py b/src/pruna/evaluation/metrics/vbench_utils.py index 2915baa8..1c0c4468 100644 --- a/src/pruna/evaluation/metrics/vbench_utils.py +++ b/src/pruna/evaluation/metrics/vbench_utils.py @@ -400,7 +400,9 @@ def generate_videos( safe_memory_cleanup() -def evaluate_videos(data: Any, metrics: StatefulMetric | List[StatefulMetric]) -> List[MetricResult]: +def evaluate_videos( + data: Any, metrics: StatefulMetric | List[StatefulMetric], prompts: Any | None = None +) -> List[MetricResult]: """ Evaluation loop helper. @@ -419,8 +421,14 @@ def evaluate_videos(data: Any, metrics: StatefulMetric | List[StatefulMetric]) - results = [] if isinstance(metrics, StatefulMetric): metrics = [metrics] + if any(metric.call_type != "y" for metric in metrics) and prompts is None: + raise ValueError( + "You are trying to evaluate metrics that require more than the outputs,but didn't provide prompts." + ) for metric in metrics: for batch in data: - metric.update(batch, batch, batch) + if prompts is None: + prompts = batch + metric.update(prompts, batch, batch) results.append(metric.compute()) return results From 2f579fd12ba0074eeb4e5e6e82e1985067e6fc86 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Tue, 28 Oct 2025 15:50:42 +0000 Subject: [PATCH 07/42] refactor: make utilities more vbench independent and fix small things in the metric implementations. --- .../metric_vbench_background_consistency.py | 11 ++--- .../metrics/metric_vbench_dynamic_degree.py | 3 +- src/pruna/evaluation/metrics/vbench_utils.py | 41 +++++++++++++++---- 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py index 3c5f7c9e..bea0e8dd 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py +++ b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py @@ -55,7 +55,7 @@ class VBenchBackgroundConsistency(StatefulMetric, VBenchMixin): higher_is_better: bool = True # https://github.com/Vchitect/VBench/blob/dc62783c0fb4fd333249c0b669027fe102696682/evaluate.py#L111 # explicitly sets the device to cuda. We respect this here. - runs_on: List[str] = ["cuda"] + runs_on: List[str] = ["cuda, cpu"] modality: List[str] = ["video"] # state similarity_scores: torch.Tensor @@ -87,9 +87,10 @@ def __init__( self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) self.clip_model, self.preprocessor = clip.load(model_path, device=self.device) + # Cropping for the CLIP encoder. self.video_transform = clip_transform(224) - self.add_state("similarity_scores", torch.tensor(0.0)) + self.add_state("similarity_scores_cumsum", torch.tensor(0.0)) self.add_state("n_samples", torch.tensor(0)) def update(self, x: List[str], gt: Any, outputs: Any) -> None: @@ -124,7 +125,7 @@ def update(self, x: List[str], gt: Any, outputs: Any) -> None: similarities = (similarity_to_first + similarity_to_prev) / 2 # Update stats - self.similarity_scores += similarities.sum().item() + self.similarity_scores_cumsum += similarities.sum().item() self.n_samples += similarities.numel() def compute(self) -> MetricResult: @@ -138,11 +139,11 @@ def compute(self) -> MetricResult: """ if self.n_samples == 0: return MetricResult(self.metric_name, self.__dict__, 0.0) - score = self.similarity_scores / self.n_samples + score = self.similarity_scores_cumsum / self.n_samples return MetricResult(self.metric_name, self.__dict__, score) def reset(self) -> None: """Reset the metric states.""" super().reset() - self.similarity_scores = torch.tensor(0.0) + self.similarity_scores_cumsum = torch.tensor(0.0) self.n_samples = torch.tensor(0) diff --git a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py index afcdeb63..4156b10a 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py +++ b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py @@ -66,6 +66,7 @@ def infer(self, frames: torch.Tensor, interval: int) -> bool: for image1, image2 in zip(frames[:-1], frames[1:]): padder = InputPadder(image1.shape) image1, image2 = padder.pad(image1, image2) + # 20 iterations as the original DynamicDegree implementation. _, flow_up = self.model(image1, image2, iters=20, test_mode=True) max_rad = self.get_score(image1, flow_up) static_score.append(max_rad) @@ -111,7 +112,7 @@ class VBenchDynamicDegree(StatefulMetric, VBenchMixin): runs_on: List[str] = ["cuda"] modality: List[str] = ["video"] # state - scores: List[float] + scores: List[bool] def __init__( self, diff --git a/src/pruna/evaluation/metrics/vbench_utils.py b/src/pruna/evaluation/metrics/vbench_utils.py index 1c0c4468..26a809f8 100644 --- a/src/pruna/evaluation/metrics/vbench_utils.py +++ b/src/pruna/evaluation/metrics/vbench_utils.py @@ -127,7 +127,7 @@ def load_videos_from_path(path: str | Path) -> torch.Tensor: Returns: ------- - List[List[Image] | np.ndarray | torch.Tensor] + torch.Tensor The videos. """ path = Path(str(path)) @@ -221,6 +221,8 @@ def _normalize_prompts( ---------- prompts : str | List[str] | PrunaDataModule The prompts to normalize. + split : str + The dataset split to sample from. Returns: ------- @@ -247,7 +249,9 @@ def _ensure_dir(p: Path) -> None: p.mkdir(parents=True, exist_ok=True) -def create_vbench_file_name(prompt: str, idx: int, special_str: str = "", postfix: str = ".mp4") -> str: +def create_vbench_file_name( + prompt: str, idx: int, special_str: str = "", save_format: str = ".mp4", max_filename_length: int = 255 +) -> str: """ Create a file name for the video in accordance with the VBench format. @@ -259,18 +263,26 @@ def create_vbench_file_name(prompt: str, idx: int, special_str: str = "", postfi The index of the video. Vbench uses 5 seeds for each prompt. special_str: str A special string to add to the file name if you wish to add a specific identifier. - postfix: str + save_format: str The format of the video file. Vbench supports mp4 and gif. + max_filename_length: int + The maximum length allowed for the file name. Returns: ------- str The file name for the video. """ - return f"{prompt}{special_str}-{str(idx)}{postfix}" + filename = f"{prompt}{special_str}-{str(idx)}{save_format}" + if len(filename) > max_filename_length: + pruna_logger.debug( + f"File name {filename} is too long. Maximum length is {max_filename_length} characters. Truncating filename." + ) + filename = filename[:max_filename_length] + return filename -def sample_video_from_pipelines(pipeline: Any, seeder: Any, prompt: str, device: str | torch.device = None, **kwargs): +def sample_video_from_pipelines(pipeline: Any, seeder: Any, prompt: str, **kwargs): """ Sample a video from diffusers pipeline. @@ -290,8 +302,13 @@ def sample_video_from_pipelines(pipeline: Any, seeder: Any, prompt: str, device: torch.Tensor The video tensor. """ + is_return_dict = kwargs.pop("return_dict", True) with torch.inference_mode(): - out = pipeline(prompt=prompt, generator=seeder, **kwargs).frames[0] + if is_return_dict: + out = pipeline(prompt=prompt, generator=seeder, **kwargs).frames[0] + else: + # If return_dict is False, the pipeline returns a tuple of (frames, metadata). + out = pipeline(prompt=prompt, generator=seeder, **kwargs)[0] return out @@ -336,6 +353,7 @@ def generate_videos( fps: int = 16, save_dir: str | Path = "./saved_videos", save_format: str = "mp4", + filename_fn: Callable = create_vbench_file_name, special_str: str = "", device: str | torch.device = None, **model_kwargs, @@ -369,8 +387,12 @@ def generate_videos( The directory to save the videos to. save_format : str The format to save the videos in. VBench supports mp4 and gif. + filename_fn: Callable + The function to create the file name. special_str : str A special string to add to the file name if you wish to add a specific identifier. + device : str | torch.device | None + The device to sample on. If None, the best available device will be used. **model_kwargs : Any Additional keyword arguments to pass to the sampling function. """ @@ -390,7 +412,7 @@ def generate_videos( for batch in prompt_iterable: prompt = prepare_batch(batch) for idx in range(unique_sample_per_video_count): - file_name = create_vbench_file_name(sanitize_prompt(prompt), idx, special_str, file_extension) + file_name = filename_fn(sanitize_prompt(prompt), idx, special_str, file_extension) out_path = save_dir / file_name vid = sampler(prompt=prompt, seeder=seed_rng, device=device, **model_kwargs) @@ -412,6 +434,8 @@ def evaluate_videos( The data to evaluate. metrics : StatefulMetric | List[StatefulMetric] The metrics to evaluate. + prompts : Any | None + The prompts to evaluate. Returns: ------- @@ -423,12 +447,13 @@ def evaluate_videos( metrics = [metrics] if any(metric.call_type != "y" for metric in metrics) and prompts is None: raise ValueError( - "You are trying to evaluate metrics that require more than the outputs,but didn't provide prompts." + "You are trying to evaluate metrics that require more than the outputs, but didn't provide prompts." ) for metric in metrics: for batch in data: if prompts is None: prompts = batch metric.update(prompts, batch, batch) + prompts = None results.append(metric.compute()) return results From 9abf89873b41f5991ff540605a476dc165c5f35b Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Mon, 3 Nov 2025 13:46:10 +0000 Subject: [PATCH 08/42] refactor: address PR comments --- .../metrics/metric_vbench_background_consistency.py | 6 +++--- .../evaluation/metrics/metric_vbench_dynamic_degree.py | 6 +++--- tests/evaluation/test_vbench_metrics.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py index bea0e8dd..127ff159 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py +++ b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py @@ -22,7 +22,7 @@ from torchvision.transforms.functional import convert_image_dtype from vbench.utils import clip_transform, init_submodules -from pruna.engine.utils import set_to_best_available_device +from pruna.engine.utils import get_device_type, set_to_best_available_device from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -70,8 +70,8 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) - if device is not None and str(device).split(":")[0] not in self.runs_on: - pruna_logger.error(f"Unsupported device {device}; supported: {self.runs_on}") + if device is not None and get_device_type(device) not in self.runs_on: + pruna_logger.error(f"Unsupported device {get_device_type(device)}; supported: {self.runs_on}") raise ValueError() if call_type == PAIRWISE: diff --git a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py index 4156b10a..46d06a72 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py +++ b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py @@ -23,7 +23,7 @@ from vbench.third_party.RAFT.core.utils_core.utils import InputPadder from vbench.utils import init_submodules -from pruna.engine.utils import set_to_best_available_device +from pruna.engine.utils import get_device_type, set_to_best_available_device from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -124,8 +124,8 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) - if device is not None and str(device).split(":")[0] not in self.runs_on: - pruna_logger.error(f"Unsupported device {device}; supported: {self.runs_on}") + if device is not None and get_device_type(device) not in self.runs_on: + pruna_logger.error(f"Unsupported device {get_device_type(device)}; supported: {self.runs_on}") raise ValueError() if call_type == PAIRWISE: diff --git a/tests/evaluation/test_vbench_metrics.py b/tests/evaluation/test_vbench_metrics.py index 67a7d816..5890dd9f 100644 --- a/tests/evaluation/test_vbench_metrics.py +++ b/tests/evaluation/test_vbench_metrics.py @@ -32,7 +32,7 @@ def test_metric_background_consistency(): @pytest.mark.cuda @pytest.mark.parametrize("model_fixture", ["wan_tiny_random"], indirect=["model_fixture"]) -def test_metrtic_dynamic_degree_dynamic(model_fixture): +def test_metric_dynamic_degree_dynamic(model_fixture): """Test metric dynamic degree with an example sample from the vbench dataset that returns a dynamic video.""" model, smash_config = model_fixture model.to("cuda") From e4bc717bcf274981872f65b60e24c2ea4a164052 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Mon, 10 Nov 2025 13:09:51 +0000 Subject: [PATCH 09/42] test: adding more tests for dynamic degree and background consistency --- pyproject.toml | 4 +- .../metric_vbench_background_consistency.py | 15 +-- .../metrics/metric_vbench_dynamic_degree.py | 30 ++--- src/pruna/evaluation/metrics/vbench_utils.py | 100 ++++++++-------- tests/evaluation/test_vbench_metrics.py | 108 +++++++++++++++++- 5 files changed, 176 insertions(+), 81 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 88a880e0..bed2e6ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,9 @@ unsupported-operator = "ignore" # mypy supports | syntax with from __future__ im invalid-argument-type = "ignore" # mypy is more permissive with argument types invalid-return-type = "ignore" # mypy is more permissive with return types invalid-parameter-default = "ignore" # mypy is more permissive with parameter defaults +possibly-missing-attribute = "ignore" # mypy is more permissive with attribute access +possibly-unbound-attribute = "ignore" +possibly-missing-import = "ignore" # mypy is more permissive with imports no-matching-overload = "ignore" # mypy is more permissive with overloads unresolved-reference = "ignore" # mypy is more permissive with references possibly-unbound-import = "ignore" @@ -187,7 +190,6 @@ dev = [ ] cpu = [] - [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py index 127ff159..87e6ff2b 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py +++ b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py @@ -22,7 +22,6 @@ from torchvision.transforms.functional import convert_image_dtype from vbench.utils import clip_transform, init_submodules -from pruna.engine.utils import get_device_type, set_to_best_available_device from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -42,8 +41,6 @@ class VBenchBackgroundConsistency(StatefulMetric, VBenchMixin): ---------- *args : Any The arguments to pass to the metric. - device : str | None - The device to run the metric on. call_type : str The call type to use for the metric. **kwargs : Any @@ -53,9 +50,7 @@ class VBenchBackgroundConsistency(StatefulMetric, VBenchMixin): metric_name: str = METRIC_VBENCH_BACKGROUND_CONSISTENCY default_call_type: str = "y" # We just need the outputs higher_is_better: bool = True - # https://github.com/Vchitect/VBench/blob/dc62783c0fb4fd333249c0b669027fe102696682/evaluate.py#L111 - # explicitly sets the device to cuda. We respect this here. - runs_on: List[str] = ["cuda, cpu"] + runs_on: List[str] = ["cuda", "cpu"] modality: List[str] = ["video"] # state similarity_scores: torch.Tensor @@ -64,15 +59,10 @@ class VBenchBackgroundConsistency(StatefulMetric, VBenchMixin): def __init__( self, *args: Any, - device: str | None = None, call_type: str = SINGLE, **kwargs: Any, ) -> None: - super().__init__(*args, **kwargs) - - if device is not None and get_device_type(device) not in self.runs_on: - pruna_logger.error(f"Unsupported device {get_device_type(device)}; supported: {self.runs_on}") - raise ValueError() + super().__init__(kwargs.pop("device", None)) if call_type == PAIRWISE: # VBench itself does not support pairwise. @@ -83,7 +73,6 @@ def __init__( submodules_dict = init_submodules([METRIC_VBENCH_BACKGROUND_CONSISTENCY]) model_path = submodules_dict[METRIC_VBENCH_BACKGROUND_CONSISTENCY][0] - self.device = set_to_best_available_device(device) self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) self.clip_model, self.preprocessor = clip.load(model_path, device=self.device) diff --git a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py index 46d06a72..83a505f4 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py +++ b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py @@ -23,7 +23,6 @@ from vbench.third_party.RAFT.core.utils_core.utils import InputPadder from vbench.utils import init_submodules -from pruna.engine.utils import get_device_type, set_to_best_available_device from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -35,7 +34,16 @@ class PrunaDynamicDegree(DynamicDegree): - """Helper class to compute Dynamic Degree score for a given video.""" + """ + Helper class to compute Dynamic Degree score for a given video. + + Parameters + ---------- + args : EasyDict + The arguments to pass to the RAFT model. + device : str | torch.device + The device to use for the model. + """ def infer(self, frames: torch.Tensor, interval: int) -> bool: """ @@ -45,9 +53,9 @@ def infer(self, frames: torch.Tensor, interval: int) -> bool: Parameters ---------- - frames: torch.Tensor + frames : torch.Tensor The video frames to compute the Dynamic Degree score for. - interval: int + interval : int The interval to skip frames. It's possible for each consecutive frame to not have extreme motion, even though the video itself contains large dynamic changes. Therefore it's important to set the inteval to skip frames correctly. @@ -88,9 +96,6 @@ class VBenchDynamicDegree(StatefulMetric, VBenchMixin): ---------- *args : Any The arguments to be passed to the DynamicDegree class. - device : str | None, optional - The device to be used, e.g., 'cuda' or 'cpu'. Default is None. - If None, the best available device will be used. call_type : str, default="y" The call type to be used, e.g., 'y' or 'y_gt'. Default is "y". interval : int, default=3 @@ -117,16 +122,11 @@ class VBenchDynamicDegree(StatefulMetric, VBenchMixin): def __init__( self, *args: Any, - device: str | None = None, call_type: str = SINGLE, interval: int = 3, **kwargs: Any, ) -> None: - super().__init__(*args, **kwargs) - - if device is not None and get_device_type(device) not in self.runs_on: - pruna_logger.error(f"Unsupported device {get_device_type(device)}; supported: {self.runs_on}") - raise ValueError() + super().__init__(device=kwargs.pop("device", None)) if call_type == PAIRWISE: # VBench itself does not support pairwise. @@ -138,7 +138,6 @@ def __init__( submodules_dict = init_submodules([METRIC_VBENCH_DYNAMIC_DEGREE]) model_path = submodules_dict[METRIC_VBENCH_DYNAMIC_DEGREE]["model"] - self.device = set_to_best_available_device(device) self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) # RAFT models expect arguments to be passed as an object with attributes. # So we need to convert the arguments to an EasyDict. @@ -185,6 +184,9 @@ def compute(self) -> MetricResult: MetricResult The dynamic degree score. """ + if len(self.scores) == 0: + pruna_logger.warning("No scores have been computed. Returning 0.0.") + return MetricResult(name=self.metric_name, params=self.__dict__, result=0.0) final_score = np.mean(self.scores) return MetricResult(name=self.metric_name, params=self.__dict__, result=final_score) diff --git a/src/pruna/evaluation/metrics/vbench_utils.py b/src/pruna/evaluation/metrics/vbench_utils.py index 26a809f8..180f4c4d 100644 --- a/src/pruna/evaluation/metrics/vbench_utils.py +++ b/src/pruna/evaluation/metrics/vbench_utils.py @@ -37,13 +37,6 @@ class VBenchMixin: Mixin class for VBench metrics. Handles benchmark specific initilizations and artifact saving conventions. - - Parameters - ---------- - *args: Any - The arguments to pass to the metric. - **kwargs: Any - The keyword arguments to pass to the metric. """ def create_filename(self, prompt: str, idx: int, file_extension: str, special_str: str = "") -> str: @@ -52,13 +45,13 @@ def create_filename(self, prompt: str, idx: int, file_extension: str, special_st Parameters ---------- - prompt: str + prompt : str The prompt to create the filename from. - idx: int + idx : int The index of the video. Vbench uses 5 seeds for each prompt. - file_extension: str + file_extension : str The file extension to use. Vbench supports mp4 and gif. - special_str: str + special_str : str A special string to add to the filename if you wish to add a specific identifier. Returns @@ -72,12 +65,12 @@ def validate_batch(self, batch: torch.Tensor) -> torch.Tensor: """ Make sure that the video tensor has correct dimensions. - Parameters: + Parameters ---------- - batch: torch.Tensor + batch : torch.Tensor The video tensor. - Returns: + Returns ------- torch.Tensor The video tensor. @@ -93,14 +86,14 @@ def load_video(path: str | Path, return_type: str = "pt") -> List[Image] | np.nd """ Load videos from a path. - Parameters: + Parameters ---------- - path: str | Path + path : str | Path The path to the videos. - return_type: str + return_type : str The type to return the videos as. Can be "pt", "np", "pil". - Returns: + Returns ------- List[torch.Tensor] The videos. @@ -120,12 +113,12 @@ def load_videos_from_path(path: str | Path) -> torch.Tensor: """ Load entire directory of mp4 videos as a single tensor ready to be passed to evaluation. - Parameters: + Parameters ---------- path : str | Path The path to the directory of videos. - Returns: + Returns ------- torch.Tensor The videos. @@ -142,12 +135,12 @@ def sanitize_prompt(prompt: str) -> str: Replaces characters illegal in filenames and collapses whitespace so that generated files are portable across file systems. - Parameters: + Parameters ---------- prompt : str The prompt to sanitize. - Returns: + Returns ------- str The sanitized prompt. @@ -165,13 +158,12 @@ def prepare_batch(batch: str | tuple[str | List[str], Any]) -> str: Pruna datamodules are expected to yield tuples where the first element is a sequence of inputs; this utility enforces batch_size == 1 for simplicity. - - Parameters: + Parameters ---------- - batch: str | tuple[str | List[str], Any] + batch : str | tuple[str | List[str], Any] The batch to prepare. - Returns: + Returns ------- str The prompt string. @@ -193,12 +185,12 @@ def _normalize_save_format(save_format: str) -> tuple[str, Callable]: """ Normalize the save format to be used in the generate_videos function. - Parameters: + Parameters ---------- save_format : str The format to save the videos in. VBench supports mp4 and gif. - Returns: + Returns ------- tuple[str, Callable] The normalized save format and the save function. @@ -217,14 +209,14 @@ def _normalize_prompts( """ Normalize prompts to an iterable format to be used in the generate_videos function. - Parameters: + Parameters ---------- prompts : str | List[str] | PrunaDataModule The prompts to normalize. split : str The dataset split to sample from. - Returns: + Returns ------- Iterable[str] The normalized prompts. @@ -241,7 +233,7 @@ def _ensure_dir(p: Path) -> None: """ Ensure the directory exists. - Parameters: + Parameters ---------- p : Path The path to ensure the directory exists. @@ -255,20 +247,20 @@ def create_vbench_file_name( """ Create a file name for the video in accordance with the VBench format. - Parameters: + Parameters ---------- - prompt: str + prompt : str The prompt to create the file name from. - idx: int + idx : int The index of the video. Vbench uses 5 seeds for each prompt. - special_str: str + special_str : str A special string to add to the file name if you wish to add a specific identifier. - save_format: str + save_format : str The format of the video file. Vbench supports mp4 and gif. - max_filename_length: int + max_filename_length : int The maximum length allowed for the file name. - Returns: + Returns ------- str The file name for the video. @@ -286,18 +278,18 @@ def sample_video_from_pipelines(pipeline: Any, seeder: Any, prompt: str, **kwarg """ Sample a video from diffusers pipeline. - Parameters: + Parameters ---------- - pipeline: Any + pipeline : Any The pipeline to sample from. - prompt: str - The prompt to sample from. - seeder: Any + seeder : Any The seeding generator. - **kwargs: Any + prompt : str + The prompt to sample from. + **kwargs : Any Additional keyword arguments to pass to the pipeline. - Returns: + Returns ------- torch.Tensor The video tensor. @@ -323,13 +315,17 @@ def _wrap_sampler(model: Any, sampling_fn: Callable[..., Any]) -> Callable[..., This wrapper always passes `model` as the first positional argument, so custom functions can name their first parameter `model` or `pipeline`, etc. - Parameters: + Parameters ---------- - model: Any + model : Any The model to sample from. - sampling_fn: Callable[..., Any] + sampling_fn : Callable[..., Any] The sampling function to wrap. + Returns + ------- + Callable[..., Any] + The wrapped sampling function. """ if sampling_fn != sample_video_from_pipelines: pruna_logger.info( @@ -366,7 +362,7 @@ def generate_videos( - Uses an RNG seeded with `global_seed` for reproducibility across runs. - Saves videos as MP4 or GIF. - Parameters: + Parameters ---------- model : Any The model to sample from. @@ -387,7 +383,7 @@ def generate_videos( The directory to save the videos to. save_format : str The format to save the videos in. VBench supports mp4 and gif. - filename_fn: Callable + filename_fn : Callable The function to create the file name. special_str : str A special string to add to the file name if you wish to add a specific identifier. @@ -428,7 +424,7 @@ def evaluate_videos( """ Evaluation loop helper. - Parameters: + Parameters ---------- data : Any The data to evaluate. @@ -437,7 +433,7 @@ def evaluate_videos( prompts : Any | None The prompts to evaluate. - Returns: + Returns ------- List[MetricResult] The results of the evaluation. diff --git a/tests/evaluation/test_vbench_metrics.py b/tests/evaluation/test_vbench_metrics.py index 5890dd9f..c1b98d0a 100644 --- a/tests/evaluation/test_vbench_metrics.py +++ b/tests/evaluation/test_vbench_metrics.py @@ -1,12 +1,14 @@ from __future__ import annotations +from __future__ import annotations import pytest import torch import numpy as np - from pruna.evaluation.metrics.metric_vbench_background_consistency import VBenchBackgroundConsistency from pruna.evaluation.metrics.metric_vbench_dynamic_degree import VBenchDynamicDegree +from pruna.evaluation.metrics.utils import PAIRWISE +from pruna.evaluation.metrics.registry import MetricRegistry @pytest.mark.cuda @@ -56,3 +58,107 @@ def test_metric_dynamic_degree_static(): metric.update(video, video, video) result = metric.compute() assert result.result == 0.0 + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", [VBenchDynamicDegree, VBenchBackgroundConsistency]) +def test_metric_pairwise_call_type(vbench_metric): + """Test that VBenchBackgroundConsistency raises ValueError for PAIRWISE call_type.""" + with pytest.raises(ValueError): + vbench_metric(call_type=PAIRWISE) + + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", ['background_consistency', 'dynamic_degree']) +def test_vbench_metrics_invalid_tensor_dimensions(vbench_metric): + """Test that validate_batch raises ValueError for invalid tensor dimensions.""" + metric = MetricRegistry.get_metric(vbench_metric) + # Test 3D tensor (should fail) + invalid_tensor = torch.randn(2, 3, 16) + with pytest.raises(ValueError, match="4 or 5 dimensional"): + metric.validate_batch(invalid_tensor) + + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", ["background_consistency", "dynamic_degree"]) +def test_vbench_metrics_compute_without_updates(vbench_metric): + """Test compute() returns 0.0 when no updates have been made.""" + metric = MetricRegistry.get_metric(vbench_metric) + result = metric.compute() + assert result.result == 0.0 + assert result.name == vbench_metric + + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", ["background_consistency", "dynamic_degree"]) +def test_vbench_metrics_4d_tensor(vbench_metric): + """Test that 4D tensors (B, C, H, W) are properly converted to 5D.""" + # 4D tensor should be converted to 5D by validate_batch + four_d_video = torch.randn(2, 3, 16, 16) # Missing time dimension + metric = MetricRegistry.get_metric(vbench_metric) + metric.update(four_d_video, four_d_video, four_d_video) + result = metric.compute() + assert result.result >= 0.0 + assert result.name == vbench_metric + + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", ["background_consistency", "dynamic_degree"]) +def test_vbench_metrics_different_batch_sizes(vbench_metric): + """Test background consistency with different batch sizes.""" + metric = MetricRegistry.get_metric(vbench_metric) + for batch_size in [1, 2, 3]: + video_batch = torch.randn(batch_size, 5, 3, 64, 64) + metric.update(video_batch, video_batch, video_batch) + result = metric.compute() + assert 0.0 <= result.result <= 1.0 + metric.reset() + + + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", ["background_consistency", "dynamic_degree"]) +def test_vbench_metrics_different_resolutions(vbench_metric): + """Test background consistency with different video resolutions.""" + resolutions = [(64, 64), (128, 128), (224, 224)] + for h, w in resolutions: + video = torch.randn(1, 5, 3, h, w) + metric = MetricRegistry.get_metric(vbench_metric) + metric.update(video, video, video) + result = metric.compute() + assert 0.0 <= result.result <= 1.0 + metric.reset() + + + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", ["background_consistency", "dynamic_degree"]) +def test_vbench_metrics_reset_clears_state(vbench_metric): + """Test that reset() properly clears the metric state.""" + metric = MetricRegistry.get_metric(vbench_metric) + video = torch.randn(1, 5, 3, 64, 64) + + # First update and compute + metric.update(video, video, video) + result1 = metric.compute() + + # Reset and verify state is cleared + metric.reset() + result2 = metric.compute() + assert result2.result == 0.0 # Should be 0.0 after reset with no updates + + # Update again and verify it works + metric.update(video, video, video) + result3 = metric.compute() + assert result3.result >= 0.0 + + + +@pytest.mark.cuda +@pytest.mark.parametrize("interval", [1, 3, 5, 10]) +def test_metric_dynamic_degree_different_intervals(interval): + """Test dynamic degree with different interval values.""" + video = torch.zeros(1, 12, 3, 64, 64) # 12 frames to allow various intervals + metric = VBenchDynamicDegree(interval=interval) + metric.update(video, video, video) + result = metric.compute() + assert 0.0 <= result.result <= 1.0 # Score should be in valid range From 2f37e88b264231dc2d5556e9426f0f8bd93412fc Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Mon, 6 Oct 2025 16:01:32 +0000 Subject: [PATCH 10/42] feat: artifact saving and vbench related agent updates --- .../artifactsavers/artifactsaver.py | 98 +++++++++++++++++++ src/pruna/evaluation/artifactsavers/utils.py | 46 +++++++++ .../artifactsavers/video_artifactsaver.py | 73 ++++++++++++++ src/pruna/evaluation/evaluation_agent.py | 72 ++++++++++---- src/pruna/evaluation/metrics/metric_cmmd.py | 1 + .../metrics/metric_pairwise_clip.py | 1 + .../evaluation/metrics/metric_sharpness.py | 1 + .../evaluation/metrics/metric_stateful.py | 20 ++++ src/pruna/evaluation/metrics/metric_torch.py | 30 +++--- src/pruna/evaluation/task.py | 25 +++++ 10 files changed, 332 insertions(+), 35 deletions(-) create mode 100644 src/pruna/evaluation/artifactsavers/artifactsaver.py create mode 100644 src/pruna/evaluation/artifactsavers/utils.py create mode 100644 src/pruna/evaluation/artifactsavers/video_artifactsaver.py diff --git a/src/pruna/evaluation/artifactsavers/artifactsaver.py b/src/pruna/evaluation/artifactsavers/artifactsaver.py new file mode 100644 index 00000000..51bbb3d5 --- /dev/null +++ b/src/pruna/evaluation/artifactsavers/artifactsaver.py @@ -0,0 +1,98 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + + +class ArtifactSaver(ABC): + """ + Abstract class for artifact savers. + + The artifact saver is responsible for saving the inference outputs during evaluation. + + There needs to be a subclass for each metric modality (e.g. video, image, text, etc.). + + Parameters + ---------- + export_format: str | None + The format to export the artifacts in. + root: Path | str | None + The root directory to save the artifacts in. + """ + + export_format: str | None = None + root: Path | str | None = None + + @abstractmethod + def save_artifact(self, data: Any) -> Path: + """ + Implement this method to save the artifact. + + Parameters + ---------- + data: Any + The data to save. + + Returns + ------- + Path + The full path to the saved artifact. + """ + pass + + def create_alias(self, source_path: Path | str, filename: str) -> Path: + """ + Create an alias for the artifact. + + The evaluation agent will save the inference outputs with a canonical file + formatting style that makes sense for the general case. + + If your metric requires a different file naming convention for evaluation, + you can use this method to create an alias for the artifact. + + This way we prevent duplicate artifacts from being saved and save storage space. + + By default, the alias will be created as a hardlink to the source artifact. + If the hardlink fails, a symlink will be created. + + Parameters + ---------- + source_path : Path | str + The path to the source artifact. + filename : str + The filename to create the alias for. + + Returns + ------- + Path + The full path to the alias. + """ + alias = Path(str(self.root)) / f"{filename}.{self.export_format}" + alias.parent.mkdir(parents=True, exist_ok=True) + try: + if alias.exists(): + alias.unlink() + alias.hardlink_to(source_path) + except Exception: + try: + if alias.exists(): + alias.unlink() + alias.symlink_to(source_path) + except Exception as e: + raise e + return alias diff --git a/src/pruna/evaluation/artifactsavers/utils.py b/src/pruna/evaluation/artifactsavers/utils.py new file mode 100644 index 00000000..bf9787f9 --- /dev/null +++ b/src/pruna/evaluation/artifactsavers/utils.py @@ -0,0 +1,46 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pathlib import Path + +from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver +from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver + + +def assign_artifact_saver( + modality: str, root: Path | str | None = None, export_format: str | None = None +) -> ArtifactSaver: + """ + Assign the appropriate artifact saver based on the modality. + + Parameters + ---------- + modality: str + The modality of the data. + root: str + The root directory to save the artifacts. + export_format: str + The format to save the artifacts. + + Returns + ------- + ArtifactSaver + The appropriate artifact saver. + """ + if modality != "video": + raise ValueError(f"Modality {modality} is not supported") + else: + return VideoArtifactSaver(root=root, export_format=export_format) diff --git a/src/pruna/evaluation/artifactsavers/video_artifactsaver.py b/src/pruna/evaluation/artifactsavers/video_artifactsaver.py new file mode 100644 index 00000000..de0efcef --- /dev/null +++ b/src/pruna/evaluation/artifactsavers/video_artifactsaver.py @@ -0,0 +1,73 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import secrets +import time +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from diffusers.utils import export_to_gif, export_to_video + +from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver + + +class VideoArtifactSaver(ArtifactSaver): + """ + Save video artifacts. + + Parameters + ---------- + root: Path | str | None = None + The root directory to save the artifacts. + export_format: str | None = "mp4" + The format to save the artifacts. + """ + + export_format: str | None + root: Path | str | None + + def __init__(self, root: Path | str | None = None, export_format: str | None = "mp4") -> None: + self.root = Path(root) if root else Path.cwd() + (self.root / "canonical").mkdir(parents=True, exist_ok=True) + self.export_format = export_format if export_format else "mp4" + + def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: + """ + Save the video artifact. + + Parameters + ---------- + data: Any + The data to save. + + Returns + ------- + Path + The path to the saved artifact. + """ + canonical_filename = f"{int(time.time())}_{secrets.token_hex(4)}.{self.export_format}" + canonical_path = Path(str(self.root)) / "canonical" / canonical_filename + if isinstance(data, torch.Tensor): + data = np.transpose(data.cpu().numpy(), (0, 2, 3, 1)) + if self.export_format == "mp4": + export_to_video(data, canonical_path, **saving_kwargs) + elif self.export_format == "gif": + export_to_gif(data, canonical_path, **saving_kwargs) + else: + raise ValueError(f"Invalid format: {self.export_format}") + return canonical_path diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 20cc2fd5..18ccc694 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -14,7 +14,9 @@ from __future__ import annotations -from typing import Any, List +import tempfile +from pathlib import Path +from typing import Any, List, Literal import torch from torch import Tensor @@ -25,7 +27,8 @@ from pruna.data.pruna_datamodule import PrunaDataModule from pruna.data.utils import move_batch_to_device from pruna.engine.pruna_model import PrunaModel -from pruna.engine.utils import get_device, move_to_device, safe_memory_cleanup, set_to_best_available_device +from pruna.engine.utils import get_device, move_to_device, get_device, safe_memory_cleanup, set_to_best_available_device +from pruna.evaluation.artifactsavers.utils import assign_artifact_saver from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.result import MetricResult @@ -58,6 +61,12 @@ def __init__( request: str | List[str | BaseMetric | StatefulMetric] | None = None, datamodule: PrunaDataModule | None = None, device: str | torch.device | None = None, + save_artifacts: bool = False, + root_dir: str | Path | None = None, + num_samples_per_input: int = 1, + seed_strategy: Literal["per_evaluation", "per_sample", "no_seed"] = "no_seed", + global_seed: int | None = None, + saving_kwargs: dict = dict(), ) -> None: if task is not None: if request is not None or datamodule is not None or device is not None: @@ -70,12 +79,19 @@ def __init__( if request is None or datamodule is None: raise ValueError("When not using 'task' parameter, both 'request' and 'datamodule' must be provided.") self.task = Task(request=request, datamodule=datamodule, device=device) - self.first_model_results: List[MetricResult] = [] self.subsequent_model_results: List[MetricResult] = [] + self.seed_strategy = seed_strategy + self.num_samples_per_input = num_samples_per_input + self.global_seed = global_seed self.device = set_to_best_available_device(self.task.device) self.cache: List[Tensor] = [] self.evaluation_for_first_model: bool = True + self.save_artifacts: bool = save_artifacts + if save_artifacts: + self.root_dir = root_dir if root_dir is not None else tempfile.mkdtemp(prefix="inference_outputs") + self.artifact_saver = assign_artifact_saver(self.task.modality, self.root_dir) + self.saving_kwargs = saving_kwargs def evaluate(self, model: Any) -> List[MetricResult]: """ @@ -153,7 +169,7 @@ def prepare_model(self, model: Any) -> PrunaModel: ) else: - smash_config = SmashConfig(device="cpu") + smash_config = SmashConfig(device=get_device(model)) model = PrunaModel(model, smash_config=smash_config) pruna_logger.info("Evaluating a base model.") is_base = True @@ -183,6 +199,9 @@ def prepare_model(self, model: Any) -> PrunaModel: self.device = self.task.device # Keeping the device map to move model back to the original device, when the agent is finished. self.device_map = get_device_map(model) + model.set_to_eval() + model.inference_handler.configure_seed(self.seed_strategy, self.global_seed) + return model def update_stateful_metrics( @@ -212,22 +231,35 @@ def update_stateful_metrics( move_to_device(model, self.device, device_map=self.device_map) for batch_idx, batch in enumerate(tqdm(self.task.dataloader, desc="Processing batches", unit="batch")): - processed_outputs = model.run_inference(batch) - - batch = move_batch_to_device(batch, self.device) - processed_outputs = move_batch_to_device(processed_outputs, self.device) - (x, gt) = batch - # Non-pairwise (aka single) metrics have regular update. - for stateful_metric in single_stateful_metrics: - stateful_metric.update(x, gt, processed_outputs) - - # Cache outputs once in the agent for pairwise metrics to save compute time and memory. - if self.task.is_pairwise_evaluation(): - if self.evaluation_for_first_model: - self.cache.append(processed_outputs) - else: - for pairwise_metric in pairwise_metrics: - pairwise_metric.update(x, self.cache[batch_idx], processed_outputs) + for sample_idx in range(self.num_samples_per_input): + processed_outputs = model.run_inference(batch) + if self.save_artifacts: + canonical_paths = [] + for processed_output in processed_outputs: + canonical_path = self.artifact_saver.save_artifact(processed_output) + canonical_paths.append(canonical_path) + + batch = move_batch_to_device(batch, self.device) + processed_outputs = move_batch_to_device(processed_outputs, self.device) + (x, gt) = batch + # Non-pairwise (aka single) metrics have regular update. + for stateful_metric in single_stateful_metrics: + stateful_metric.update(x, gt, processed_outputs) + if stateful_metric.create_alias: + for prompt_idx, prompt in enumerate(x): + assert isinstance(self.artifact_saver.export_format, str) + alias_filename = stateful_metric.create_filename( + filename=prompt, idx=sample_idx, file_extension=self.artifact_saver.export_format + ) + self.artifact_saver.create_alias(canonical_paths[prompt_idx], alias_filename) + + # Cache outputs once in the agent for pairwise metrics to save compute time and memory. + if self.task.is_pairwise_evaluation(): + if self.evaluation_for_first_model: + self.cache.append(processed_outputs) + else: + for pairwise_metric in pairwise_metrics: + pairwise_metric.update(x, self.cache[batch_idx], processed_outputs) def compute_stateful_metrics( self, single_stateful_metrics: List[StatefulMetric], pairwise_metrics: List[StatefulMetric] diff --git a/src/pruna/evaluation/metrics/metric_cmmd.py b/src/pruna/evaluation/metrics/metric_cmmd.py index 839cbe11..50521125 100644 --- a/src/pruna/evaluation/metrics/metric_cmmd.py +++ b/src/pruna/evaluation/metrics/metric_cmmd.py @@ -58,6 +58,7 @@ class CMMD(StatefulMetric): default_call_type: str = "gt_y" higher_is_better: bool = False metric_name: str = METRIC_CMMD + modality = ["image"] def __init__( self, diff --git a/src/pruna/evaluation/metrics/metric_pairwise_clip.py b/src/pruna/evaluation/metrics/metric_pairwise_clip.py index 62e436e5..87296377 100644 --- a/src/pruna/evaluation/metrics/metric_pairwise_clip.py +++ b/src/pruna/evaluation/metrics/metric_pairwise_clip.py @@ -47,6 +47,7 @@ class PairwiseClipScore(CLIPScore, StatefulMetric): # type: ignore[misc] higher_is_better: bool = True metric_name: str = "pairwise_clip_score" + modality = ["image"] def __init__(self, **kwargs: Any) -> None: device = kwargs.pop("device", None) diff --git a/src/pruna/evaluation/metrics/metric_sharpness.py b/src/pruna/evaluation/metrics/metric_sharpness.py index b09067de..759f0fca 100644 --- a/src/pruna/evaluation/metrics/metric_sharpness.py +++ b/src/pruna/evaluation/metrics/metric_sharpness.py @@ -64,6 +64,7 @@ class SharpnessMetric(StatefulMetric): higher_is_better: bool = True metric_name: str = METRIC_SHARPNESS runs_on: List[str] = ["cpu", "cuda"] + modality = ["image"] def __init__(self, *args, kernel_size: int = 3, call_type: str = SINGLE, **kwargs) -> None: super().__init__(device=kwargs.pop("device", None)) diff --git a/src/pruna/evaluation/metrics/metric_stateful.py b/src/pruna/evaluation/metrics/metric_stateful.py index 39fddcf6..19a201af 100644 --- a/src/pruna/evaluation/metrics/metric_stateful.py +++ b/src/pruna/evaluation/metrics/metric_stateful.py @@ -45,6 +45,8 @@ class StatefulMetric(ABC): metric_name: str call_type: str runs_on: list[str] = ["cuda", "cpu", "mps"] + create_alias: bool = False + modality: List[str] def __init__(self, device: str | torch.device | None = None, **kwargs) -> None: """Initialize the StatefulMetric class.""" @@ -167,3 +169,21 @@ def is_device_supported(self, device: str | torch.device) -> bool: """ dvc, _ = split_device(device_to_string(device)) return dvc in self.runs_on + + def create_filename(self, filename: str, idx: int, file_extension: str) -> str: + """ + Create a filename for the metric. + + Parameters + ---------- + filename: str + The name of the file. + file_extension: str + The extension of the file. + + Returns + ------- + str + The filename. + """ + return f"{filename}-{idx}.{file_extension}" diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index e1dfb0e0..951f1b8d 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -124,9 +124,7 @@ def arniqa_update(metric: ARNIQA, preds: Any) -> None: def ssim_update( - metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, - preds: Any, - target: Any + metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, preds: Any, target: Any ) -> None: """ Update handler for SSIM or MS-SSIM metric. @@ -170,23 +168,24 @@ class TorchMetrics(Enum): The starting value for the enum. """ - fid = (partial(FrechetInceptionDistance), fid_update, "gt_y") - accuracy = (partial(Accuracy), None, "y_gt") - perplexity = (partial(Perplexity), None, "y_gt") - clip_score = (partial(CLIPScore), None, "y_x") - precision = (partial(Precision), None, "y_gt") - recall = (partial(Recall), None, "y_gt") - psnr = (partial(PeakSignalNoiseRatio), None, "pairwise_y_gt") - ssim = (partial(StructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt") - msssim = (partial(MultiScaleStructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt") - lpips = (partial(LearnedPerceptualImagePatchSimilarity), lpips_update, "pairwise_y_gt") - arniqa = (partial(ARNIQA), arniqa_update, "y") - clipiqa = (partial(CLIPImageQualityAssessment), None, "y") + fid = (partial(FrechetInceptionDistance), fid_update, "gt_y", ["image"]) + accuracy = (partial(Accuracy), None, "y_gt", ["general"]) + perplexity = (partial(Perplexity), None, "y_gt", ["text"]) + clip_score = (partial(CLIPScore), None, "y_x", ["image"]) + precision = (partial(Precision), None, "y_gt", ["general"]) + recall = (partial(Recall), None, "y_gt", ["general"]) + psnr = (partial(PeakSignalNoiseRatio), None, "pairwise_y_gt", ["image"]) + ssim = (partial(StructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt", ["image"]) + msssim = (partial(MultiScaleStructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt", ["image"]) + lpips = (partial(LearnedPerceptualImagePatchSimilarity), lpips_update, "pairwise_y_gt", ["image"]) + arniqa = (partial(ARNIQA), arniqa_update, "y", ["image"]) + clipiqa = (partial(CLIPImageQualityAssessment), None, "y", ["image"]) def __init__(self, *args, **kwargs) -> None: self.tm = self.value[0] self.update_fn = self.value[1] or default_update self.call_type = self.value[2] + self.modality = self.value[3] def __call__(self, **kwargs) -> Metric: """ @@ -260,6 +259,7 @@ def __init__(self, metric_name: str, call_type: str = "", **kwargs) -> None: # Get the specific update function for the metric, or use the default if not found. self.update_fn = TorchMetrics[metric_name].update_fn + self.modality = TorchMetrics[metric_name].modality except KeyError: raise ValueError(f"Metric {metric_name} is not supported.") diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index e8c63688..316e5afa 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -70,6 +70,7 @@ def __init__( self.stateful_metric_device = self._get_stateful_metric_device_from_task_device() self.metrics = _safe_build_metrics(request, self.device, self.stateful_metric_device) + self.modality = self.validate_and_get_task_modality() self.datamodule = datamodule self.dataloader = datamodule.test_dataloader() @@ -154,6 +155,30 @@ def _safe_build_metrics( ) raise e + def validate_and_get_task_modality(self) -> str: + """ + Check if the task has a single modality of metrics. + + Inference handling is different for different modalities. + The task should have one consistent modality across metrics. + Stateless metrics and stateful metrics with general modalities are allowed. + + Returns + ------- + str + The modality of the task. + """ + stateful_metrics = [metric for metric in self.metrics if isinstance(metric, StatefulMetric)] + modalities_no_general = [metric.modality for metric in stateful_metrics if metric.modality != ["general"]] + # We should also allow 0 because the user might have only general modality metrics. + if len(modalities_no_general) == 1: + modality = modalities_no_general[0][0] + elif len(modalities_no_general) == 0: + modality = "general" + else: + raise ValueError("The task should have a single modality of quality metrics.") + return modality + def get_metrics( request: str | List[str | BaseMetric | StatefulMetric], inference_device: str, stateful_metric_device: str From b8f3e8d4bab64b8209452b77187a952e731f9436 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Tue, 7 Oct 2025 15:15:16 +0000 Subject: [PATCH 11/42] test: add tests for the artifact savers --- .../artifactsavers/video_artifactsaver.py | 7 +++ src/pruna/evaluation/evaluation_agent.py | 14 ++++- tests/evaluation/test_artifactsaver.py | 57 +++++++++++++++++++ 3 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 tests/evaluation/test_artifactsaver.py diff --git a/src/pruna/evaluation/artifactsavers/video_artifactsaver.py b/src/pruna/evaluation/artifactsavers/video_artifactsaver.py index de0efcef..b1c2539c 100644 --- a/src/pruna/evaluation/artifactsavers/video_artifactsaver.py +++ b/src/pruna/evaluation/artifactsavers/video_artifactsaver.py @@ -22,6 +22,7 @@ import numpy as np import torch from diffusers.utils import export_to_gif, export_to_video +from PIL import Image from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver @@ -62,8 +63,14 @@ def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: """ canonical_filename = f"{int(time.time())}_{secrets.token_hex(4)}.{self.export_format}" canonical_path = Path(str(self.root)) / "canonical" / canonical_filename + + # all diffusers saving utility functions accept a list of PIL.Images, so we convert to PIL to be safe. if isinstance(data, torch.Tensor): data = np.transpose(data.cpu().numpy(), (0, 2, 3, 1)) + data = np.clip(data * 255, 0, 255).astype(np.uint8) + if isinstance(data, np.ndarray): + data = [Image.fromarray(frame.astype(np.uint8)) for frame in data] + if self.export_format == "mp4": export_to_video(data, canonical_path, **saving_kwargs) elif self.export_format == "gif": diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 18ccc694..ff168f3c 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -52,6 +52,18 @@ class EvaluationAgent: device : str | torch.device | None, optional The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. + save_artifacts : bool, optional + Whether to save the artifacts. Default is False. + root_dir : str | Path | None, optional + The directory to save the artifacts. Default is None. + num_samples_per_input : int, optional + The number of samples to generate per input. Default is 1. + seed_strategy : Literal["per_sample", "no_seed"], optional + The seed strategy to use. Default is "no_seed". + global_seed : int | None, optional + The global seed to use. Default is None. + saving_kwargs : dict, optional + The kwargs to pass to the artifact saver. Default is an empty dict. """ def __init__( @@ -64,7 +76,7 @@ def __init__( save_artifacts: bool = False, root_dir: str | Path | None = None, num_samples_per_input: int = 1, - seed_strategy: Literal["per_evaluation", "per_sample", "no_seed"] = "no_seed", + seed_strategy: Literal["per_sample", "no_seed"] = "no_seed", global_seed: int | None = None, saving_kwargs: dict = dict(), ) -> None: diff --git a/tests/evaluation/test_artifactsaver.py b/tests/evaluation/test_artifactsaver.py new file mode 100644 index 00000000..9655e8dc --- /dev/null +++ b/tests/evaluation/test_artifactsaver.py @@ -0,0 +1,57 @@ +import pytest +import tempfile +from pathlib import Path + +import numpy as np +import torch + +from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver +from pruna.evaluation.artifactsavers.utils import assign_artifact_saver +from pruna.evaluation.metrics.vbench_utils import load_videos +from PIL import Image +import itertools + + +def test_create_alias(): + with tempfile.TemporaryDirectory() as tmp_path: + saver = VideoArtifactSaver(root=tmp_path, export_format="mp4") + dummy_video = np.random.randint(0, 255, (10, 16, 16, 3), dtype=np.uint8) + source_filename = saver.save_artifact(dummy_video, saving_kwargs={"fps": 5}) + + alias = saver.create_alias(source_filename, "alias_filename") + reloaded_alias_video = load_videos(str(alias), return_type = "np") + + assert(reloaded_alias_video.shape == dummy_video.shape) + assert alias.exists() + assert alias.name.endswith(".mp4") + + +def test_assign_artifact_saver_video(tmp_path: Path): + saver = assign_artifact_saver("video", root=tmp_path, export_format="mp4") + assert isinstance(saver, VideoArtifactSaver) + assert saver.export_format == "mp4" + + +def test_assign_artifact_saver_invalid(): + with pytest.raises(ValueError): + assign_artifact_saver("text") + +@pytest.mark.parametrize( + "export_format, save_from_type", + list(itertools.product(["gif", "mp4"], ["np", "pt", "pil"])) +) +def test_video_artifact_saver_tensor(export_format: str, save_from_type: str): + with tempfile.TemporaryDirectory() as tmp_path: + saver = VideoArtifactSaver(root=tmp_path, export_format=export_format) + # create a fake video: + if save_from_type == "pt": + dummy_video = torch.randint(0, 255, (2, 3, 16, 16), dtype=torch.uint8) + elif save_from_type == "np": + dummy_video = np.random.randint(0, 255, (2, 16, 16, 3), dtype=np.uint8) + elif save_from_type == "pil": + dummy_video = np.random.randint(0, 255, (2, 16, 16, 3), dtype=np.uint8) + dummy_video = [Image.fromarray(frame) for frame in dummy_video] + + path = saver.save_artifact(dummy_video) + assert path.exists() + assert path.suffix == f".{export_format}" From d8c7a2897d541207149db27c05b86bb36fcf8cce Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Wed, 8 Oct 2025 12:06:47 +0000 Subject: [PATCH 12/42] test: add artifact related evaluation tests and task modality tests --- src/pruna/evaluation/evaluation_agent.py | 3 +- src/pruna/evaluation/task.py | 6 ++-- tests/algorithms/testers/flash_attn3.py | 2 +- tests/evaluation/test_evalagent.py | 42 ++++++++++++++++++++++++ tests/evaluation/test_task.py | 17 ++++++++++ 5 files changed, 66 insertions(+), 4 deletions(-) create mode 100644 tests/evaluation/test_evalagent.py diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index ff168f3c..86361e1f 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -78,6 +78,7 @@ def __init__( num_samples_per_input: int = 1, seed_strategy: Literal["per_sample", "no_seed"] = "no_seed", global_seed: int | None = None, + artifact_saver_export_format: str | None = None, saving_kwargs: dict = dict(), ) -> None: if task is not None: @@ -102,7 +103,7 @@ def __init__( self.save_artifacts: bool = save_artifacts if save_artifacts: self.root_dir = root_dir if root_dir is not None else tempfile.mkdtemp(prefix="inference_outputs") - self.artifact_saver = assign_artifact_saver(self.task.modality, self.root_dir) + self.artifact_saver = assign_artifact_saver(self.task.modality, self.root_dir, artifact_saver_export_format) self.saving_kwargs = saving_kwargs def evaluate(self, model: Any) -> List[MetricResult]: diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index 316e5afa..809ad3ef 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -169,10 +169,12 @@ def validate_and_get_task_modality(self) -> str: The modality of the task. """ stateful_metrics = [metric for metric in self.metrics if isinstance(metric, StatefulMetric)] - modalities_no_general = [metric.modality for metric in stateful_metrics if metric.modality != ["general"]] + modalities_no_general = { + modality for metric in stateful_metrics for modality in metric.modality if modality != "general" + } # We should also allow 0 because the user might have only general modality metrics. if len(modalities_no_general) == 1: - modality = modalities_no_general[0][0] + modality = modalities_no_general.pop() elif len(modalities_no_general) == 0: modality = "general" else: diff --git a/tests/algorithms/testers/flash_attn3.py b/tests/algorithms/testers/flash_attn3.py index 575f8802..90738424 100644 --- a/tests/algorithms/testers/flash_attn3.py +++ b/tests/algorithms/testers/flash_attn3.py @@ -13,4 +13,4 @@ class TestFlashAttn3(AlgorithmTesterBase): reject_models = ["opt_tiny_random"] allow_pickle_files = False algorithm_class = FlashAttn3 - metrics = ["latency"] + metrics = ["background_consistency"] diff --git a/tests/evaluation/test_evalagent.py b/tests/evaluation/test_evalagent.py new file mode 100644 index 00000000..74a223dd --- /dev/null +++ b/tests/evaluation/test_evalagent.py @@ -0,0 +1,42 @@ +import pytest +import torch +from pathlib import Path +import tempfile +from pruna.evaluation.evaluation_agent import EvaluationAgent +from pruna.engine.pruna_model import PrunaModel + + +@pytest.mark.cuda +@pytest.mark.parametrize("model_fixture", +[pytest.param("wan_tiny_random", marks=pytest.mark.cuda)], indirect=["model_fixture"]) +def test_agent_saves_artifacts(model_fixture): + model, smash_config = model_fixture + # Metrics don't work with bfloat16 + model.to(dtype=torch.float16, device="cuda") + # Artifact path + temp_path = tempfile.mkdtemp() + + # Let's limit the number of data points + dm = smash_config.data + data_points = 2 + dm.limit_datasets(data_points) + + + agent = EvaluationAgent( + request=['background_consistency'], + datamodule=dm, + device="cuda", + save_artifacts=True, + root_dir=temp_path, + artifact_saver_export_format="mp4", + saving_kwargs={"fps":4} + ) + + pruna_model = PrunaModel(model, smash_config) + pruna_model.inference_handler.model_args["num_inference_steps"] = 2 + + agent.evaluate(model=pruna_model) + mp4_files = list(Path(temp_path).rglob("*.mp4")) + + # Check that we saved the correct number of files + assert len(mp4_files) == data_points diff --git a/tests/evaluation/test_task.py b/tests/evaluation/test_task.py index 5420a17b..30dc365e 100644 --- a/tests/evaluation/test_task.py +++ b/tests/evaluation/test_task.py @@ -107,3 +107,20 @@ def test_task_from_string_request(): assert isinstance(task.metrics[0], CMMD) assert isinstance(task.metrics[1], PairwiseClipScore) assert isinstance(task.metrics[2], TorchMetricWrapper) + + +@pytest.mark.parametrize("metrics, modality", [ + (["cmmd", "ssim", "lpips", "latency"], "image"), + (["perplexity", "disk_memory"], "text"), + ([MetricRegistry().get_metric("accuracy", task="binary"), MetricRegistry().get_metric("recall", task="binary")], "general"), + (["background_consistency", "dynamic_degree"], "video") +]) +def test_task_modality(metrics, modality): + datamodule = type("dm", (), {"test_dataloader": lambda self: []})() + task = Task(request=metrics, datamodule=datamodule) + assert task.modality == modality + +def test_task_modality_mixed_raises(): + datamodule = type("dm", (), {"test_dataloader": lambda self: []})() + with pytest.raises(ValueError): + Task(request=["cmmd", "background_consistency"], datamodule=datamodule) From fb15eed40cb751351293446e37933e302a977920 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Wed, 8 Oct 2025 14:03:07 +0000 Subject: [PATCH 13/42] refactor: add some comments --- .../evaluation/artifactsavers/video_artifactsaver.py | 4 ++-- src/pruna/evaluation/evaluation_agent.py | 8 +++++++- tests/evaluation/test_artifactsaver.py | 7 +++++++ tests/evaluation/test_evalagent.py | 1 + tests/evaluation/test_task.py | 2 ++ 5 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/pruna/evaluation/artifactsavers/video_artifactsaver.py b/src/pruna/evaluation/artifactsavers/video_artifactsaver.py index b1c2539c..c8dbbec2 100644 --- a/src/pruna/evaluation/artifactsavers/video_artifactsaver.py +++ b/src/pruna/evaluation/artifactsavers/video_artifactsaver.py @@ -72,9 +72,9 @@ def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: data = [Image.fromarray(frame.astype(np.uint8)) for frame in data] if self.export_format == "mp4": - export_to_video(data, canonical_path, **saving_kwargs) + export_to_video(data, str(canonical_path), **saving_kwargs.copy()) elif self.export_format == "gif": - export_to_gif(data, canonical_path, **saving_kwargs) + export_to_gif(data, str(canonical_path), **saving_kwargs.copy()) else: raise ValueError(f"Invalid format: {self.export_format}") return canonical_path diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 86361e1f..3b3d1d47 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -104,6 +104,7 @@ def __init__( if save_artifacts: self.root_dir = root_dir if root_dir is not None else tempfile.mkdtemp(prefix="inference_outputs") self.artifact_saver = assign_artifact_saver(self.task.modality, self.root_dir, artifact_saver_export_format) + # for miscellaneous saving kwargs like fps, etc. self.saving_kwargs = saving_kwargs def evaluate(self, model: Any) -> List[MetricResult]: @@ -213,6 +214,7 @@ def prepare_model(self, model: Any) -> PrunaModel: # Keeping the device map to move model back to the original device, when the agent is finished. self.device_map = get_device_map(model) model.set_to_eval() + # Setup seeding for inference. model.inference_handler.configure_seed(self.seed_strategy, self.global_seed) return model @@ -248,6 +250,7 @@ def update_stateful_metrics( processed_outputs = model.run_inference(batch) if self.save_artifacts: canonical_paths = [] + # We have to save the artifacts for each sample in the batch. for processed_output in processed_outputs: canonical_path = self.artifact_saver.save_artifact(processed_output) canonical_paths.append(canonical_path) @@ -258,7 +261,8 @@ def update_stateful_metrics( # Non-pairwise (aka single) metrics have regular update. for stateful_metric in single_stateful_metrics: stateful_metric.update(x, gt, processed_outputs) - if stateful_metric.create_alias: + if self.save_artifacts and stateful_metric.create_alias: + # Again, we have to create an alias for each sample in the batch. for prompt_idx, prompt in enumerate(x): assert isinstance(self.artifact_saver.export_format, str) alias_filename = stateful_metric.create_filename( @@ -268,6 +272,8 @@ def update_stateful_metrics( # Cache outputs once in the agent for pairwise metrics to save compute time and memory. if self.task.is_pairwise_evaluation(): + if self.num_samples_per_input > 1: + raise ValueError("Pairwise evaluation with multiple samples per input is not supported.") if self.evaluation_for_first_model: self.cache.append(processed_outputs) else: diff --git a/tests/evaluation/test_artifactsaver.py b/tests/evaluation/test_artifactsaver.py index 9655e8dc..24775ca5 100644 --- a/tests/evaluation/test_artifactsaver.py +++ b/tests/evaluation/test_artifactsaver.py @@ -13,12 +13,16 @@ def test_create_alias(): + """ Test that we can create an alias for an existing video.""" with tempfile.TemporaryDirectory() as tmp_path: + # First, we create a random video and save it. saver = VideoArtifactSaver(root=tmp_path, export_format="mp4") dummy_video = np.random.randint(0, 255, (10, 16, 16, 3), dtype=np.uint8) source_filename = saver.save_artifact(dummy_video, saving_kwargs={"fps": 5}) + # Then, we create an alias for the video. alias = saver.create_alias(source_filename, "alias_filename") + # Finally, we reload the alias and check that it is the same as the original video. reloaded_alias_video = load_videos(str(alias), return_type = "np") assert(reloaded_alias_video.shape == dummy_video.shape) @@ -27,12 +31,14 @@ def test_create_alias(): def test_assign_artifact_saver_video(tmp_path: Path): + """ Test the artifact save is assigned correctly.""" saver = assign_artifact_saver("video", root=tmp_path, export_format="mp4") assert isinstance(saver, VideoArtifactSaver) assert saver.export_format == "mp4" def test_assign_artifact_saver_invalid(): + """ Test that we raise an error if the artifact saver is assigned incorrectly.""" with pytest.raises(ValueError): assign_artifact_saver("text") @@ -41,6 +47,7 @@ def test_assign_artifact_saver_invalid(): list(itertools.product(["gif", "mp4"], ["np", "pt", "pil"])) ) def test_video_artifact_saver_tensor(export_format: str, save_from_type: str): + """ Test that we can save a video from numpy, torch and PIL in mp4 and gif formats.""" with tempfile.TemporaryDirectory() as tmp_path: saver = VideoArtifactSaver(root=tmp_path, export_format=export_format) # create a fake video: diff --git a/tests/evaluation/test_evalagent.py b/tests/evaluation/test_evalagent.py index 74a223dd..7c64de7e 100644 --- a/tests/evaluation/test_evalagent.py +++ b/tests/evaluation/test_evalagent.py @@ -10,6 +10,7 @@ @pytest.mark.parametrize("model_fixture", [pytest.param("wan_tiny_random", marks=pytest.mark.cuda)], indirect=["model_fixture"]) def test_agent_saves_artifacts(model_fixture): + """ Test that the agent runs inference and saves the inference output artifacts correctly.""" model, smash_config = model_fixture # Metrics don't work with bfloat16 model.to(dtype=torch.float16, device="cuda") diff --git a/tests/evaluation/test_task.py b/tests/evaluation/test_task.py index 30dc365e..218edd8b 100644 --- a/tests/evaluation/test_task.py +++ b/tests/evaluation/test_task.py @@ -116,11 +116,13 @@ def test_task_from_string_request(): (["background_consistency", "dynamic_degree"], "video") ]) def test_task_modality(metrics, modality): + """ Test that the task modality is assigned correctly for image, text, general and video metrics.""" datamodule = type("dm", (), {"test_dataloader": lambda self: []})() task = Task(request=metrics, datamodule=datamodule) assert task.modality == modality def test_task_modality_mixed_raises(): + """ Test that we raise an error if the task modality is mixed.""" datamodule = type("dm", (), {"test_dataloader": lambda self: []})() with pytest.raises(ValueError): Task(request=["cmmd", "background_consistency"], datamodule=datamodule) From c2c1d273171d13a100faef24b8d31c337b640b6d Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Thu, 23 Oct 2025 11:44:03 +0000 Subject: [PATCH 14/42] refactor: better initialization for artifact savers --- pyproject.toml | 1 + src/pruna/evaluation/artifactsavers/utils.py | 6 +++--- src/pruna/evaluation/evaluation_agent.py | 9 +++++++-- tests/evaluation/test_artifactsaver.py | 4 ++-- tests/evaluation/test_evalagent.py | 19 +++++++++---------- 5 files changed, 22 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bed2e6ae..9c8f9d74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ possibly-unbound-attribute = "ignore" possibly-missing-import = "ignore" # mypy is more permissive with imports no-matching-overload = "ignore" # mypy is more permissive with overloads unresolved-reference = "ignore" # mypy is more permissive with references +missing-argument = "ignore" possibly-unbound-import = "ignore" missing-argument = "ignore" diff --git a/src/pruna/evaluation/artifactsavers/utils.py b/src/pruna/evaluation/artifactsavers/utils.py index bf9787f9..6954dfc9 100644 --- a/src/pruna/evaluation/artifactsavers/utils.py +++ b/src/pruna/evaluation/artifactsavers/utils.py @@ -40,7 +40,7 @@ def assign_artifact_saver( ArtifactSaver The appropriate artifact saver. """ - if modality != "video": - raise ValueError(f"Modality {modality} is not supported") - else: + if modality == "video": return VideoArtifactSaver(root=root, export_format=export_format) + else: + raise ValueError(f"Modality {modality} is not supported") diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 3b3d1d47..56cfbf5e 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -36,6 +36,8 @@ from pruna.evaluation.task import Task from pruna.logging.logger import pruna_logger +OUTPUT_DIR = tempfile.mkdtemp(prefix="inference_outputs") + class EvaluationAgent: """ @@ -102,7 +104,7 @@ def __init__( self.evaluation_for_first_model: bool = True self.save_artifacts: bool = save_artifacts if save_artifacts: - self.root_dir = root_dir if root_dir is not None else tempfile.mkdtemp(prefix="inference_outputs") + self.root_dir = root_dir if root_dir is not None else OUTPUT_DIR self.artifact_saver = assign_artifact_saver(self.task.modality, self.root_dir, artifact_saver_export_format) # for miscellaneous saving kwargs like fps, etc. self.saving_kwargs = saving_kwargs @@ -264,7 +266,10 @@ def update_stateful_metrics( if self.save_artifacts and stateful_metric.create_alias: # Again, we have to create an alias for each sample in the batch. for prompt_idx, prompt in enumerate(x): - assert isinstance(self.artifact_saver.export_format, str) + if self.artifact_saver.export_format is None: + raise ValueError( + "Export format is not set.Please set the export format for the artifact saver." + ) alias_filename = stateful_metric.create_filename( filename=prompt, idx=sample_idx, file_extension=self.artifact_saver.export_format ) diff --git a/tests/evaluation/test_artifactsaver.py b/tests/evaluation/test_artifactsaver.py index 24775ca5..2f2846e5 100644 --- a/tests/evaluation/test_artifactsaver.py +++ b/tests/evaluation/test_artifactsaver.py @@ -7,7 +7,7 @@ from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver from pruna.evaluation.artifactsavers.utils import assign_artifact_saver -from pruna.evaluation.metrics.vbench_utils import load_videos +from pruna.evaluation.metrics.vbench_utils import load_video from PIL import Image import itertools @@ -23,7 +23,7 @@ def test_create_alias(): # Then, we create an alias for the video. alias = saver.create_alias(source_filename, "alias_filename") # Finally, we reload the alias and check that it is the same as the original video. - reloaded_alias_video = load_videos(str(alias), return_type = "np") + reloaded_alias_video = load_video(str(alias), return_type = "np") assert(reloaded_alias_video.shape == dummy_video.shape) assert alias.exists() diff --git a/tests/evaluation/test_evalagent.py b/tests/evaluation/test_evalagent.py index 7c64de7e..e3dc33ba 100644 --- a/tests/evaluation/test_evalagent.py +++ b/tests/evaluation/test_evalagent.py @@ -7,13 +7,12 @@ @pytest.mark.cuda -@pytest.mark.parametrize("model_fixture", -[pytest.param("wan_tiny_random", marks=pytest.mark.cuda)], indirect=["model_fixture"]) -def test_agent_saves_artifacts(model_fixture): +@pytest.mark.parametrize("model_fixture, export_format", +[pytest.param("wan_tiny_random", "mp4", marks=pytest.mark.cuda), +pytest.param("wan_tiny_random", "gif", marks=pytest.mark.cuda)], indirect=["model_fixture"]) +def test_agent_saves_artifacts(model_fixture, export_format): """ Test that the agent runs inference and saves the inference output artifacts correctly.""" model, smash_config = model_fixture - # Metrics don't work with bfloat16 - model.to(dtype=torch.float16, device="cuda") # Artifact path temp_path = tempfile.mkdtemp() @@ -29,15 +28,15 @@ def test_agent_saves_artifacts(model_fixture): device="cuda", save_artifacts=True, root_dir=temp_path, - artifact_saver_export_format="mp4", - saving_kwargs={"fps":4} + saving_kwargs={"fps":4}, + artifact_saver_export_format=export_format, ) pruna_model = PrunaModel(model, smash_config) - pruna_model.inference_handler.model_args["num_inference_steps"] = 2 + pruna_model.inference_handler.model_args["num_inference_steps"] = 1 agent.evaluate(model=pruna_model) - mp4_files = list(Path(temp_path).rglob("*.mp4")) + files = list(Path(temp_path).rglob(f"*.{export_format}")) # Check that we saved the correct number of files - assert len(mp4_files) == data_points + assert len(files) == data_points From ceead2274872e1c76a7db966b94b820dd3f689e7 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Fri, 24 Oct 2025 10:09:15 +0000 Subject: [PATCH 15/42] test: add more dtype tests for artifact saver --- tests/evaluation/test_artifactsaver.py | 46 ++++++++++++++++++++------ 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/tests/evaluation/test_artifactsaver.py b/tests/evaluation/test_artifactsaver.py index 2f2846e5..ca46dd99 100644 --- a/tests/evaluation/test_artifactsaver.py +++ b/tests/evaluation/test_artifactsaver.py @@ -9,7 +9,6 @@ from pruna.evaluation.artifactsavers.utils import assign_artifact_saver from pruna.evaluation.metrics.vbench_utils import load_video from PIL import Image -import itertools def test_create_alias(): @@ -43,22 +42,49 @@ def test_assign_artifact_saver_invalid(): assign_artifact_saver("text") @pytest.mark.parametrize( - "export_format, save_from_type", - list(itertools.product(["gif", "mp4"], ["np", "pt", "pil"])) + "export_format, save_from_type, save_from_dtype", + [pytest.param("gif", "np", "uint8"), + pytest.param("gif", "np", "float32"), + # Numpy doesn't have half precision, so we do not test for float16 + pytest.param("gif", "pt", "float32"), + pytest.param("gif", "pt", "float16"), + pytest.param("gif", "pt", "uint8"), + # PIL doesnot support creating images from float numpy arrays, so we only test uint8. + pytest.param("gif", "pil", "uint8"), + pytest.param("mp4", "np", "uint8"), + pytest.param("mp4", "np", "float32"), + pytest.param("mp4", "pt", "float32"), + pytest.param("mp4", "pt", "float16"), + pytest.param("mp4", "pt", "uint8"), + # PIL doesnot support creating images from float numpy arrays, so we only test uint8. + pytest.param("mp4", "pil", "uint8"),] ) -def test_video_artifact_saver_tensor(export_format: str, save_from_type: str): - """ Test that we can save a video from numpy, torch and PIL in mp4 and gif formats.""" +def test_video_artifact_saver_tensor(export_format: str, save_from_type: str, save_from_dtype: str): + """ Test that we can save a video from numpy, torch and PIL in mp4 and gif formats. """ with tempfile.TemporaryDirectory() as tmp_path: saver = VideoArtifactSaver(root=tmp_path, export_format=export_format) # create a fake video: if save_from_type == "pt": - dummy_video = torch.randint(0, 255, (2, 3, 16, 16), dtype=torch.uint8) + # Unfortunately, neither torch nor numpy have one random generator function that can support all dtypes. + # Therefore, we need to use different functions for int and float dtypes. + if save_from_dtype == "uint8": + dtype = getattr(torch, save_from_dtype) + dummy_video = torch.randint(0, 256, (2, 3, 16, 16), dtype=dtype) + else: + dtype = getattr(torch, save_from_dtype) + dummy_video = torch.randn(2, 3, 16, 16, dtype=dtype) elif save_from_type == "np": - dummy_video = np.random.randint(0, 255, (2, 16, 16, 3), dtype=np.uint8) + if save_from_dtype == "uint8": + dtype = getattr(np, save_from_dtype) + dummy_video = np.random.randint(0, 256, (2, 16, 16, 3), dtype=dtype) + else: + rng = np.random.default_rng() + dtype = getattr(np, save_from_dtype) + dummy_video = rng.random((2, 16, 16, 3), dtype=dtype) elif save_from_type == "pil": - dummy_video = np.random.randint(0, 255, (2, 16, 16, 3), dtype=np.uint8) - dummy_video = [Image.fromarray(frame) for frame in dummy_video] - + dtype = getattr(np, save_from_dtype) + dummy_video = np.random.randint(0, 256, (2, 16, 16, 3), dtype=dtype) + dummy_video = [Image.fromarray(frame.astype(np.uint8)) for frame in dummy_video] path = saver.save_artifact(dummy_video) assert path.exists() assert path.suffix == f".{export_format}" From 2329abbb769041494159d22586ff2f0ca7370d67 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Fri, 24 Oct 2025 12:25:41 +0000 Subject: [PATCH 16/42] feat: metric modalities as sets --- .pre-commit-config.yaml | 4 +++ src/pruna/evaluation/evaluation_agent.py | 2 +- src/pruna/evaluation/metrics/metric_cmmd.py | 4 +-- .../metrics/metric_pairwise_clip.py | 4 +-- .../evaluation/metrics/metric_sharpness.py | 4 +-- .../evaluation/metrics/metric_stateful.py | 2 +- src/pruna/evaluation/metrics/metric_torch.py | 27 ++++++++++--------- src/pruna/evaluation/metrics/utils.py | 5 ++++ src/pruna/evaluation/task.py | 21 +++++++-------- 9 files changed, 41 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b899e7cb..8dfbe206 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,7 +39,11 @@ repos: grep -v "^D" | cut -f2- | while IFS= read -r file; do +<<<<<<< HEAD if [ -f "$file" ] && [ "$file" != ".pre-commit-config.yaml" ] && grep -q "pruna_pro" "$file"; then +======= + if [ -f "$file" ] && ["$file" != ".pre-commit-config.yaml"] && grep -q "pruna_pro" "$file"; then +>>>>>>> 306050c (feat: metric modalities as sets) echo "Error: pruna_pro found in staged file $file" exit 1 fi diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 56cfbf5e..c84889a1 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -27,7 +27,7 @@ from pruna.data.pruna_datamodule import PrunaDataModule from pruna.data.utils import move_batch_to_device from pruna.engine.pruna_model import PrunaModel -from pruna.engine.utils import get_device, move_to_device, get_device, safe_memory_cleanup, set_to_best_available_device +from pruna.engine.utils import get_device, move_to_device, safe_memory_cleanup, set_to_best_available_device from pruna.evaluation.artifactsavers.utils import assign_artifact_saver from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_stateful import StatefulMetric diff --git a/src/pruna/evaluation/metrics/metric_cmmd.py b/src/pruna/evaluation/metrics/metric_cmmd.py index 50521125..9ed5eb9d 100644 --- a/src/pruna/evaluation/metrics/metric_cmmd.py +++ b/src/pruna/evaluation/metrics/metric_cmmd.py @@ -25,7 +25,7 @@ from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult -from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor +from pruna.evaluation.metrics.utils import IMAGE, SINGLE, get_call_type_for_single_metric, metric_data_processor from pruna.logging.logger import pruna_logger METRIC_CMMD = "cmmd" @@ -58,7 +58,7 @@ class CMMD(StatefulMetric): default_call_type: str = "gt_y" higher_is_better: bool = False metric_name: str = METRIC_CMMD - modality = ["image"] + modality = {IMAGE} def __init__( self, diff --git a/src/pruna/evaluation/metrics/metric_pairwise_clip.py b/src/pruna/evaluation/metrics/metric_pairwise_clip.py index 87296377..e541fa33 100644 --- a/src/pruna/evaluation/metrics/metric_pairwise_clip.py +++ b/src/pruna/evaluation/metrics/metric_pairwise_clip.py @@ -26,7 +26,7 @@ from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult -from pruna.evaluation.metrics.utils import PAIRWISE, metric_data_processor +from pruna.evaluation.metrics.utils import IMAGE, PAIRWISE, metric_data_processor from pruna.logging.logger import pruna_logger @@ -47,7 +47,7 @@ class PairwiseClipScore(CLIPScore, StatefulMetric): # type: ignore[misc] higher_is_better: bool = True metric_name: str = "pairwise_clip_score" - modality = ["image"] + modality = {IMAGE} def __init__(self, **kwargs: Any) -> None: device = kwargs.pop("device", None) diff --git a/src/pruna/evaluation/metrics/metric_sharpness.py b/src/pruna/evaluation/metrics/metric_sharpness.py index 759f0fca..c0abaeac 100644 --- a/src/pruna/evaluation/metrics/metric_sharpness.py +++ b/src/pruna/evaluation/metrics/metric_sharpness.py @@ -24,7 +24,7 @@ from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult -from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor +from pruna.evaluation.metrics.utils import IMAGE, SINGLE, get_call_type_for_single_metric, metric_data_processor from pruna.logging.logger import pruna_logger METRIC_SHARPNESS = "sharpness" @@ -64,7 +64,7 @@ class SharpnessMetric(StatefulMetric): higher_is_better: bool = True metric_name: str = METRIC_SHARPNESS runs_on: List[str] = ["cpu", "cuda"] - modality = ["image"] + modality = {IMAGE} def __init__(self, *args, kernel_size: int = 3, call_type: str = SINGLE, **kwargs) -> None: super().__init__(device=kwargs.pop("device", None)) diff --git a/src/pruna/evaluation/metrics/metric_stateful.py b/src/pruna/evaluation/metrics/metric_stateful.py index 19a201af..8407e96c 100644 --- a/src/pruna/evaluation/metrics/metric_stateful.py +++ b/src/pruna/evaluation/metrics/metric_stateful.py @@ -46,7 +46,7 @@ class StatefulMetric(ABC): call_type: str runs_on: list[str] = ["cuda", "cpu", "mps"] create_alias: bool = False - modality: List[str] + modality: set[str] def __init__(self, device: str | torch.device | None = None, **kwargs) -> None: """Initialize the StatefulMetric class.""" diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index 951f1b8d..e79d8a71 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -42,8 +42,11 @@ from pruna.evaluation.metrics.result import MetricResult from pruna.evaluation.metrics.utils import ( CALL_TYPES, + IMAGE, + MODALITIES, PAIRWISE, SINGLE, + TEXT, get_pairwise_pairing, get_single_pairing, metric_data_processor, @@ -168,18 +171,18 @@ class TorchMetrics(Enum): The starting value for the enum. """ - fid = (partial(FrechetInceptionDistance), fid_update, "gt_y", ["image"]) - accuracy = (partial(Accuracy), None, "y_gt", ["general"]) - perplexity = (partial(Perplexity), None, "y_gt", ["text"]) - clip_score = (partial(CLIPScore), None, "y_x", ["image"]) - precision = (partial(Precision), None, "y_gt", ["general"]) - recall = (partial(Recall), None, "y_gt", ["general"]) - psnr = (partial(PeakSignalNoiseRatio), None, "pairwise_y_gt", ["image"]) - ssim = (partial(StructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt", ["image"]) - msssim = (partial(MultiScaleStructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt", ["image"]) - lpips = (partial(LearnedPerceptualImagePatchSimilarity), lpips_update, "pairwise_y_gt", ["image"]) - arniqa = (partial(ARNIQA), arniqa_update, "y", ["image"]) - clipiqa = (partial(CLIPImageQualityAssessment), None, "y", ["image"]) + fid = (partial(FrechetInceptionDistance), fid_update, "gt_y", {IMAGE}) + accuracy = (partial(Accuracy), None, "y_gt", MODALITIES) + perplexity = (partial(Perplexity), None, "y_gt", {TEXT}) + clip_score = (partial(CLIPScore), None, "y_x", {IMAGE}) + precision = (partial(Precision), None, "y_gt", MODALITIES) + recall = (partial(Recall), None, "y_gt", MODALITIES) + psnr = (partial(PeakSignalNoiseRatio), None, "pairwise_y_gt", {IMAGE}) + ssim = (partial(StructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt", {IMAGE}) + msssim = (partial(MultiScaleStructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt", {IMAGE}) + lpips = (partial(LearnedPerceptualImagePatchSimilarity), lpips_update, "pairwise_y_gt", {IMAGE}) + arniqa = (partial(ARNIQA), arniqa_update, "y", {IMAGE}) + clipiqa = (partial(CLIPImageQualityAssessment), None, "y", {IMAGE}) def __init__(self, *args, **kwargs) -> None: self.tm = self.value[0] diff --git a/src/pruna/evaluation/metrics/utils.py b/src/pruna/evaluation/metrics/utils.py index 29342701..2232b198 100644 --- a/src/pruna/evaluation/metrics/utils.py +++ b/src/pruna/evaluation/metrics/utils.py @@ -41,6 +41,11 @@ SINGLE = "single" PAIRWISE = "pairwise" CALL_TYPES = (SINGLE, PAIRWISE) +IMAGE = "image" +VIDEO = "video" +TEXT = "text" + +MODALITIES = {IMAGE, VIDEO, TEXT} def metric_data_processor( diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index 809ad3ef..672f02ae 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -168,18 +168,15 @@ def validate_and_get_task_modality(self) -> str: str The modality of the task. """ - stateful_metrics = [metric for metric in self.metrics if isinstance(metric, StatefulMetric)] - modalities_no_general = { - modality for metric in stateful_metrics for modality in metric.modality if modality != "general" - } - # We should also allow 0 because the user might have only general modality metrics. - if len(modalities_no_general) == 1: - modality = modalities_no_general.pop() - elif len(modalities_no_general) == 0: - modality = "general" - else: - raise ValueError("The task should have a single modality of quality metrics.") - return modality + modality_intersection = set.intersection( + *[metric.modality for metric in self.metrics if isinstance(metric, StatefulMetric)] + ) + if len(modality_intersection) == 1: + return modality_intersection.pop() + elif len(modality_intersection) == 0: + raise ValueError("The task should have a single modality across all quality metrics.") + else: # More than one modality, fine for evaluation, can't save artifacts (for now). + return "general" def get_metrics( From 34a9e82213276f9133bcc5b5d17aef9dd2ce5c68 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Tue, 28 Oct 2025 11:11:26 +0000 Subject: [PATCH 17/42] refactor: comments tests task modality --- src/pruna/evaluation/artifactsavers/video_artifactsaver.py | 2 ++ src/pruna/evaluation/evaluation_agent.py | 6 ++++-- src/pruna/evaluation/task.py | 2 ++ tests/evaluation/test_artifactsaver.py | 2 +- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/pruna/evaluation/artifactsavers/video_artifactsaver.py b/src/pruna/evaluation/artifactsavers/video_artifactsaver.py index c8dbbec2..4acae51c 100644 --- a/src/pruna/evaluation/artifactsavers/video_artifactsaver.py +++ b/src/pruna/evaluation/artifactsavers/video_artifactsaver.py @@ -55,6 +55,8 @@ def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: ---------- data: Any The data to save. + saving_kwargs: dict + The additional kwargs to pass to the saving utility function. Returns ------- diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index c84889a1..f705f6ec 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -264,11 +264,13 @@ def update_stateful_metrics( for stateful_metric in single_stateful_metrics: stateful_metric.update(x, gt, processed_outputs) if self.save_artifacts and stateful_metric.create_alias: - # Again, we have to create an alias for each sample in the batch. + # The evaluation agent saves the artifacts with a canonical filenaming convention. + # If the user wants to save the artifact with a different filename, + # here we give them the option to create an alias for the file. for prompt_idx, prompt in enumerate(x): if self.artifact_saver.export_format is None: raise ValueError( - "Export format is not set.Please set the export format for the artifact saver." + "Export format is not set. Please set the export format for the artifact saver." ) alias_filename = stateful_metric.create_filename( filename=prompt, idx=sample_idx, file_extension=self.artifact_saver.export_format diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index 672f02ae..0a1f9860 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -168,6 +168,8 @@ def validate_and_get_task_modality(self) -> str: str The modality of the task. """ + if not self.get_single_stateful_metrics() and not self.get_pairwise_stateful_metrics(): + return "general" modality_intersection = set.intersection( *[metric.modality for metric in self.metrics if isinstance(metric, StatefulMetric)] ) diff --git a/tests/evaluation/test_artifactsaver.py b/tests/evaluation/test_artifactsaver.py index ca46dd99..41b09de4 100644 --- a/tests/evaluation/test_artifactsaver.py +++ b/tests/evaluation/test_artifactsaver.py @@ -39,7 +39,7 @@ def test_assign_artifact_saver_video(tmp_path: Path): def test_assign_artifact_saver_invalid(): """ Test that we raise an error if the artifact saver is assigned incorrectly.""" with pytest.raises(ValueError): - assign_artifact_saver("text") + assign_artifact_saver("nonexistent_modality") @pytest.mark.parametrize( "export_format, save_from_type, save_from_dtype", From 6e3eb7688a12a0e80a5baf50fb16360c7e8c2175 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Mon, 6 Oct 2025 08:11:07 +0000 Subject: [PATCH 18/42] feat: add video inference support and seeding strategies to inference handler --- src/pruna/engine/handler/handler_diffuser.py | 62 ++++++++++++--- src/pruna/engine/handler/handler_inference.py | 76 ++++++++++++++++++- src/pruna/engine/pruna_model.py | 7 +- 3 files changed, 131 insertions(+), 14 deletions(-) diff --git a/src/pruna/engine/handler/handler_diffuser.py b/src/pruna/engine/handler/handler_diffuser.py index fd1fd239..39ff94fc 100644 --- a/src/pruna/engine/handler/handler_diffuser.py +++ b/src/pruna/engine/handler/handler_diffuser.py @@ -15,12 +15,11 @@ from __future__ import annotations import inspect -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple import torch -from torchvision import transforms -from pruna.engine.handler.handler_inference import InferenceHandler +from pruna.engine.handler.handler_inference import InferenceHandler, validate_seed_strategy from pruna.logging.logger import pruna_logger @@ -41,11 +40,9 @@ class DiffuserHandler(InferenceHandler): """ def __init__(self, call_signature: inspect.Signature, model_args: Optional[Dict[str, Any]] = None) -> None: - default_args = {"generator": torch.Generator("cpu").manual_seed(42)} self.call_signature = call_signature - if model_args: - default_args.update(model_args) - self.model_args = default_args + self.model_args = model_args if model_args else {} + self.model_args["output_type"] = "pt" def prepare_inputs( self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor | dict[str, Any], ...] | dict[str, Any] @@ -69,6 +66,13 @@ def prepare_inputs( else: # Unconditional generation models return None + def apply_per_sample_seed(self) -> None: + """Generate and apply a new random seed derived from global_seed (only valid if seed_strategy="per_sample").""" + if self.seed_strategy != "per_sample": + raise ValueError("Seed strategy must be 'per_sample' to apply per sample seed.") + seed = int(torch.randint(0, 2**31, (1,), generator=self.generator).item()) + self.model_args["generator"] = torch.Generator("cpu").manual_seed(seed) + def process_output(self, output: Any) -> torch.Tensor: """ Handle the output of the model. @@ -83,13 +87,47 @@ def process_output(self, output: Any) -> torch.Tensor: torch.Tensor The processed images. """ - generated = output.images - return torch.stack([transforms.PILToTensor()(g) for g in generated]) + if hasattr(output, "images"): + generated = output.images + elif hasattr(output, "frames"): + generated = output.frames + else: + # Maybe the user is calling the pipeline with return_dict = False, + # which then directly returns the generated image / video. + generated = output + return generated def log_model_info(self) -> None: """Log information about the inference handler.""" pruna_logger.info( - "Detected diffusers model. Using DiffuserHandler with fixed seed.\n" - "- The first element of the batch is passed as input.\n" - "- The generated outputs are expected to have .images attribute." + "Detected diffusers model. Using DiffuserHandler.\n- The first element of the batch is passed as input.\n" ) + + def configure_seed( + self, seed_strategy: Literal["per_evaluation", "per_sample", "no_seed"], global_seed: int | None + ) -> None: + """ + Set the random seed according to the chosen strategy. + + - If `seed_strategy="per_evaluation"`, the same `global_seed` is applied + once and reused for the entire generation run. + - If `seed_strategy="per_sample"`, the `global_seed` is used as a base to derive a different seed for each sample + This ensures reproducibility while still producing variation across samples, + making it the preferred option for benchmarking. + - If `seed_strategy="no_seed"`, no seed is set internally. + The user is responsible for managing seeds if reproducibility is required. + + Parameters + ---------- + seed_strategy : Literal["per_evaluation", "per_sample", "no_seed"] + The seeding strategy to apply. + global_seed : int | None + The base seed value to use (if applicable). + """ + self.seed_strategy = seed_strategy + validate_seed_strategy(seed_strategy, global_seed) + if global_seed is not None: + self.global_seed = global_seed + self.generator = torch.Generator("cpu").manual_seed(global_seed) + # We also set the generator for the per_evaluation seed strategy already here. + self.model_args["generator"] = torch.Generator("cpu").manual_seed(global_seed) diff --git a/src/pruna/engine/handler/handler_inference.py b/src/pruna/engine/handler/handler_inference.py index 900e0e25..580fbbd8 100644 --- a/src/pruna/engine/handler/handler_inference.py +++ b/src/pruna/engine/handler/handler_inference.py @@ -14,9 +14,11 @@ from __future__ import annotations +import random from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Literal, Tuple +import numpy as np import torch from pruna.data.utils import move_batch_to_device @@ -98,3 +100,75 @@ def move_inputs_to_device( return move_batch_to_device(inputs, device) except torch.cuda.OutOfMemoryError as e: raise e + + def apply_per_sample_seed(self) -> None: + """Generate and apply a new random seed derived from global_seed (only valid if seed_strategy="per_sample").""" + if self.seed_strategy != "per_sample": + raise ValueError("Seed strategy must be 'per_sample' to apply per sample seed.") + per_sample_seed = self.global_seed + random.randint(0, 1_000_000) + set_seed(per_sample_seed) + + def configure_seed( + self, seed_strategy: Literal["per_evaluation", "per_sample", "no_seed"], global_seed: int | None + ) -> None: + """ + Set the random seed according to the chosen strategy. + + - If `seed_strategy="per_evaluation"`, the same `global_seed` is applied once and reused + for the entire generation run. + - If `seed_strategy="per_sample"`, the `global_seed` is used as a base to derive a different seed for each sample + This ensures reproducibility while still producing variation across samples, making it the preferred option for + benchmarking. + - If `seed_strategy="no_seed"`, no seed is set internally. The user is responsible for managing seeds if + reproducibility is required. + + Parameters + ---------- + seed_strategy : Literal["per_evaluation", "per_sample", "no_seed"] + The seeding strategy to apply. + global_seed : int | None + The base seed value to use (if applicable). + """ + self.seed_strategy = seed_strategy + validate_seed_strategy(seed_strategy, global_seed) + if global_seed is not None: + self.global_seed = global_seed + set_seed(global_seed) + + +def set_seed(seed: int) -> None: + """ + Set the random seed for the current process. + + Parameters + ---------- + seed : int + The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def validate_seed_strategy( + seed_strategy: Literal["per_evaluation", "per_sample", "no_seed"], global_seed: int | None +) -> None: + """ + Check the consistency of the seed strategy and the global seed. + + If the seed strategy is "no_seed", the global seed must be None. + If the seed strategy is "per_evaluation" or "per_sample", the user must provide a global seed. + + Parameters + ---------- + seed_strategy : Literal["per_evaluation", "per_sample", "no_seed"] + The seeding strategy to apply. + global_seed : int | None + The base seed value to use (if applicable). + """ + if seed_strategy != "no_seed" and global_seed is None: + raise ValueError("Global seed must be provided if seed strategy is not 'no_seed'.") + elif global_seed is not None and seed_strategy == "no_seed": + raise ValueError("Seed strategy cannot be 'no_seed' if global seed is provided.") diff --git a/src/pruna/engine/pruna_model.py b/src/pruna/engine/pruna_model.py index fb4d6f91..e9554540 100644 --- a/src/pruna/engine/pruna_model.py +++ b/src/pruna/engine/pruna_model.py @@ -24,7 +24,7 @@ from pruna.config.smash_config import SmashConfig from pruna.engine.handler.handler_utils import register_inference_handler -from pruna.engine.load import load_pruna_model, load_pruna_model_from_pretrained +from pruna.engine.load import filter_load_kwargs, load_pruna_model, load_pruna_model_from_pretrained from pruna.engine.save import save_pruna_model, save_pruna_model_to_hub from pruna.engine.utils import get_device, get_nn_modules, set_to_eval from pruna.logging.filter import apply_warning_filter @@ -108,6 +108,11 @@ def run_inference(self, batch: Any) -> Any: ) inference_function = getattr(self, inference_function_name) + if hasattr(self.inference_handler, "seed_strategy") and self.inference_handler.seed_strategy == "per_sample": + self.inference_handler.apply_per_sample_seed() + + self.inference_handler.model_args = filter_load_kwargs(self.model.__call__, self.inference_handler.model_args) + if prepared_inputs is None: outputs = inference_function(**self.inference_handler.model_args) elif isinstance(prepared_inputs, dict): From a3f7cbc2f48f102cda82ae8e93e917da0b3aefe7 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Mon, 6 Oct 2025 12:38:44 +0000 Subject: [PATCH 19/42] feat: remove per evaluation seed and add tests --- src/pruna/engine/handler/handler_diffuser.py | 32 +++--- src/pruna/engine/handler/handler_inference.py | 44 ++++--- tests/engine/test_handler.py | 108 ++++++++++++++++++ 3 files changed, 143 insertions(+), 41 deletions(-) create mode 100644 tests/engine/test_handler.py diff --git a/src/pruna/engine/handler/handler_diffuser.py b/src/pruna/engine/handler/handler_diffuser.py index 39ff94fc..d6737a3a 100644 --- a/src/pruna/engine/handler/handler_diffuser.py +++ b/src/pruna/engine/handler/handler_diffuser.py @@ -39,10 +39,17 @@ class DiffuserHandler(InferenceHandler): The arguments to pass to the model. """ - def __init__(self, call_signature: inspect.Signature, model_args: Optional[Dict[str, Any]] = None) -> None: + def __init__( + self, + call_signature: inspect.Signature, + model_args: Optional[Dict[str, Any]] = None, + seed_strategy: Literal["per_sample", "no_seed"] = "no_seed", + global_seed: int | None = None, + ) -> None: self.call_signature = call_signature self.model_args = model_args if model_args else {} self.model_args["output_type"] = "pt" + self.configure_seed(seed_strategy, global_seed) def prepare_inputs( self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor | dict[str, Any], ...] | dict[str, Any] @@ -66,13 +73,6 @@ def prepare_inputs( else: # Unconditional generation models return None - def apply_per_sample_seed(self) -> None: - """Generate and apply a new random seed derived from global_seed (only valid if seed_strategy="per_sample").""" - if self.seed_strategy != "per_sample": - raise ValueError("Seed strategy must be 'per_sample' to apply per sample seed.") - seed = int(torch.randint(0, 2**31, (1,), generator=self.generator).item()) - self.model_args["generator"] = torch.Generator("cpu").manual_seed(seed) - def process_output(self, output: Any) -> torch.Tensor: """ Handle the output of the model. @@ -103,23 +103,19 @@ def log_model_info(self) -> None: "Detected diffusers model. Using DiffuserHandler.\n- The first element of the batch is passed as input.\n" ) - def configure_seed( - self, seed_strategy: Literal["per_evaluation", "per_sample", "no_seed"], global_seed: int | None - ) -> None: + def configure_seed(self, seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None: """ Set the random seed according to the chosen strategy. - - If `seed_strategy="per_evaluation"`, the same `global_seed` is applied - once and reused for the entire generation run. - - If `seed_strategy="per_sample"`, the `global_seed` is used as a base to derive a different seed for each sample - This ensures reproducibility while still producing variation across samples, + - If `seed_strategy="per_sample"`,the `global_seed` is used as a base to derive a different seed for each + sample. This ensures reproducibility while still producing variation across samples, making it the preferred option for benchmarking. - If `seed_strategy="no_seed"`, no seed is set internally. The user is responsible for managing seeds if reproducibility is required. Parameters ---------- - seed_strategy : Literal["per_evaluation", "per_sample", "no_seed"] + seed_strategy : Literal["per_sample", "no_seed"] The seeding strategy to apply. global_seed : int | None The base seed value to use (if applicable). @@ -128,6 +124,6 @@ def configure_seed( validate_seed_strategy(seed_strategy, global_seed) if global_seed is not None: self.global_seed = global_seed - self.generator = torch.Generator("cpu").manual_seed(global_seed) - # We also set the generator for the per_evaluation seed strategy already here. self.model_args["generator"] = torch.Generator("cpu").manual_seed(global_seed) + else: + self.model_args["generator"] = None # Remove the seed. diff --git a/src/pruna/engine/handler/handler_inference.py b/src/pruna/engine/handler/handler_inference.py index 580fbbd8..d6e04da2 100644 --- a/src/pruna/engine/handler/handler_inference.py +++ b/src/pruna/engine/handler/handler_inference.py @@ -101,30 +101,19 @@ def move_inputs_to_device( except torch.cuda.OutOfMemoryError as e: raise e - def apply_per_sample_seed(self) -> None: - """Generate and apply a new random seed derived from global_seed (only valid if seed_strategy="per_sample").""" - if self.seed_strategy != "per_sample": - raise ValueError("Seed strategy must be 'per_sample' to apply per sample seed.") - per_sample_seed = self.global_seed + random.randint(0, 1_000_000) - set_seed(per_sample_seed) - - def configure_seed( - self, seed_strategy: Literal["per_evaluation", "per_sample", "no_seed"], global_seed: int | None - ) -> None: + def configure_seed(self, seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None: """ Set the random seed according to the chosen strategy. - - If `seed_strategy="per_evaluation"`, the same `global_seed` is applied once and reused - for the entire generation run. - - If `seed_strategy="per_sample"`, the `global_seed` is used as a base to derive a different seed for each sample - This ensures reproducibility while still producing variation across samples, making it the preferred option for - benchmarking. - - If `seed_strategy="no_seed"`, no seed is set internally. The user is responsible for managing seeds if - reproducibility is required. + - If `seed_strategy="per_sample"`,the `global_seed` is used as a base to derive a different seed for each + sample. This ensures reproducibility while still producing variation across samples, + making it the preferred option for benchmarking. + - If `seed_strategy="no_seed"`, no seed is set internally. + The user is responsible for managing seeds if reproducibility is required. Parameters ---------- - seed_strategy : Literal["per_evaluation", "per_sample", "no_seed"] + seed_strategy : Literal["per_sample", "no_seed"] The seeding strategy to apply. global_seed : int | None The base seed value to use (if applicable). @@ -134,6 +123,8 @@ def configure_seed( if global_seed is not None: self.global_seed = global_seed set_seed(global_seed) + else: + remove_seed() def set_seed(seed: int) -> None: @@ -152,18 +143,25 @@ def set_seed(seed: int) -> None: torch.cuda.manual_seed_all(seed) -def validate_seed_strategy( - seed_strategy: Literal["per_evaluation", "per_sample", "no_seed"], global_seed: int | None -) -> None: +def remove_seed() -> None: + """Remove the seed from the current process.""" + random.seed(None) + np.random.seed(None) + torch.manual_seed(torch.seed()) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(torch.seed()) + + +def validate_seed_strategy(seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None: """ Check the consistency of the seed strategy and the global seed. If the seed strategy is "no_seed", the global seed must be None. - If the seed strategy is "per_evaluation" or "per_sample", the user must provide a global seed. + If the seed strategy is or "per_sample", the user must provide a global seed. Parameters ---------- - seed_strategy : Literal["per_evaluation", "per_sample", "no_seed"] + seed_strategy : Literal["per_sample", "no_seed"] The seeding strategy to apply. global_seed : int | None The base seed value to use (if applicable). diff --git a/tests/engine/test_handler.py b/tests/engine/test_handler.py new file mode 100644 index 00000000..dc9e3e0e --- /dev/null +++ b/tests/engine/test_handler.py @@ -0,0 +1,108 @@ +import types +import pytest +import torch +from pruna.engine.handler.handler_inference import ( + set_seed, + validate_seed_strategy, +) +from pruna.engine.handler.handler_diffuser import DiffuserHandler +from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import move_to_device +# Default handler tests, mainly for checking seeding. + +def test_validate_seed_strategy_valid(): + '''Test to see validate_seed_strategy is valid for valid strategies''' + validate_seed_strategy("per_sample", 42) + validate_seed_strategy("no_seed", None) + + +@pytest.mark.parametrize("strategy,seed", [ + ("per_sample", None), + ("no_seed", 42), +]) +def test_validate_seed_strategy_invalid(strategy, seed): + '''Test to see validate_seed_strategy raises an error for invalid strategies''' + with pytest.raises(ValueError): + validate_seed_strategy(strategy, seed) + + +def test_set_seed_reproducibility(): + '''Test to see set_seed is reproducible''' + set_seed(42) + a = torch.randn(3) + set_seed(42) + b = torch.randn(3) + assert torch.equal(a, b) + + +# Diffuser handler tests, checking output processing and seeding. +@pytest.mark.parametrize("model_fixture", + [ + pytest.param("flux_tiny_random", marks=pytest.mark.cuda), + ], + indirect=["model_fixture"], + ) +def test_assignment_of_diffuser_handler(model_fixture): + """Check if a diffusion model is assigned to the DiffuserHandler""" + model, smash_config = model_fixture + + pruna_model = PrunaModel(model, smash_config=smash_config) + assert isinstance(pruna_model.inference_handler, DiffuserHandler) + +@pytest.mark.parametrize("model_fixture, seed, output_attr, return_dict, device", + [ + pytest.param("flux_tiny_random", 42, "images", True, "cpu", marks=pytest.mark.cpu), + pytest.param("wan_tiny_random", 42, "frames" ,True, "cuda", marks=pytest.mark.cuda), + pytest.param("flux_tiny_random", 42, "none", False, "cpu", marks=pytest.mark.cpu), + ], + indirect=["model_fixture"],) +def test_process_output_images(model_fixture, seed, output_attr, return_dict, device): + """Check if the output of the model is processed correctly""" + input_text = "a photo of a cute prune" + + model, smash_config = model_fixture + smash_config.device = device + move_to_device(model, device) + pruna_model = PrunaModel(model, smash_config=smash_config) + pruna_model.inference_handler.configure_seed("per_sample", global_seed=seed) + result = pruna_model.run_inference(input_text) + + pipe_output = model(input_text, output_type="pt", generator=torch.Generator("cpu").manual_seed(seed), return_dict=return_dict) + if output_attr != "none": + pipe_output = getattr(pipe_output, output_attr) + + assert (result == pipe_output[0]).all().item() + + +@pytest.mark.parametrize("model_fixture", + [ + pytest.param("flux_tiny_random", marks=pytest.mark.cpu), + ], + indirect=["model_fixture"],) +def test_per_sample_seed_is_applied(model_fixture): + """Check if samples change per inference run when per_sample seed is applied""" + model, smash_config = model_fixture + smash_config.device = "cpu" + move_to_device(model, "cpu") + input_text = "a photo of a cute prune" + pruna_model = PrunaModel(model, smash_config=smash_config) + pruna_model.inference_handler.configure_seed("per_sample", global_seed=42) + first_result = pruna_model.run_inference(input_text) + second_result = pruna_model.run_inference(input_text) + assert not torch.equal(first_result, second_result) + +@pytest.mark.parametrize("model_fixture", + [ + pytest.param("flux_tiny_random", marks=pytest.mark.cpu), + ], + indirect=["model_fixture"],) +def test_seed_is_removed(model_fixture): + """ Check if seed is removed when no_seed seed is applied""" + model, smash_config = model_fixture + smash_config.device = "cpu" + move_to_device(model, "cpu") + pruna_model = PrunaModel(model, smash_config=smash_config) + pruna_model.inference_handler.configure_seed("per_sample", global_seed=42) + assert pruna_model.inference_handler.model_args["generator"] is not None + pruna_model.inference_handler.configure_seed("no_seed", None) + assert pruna_model.inference_handler.model_args["generator"] is None From c2259cc41ffbf0b4cb92e380af2bcd200748cc27 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Mon, 6 Oct 2025 14:16:36 +0000 Subject: [PATCH 20/42] chore: add comments --- src/pruna/engine/handler/handler_diffuser.py | 2 ++ src/pruna/engine/handler/handler_inference.py | 4 ++++ tests/engine/test_handler.py | 5 +++++ 3 files changed, 11 insertions(+) diff --git a/src/pruna/engine/handler/handler_diffuser.py b/src/pruna/engine/handler/handler_diffuser.py index d6737a3a..0eec2adc 100644 --- a/src/pruna/engine/handler/handler_diffuser.py +++ b/src/pruna/engine/handler/handler_diffuser.py @@ -48,6 +48,7 @@ def __init__( ) -> None: self.call_signature = call_signature self.model_args = model_args if model_args else {} + # We want the default output type to be pytorch tensors. self.model_args["output_type"] = "pt" self.configure_seed(seed_strategy, global_seed) @@ -89,6 +90,7 @@ def process_output(self, output: Any) -> torch.Tensor: """ if hasattr(output, "images"): generated = output.images + # For video models. elif hasattr(output, "frames"): generated = output.frames else: diff --git a/src/pruna/engine/handler/handler_inference.py b/src/pruna/engine/handler/handler_inference.py index d6e04da2..9ed0e593 100644 --- a/src/pruna/engine/handler/handler_inference.py +++ b/src/pruna/engine/handler/handler_inference.py @@ -136,6 +136,8 @@ def set_seed(seed: int) -> None: seed : int The seed to set. """ + # With the default handler, we can't assume anything about the model, + # so we are setting the seed for all RNGs available. random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) @@ -147,6 +149,8 @@ def remove_seed() -> None: """Remove the seed from the current process.""" random.seed(None) np.random.seed(None) + # We can't really remove the seed from the PyTorch RNG, so we are reseeding with torch.seed(). + # torch.seed() creates a non-deterministic random number. torch.manual_seed(torch.seed()) if torch.cuda.is_available(): torch.cuda.manual_seed_all(torch.seed()) diff --git a/tests/engine/test_handler.py b/tests/engine/test_handler.py index dc9e3e0e..abc20efe 100644 --- a/tests/engine/test_handler.py +++ b/tests/engine/test_handler.py @@ -60,6 +60,7 @@ def test_process_output_images(model_fixture, seed, output_attr, return_dict, de """Check if the output of the model is processed correctly""" input_text = "a photo of a cute prune" + # Get the output from PrunaModel model, smash_config = model_fixture smash_config.device = device move_to_device(model, device) @@ -67,6 +68,7 @@ def test_process_output_images(model_fixture, seed, output_attr, return_dict, de pruna_model.inference_handler.configure_seed("per_sample", global_seed=seed) result = pruna_model.run_inference(input_text) + # Get the output from the pipeline. pipe_output = model(input_text, output_type="pt", generator=torch.Generator("cpu").manual_seed(seed), return_dict=return_dict) if output_attr != "none": pipe_output = getattr(pipe_output, output_attr) @@ -89,6 +91,7 @@ def test_per_sample_seed_is_applied(model_fixture): pruna_model.inference_handler.configure_seed("per_sample", global_seed=42) first_result = pruna_model.run_inference(input_text) second_result = pruna_model.run_inference(input_text) + # If seeding is successfull, each sample should create a different output. assert not torch.equal(first_result, second_result) @pytest.mark.parametrize("model_fixture", @@ -103,6 +106,8 @@ def test_seed_is_removed(model_fixture): move_to_device(model, "cpu") pruna_model = PrunaModel(model, smash_config=smash_config) pruna_model.inference_handler.configure_seed("per_sample", global_seed=42) + # First check if the seed is set. assert pruna_model.inference_handler.model_args["generator"] is not None pruna_model.inference_handler.configure_seed("no_seed", None) + # Then check if the seed generator is removed. assert pruna_model.inference_handler.model_args["generator"] is None From 8a340bb13a893792684b47d5ca11ece69c2d1f83 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Tue, 14 Oct 2025 12:00:02 +0000 Subject: [PATCH 21/42] fix: bfloats cannot be moved to cpu error in cmmd metric --- src/pruna/engine/handler/handler_diffuser.py | 8 ++------ src/pruna/evaluation/metrics/metric_cmmd.py | 2 ++ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/pruna/engine/handler/handler_diffuser.py b/src/pruna/engine/handler/handler_diffuser.py index 0eec2adc..c09d77da 100644 --- a/src/pruna/engine/handler/handler_diffuser.py +++ b/src/pruna/engine/handler/handler_diffuser.py @@ -27,10 +27,6 @@ class DiffuserHandler(InferenceHandler): """ Handle inference arguments, inputs and outputs for diffusers models. - A generator with a fixed seed (42) is passed as an argument to the model for reproducibility. - The first element of the batch is passed as input to the model. - The generated outputs are expected to have .images attribute. - Parameters ---------- call_signature : inspect.Signature @@ -95,8 +91,8 @@ def process_output(self, output: Any) -> torch.Tensor: generated = output.frames else: # Maybe the user is calling the pipeline with return_dict = False, - # which then directly returns the generated image / video. - generated = output + # which then returns the generated image / video in a tuple + generated = output[0] return generated def log_model_info(self) -> None: diff --git a/src/pruna/evaluation/metrics/metric_cmmd.py b/src/pruna/evaluation/metrics/metric_cmmd.py index 9ed5eb9d..e93442ad 100644 --- a/src/pruna/evaluation/metrics/metric_cmmd.py +++ b/src/pruna/evaluation/metrics/metric_cmmd.py @@ -100,6 +100,8 @@ def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.T The output images. """ inputs = metric_data_processor(x, gt, outputs, self.call_type, self.device) + if inputs[1].dtype == torch.bfloat16: + inputs[1] = inputs[1].to(torch.float16) gt_embeddings = self._get_embeddings(inputs[0]) output_embeddings = self._get_embeddings(inputs[1]) From 2441e9d5cc510653cecf0a13b7c73efc3854ab26 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Tue, 14 Oct 2025 12:50:50 +0000 Subject: [PATCH 22/42] fix: pre commit file fix --- .pre-commit-config.yaml | 4 ---- src/pruna/engine/pruna_model.py | 3 --- tests/engine/test_handler.py | 3 ++- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8dfbe206..3a4072ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,11 +39,7 @@ repos: grep -v "^D" | cut -f2- | while IFS= read -r file; do -<<<<<<< HEAD - if [ -f "$file" ] && [ "$file" != ".pre-commit-config.yaml" ] && grep -q "pruna_pro" "$file"; then -======= if [ -f "$file" ] && ["$file" != ".pre-commit-config.yaml"] && grep -q "pruna_pro" "$file"; then ->>>>>>> 306050c (feat: metric modalities as sets) echo "Error: pruna_pro found in staged file $file" exit 1 fi diff --git a/src/pruna/engine/pruna_model.py b/src/pruna/engine/pruna_model.py index e9554540..a3bdec14 100644 --- a/src/pruna/engine/pruna_model.py +++ b/src/pruna/engine/pruna_model.py @@ -108,9 +108,6 @@ def run_inference(self, batch: Any) -> Any: ) inference_function = getattr(self, inference_function_name) - if hasattr(self.inference_handler, "seed_strategy") and self.inference_handler.seed_strategy == "per_sample": - self.inference_handler.apply_per_sample_seed() - self.inference_handler.model_args = filter_load_kwargs(self.model.__call__, self.inference_handler.model_args) if prepared_inputs is None: diff --git a/tests/engine/test_handler.py b/tests/engine/test_handler.py index abc20efe..01d8e48d 100644 --- a/tests/engine/test_handler.py +++ b/tests/engine/test_handler.py @@ -72,8 +72,9 @@ def test_process_output_images(model_fixture, seed, output_attr, return_dict, de pipe_output = model(input_text, output_type="pt", generator=torch.Generator("cpu").manual_seed(seed), return_dict=return_dict) if output_attr != "none": pipe_output = getattr(pipe_output, output_attr) + pipe_output = pipe_output[0] - assert (result == pipe_output[0]).all().item() + assert (result == pipe_output).all().item() @pytest.mark.parametrize("model_fixture", From 52133e13aa471454e77e5b6e4354f10dec9334e4 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Tue, 28 Oct 2025 13:30:39 +0000 Subject: [PATCH 23/42] refactor: configure seeding and tests --- src/pruna/engine/handler/handler_diffuser.py | 34 +++++------ src/pruna/engine/handler/handler_inference.py | 56 +++++++++---------- tests/engine/test_handler.py | 20 ++++--- 3 files changed, 51 insertions(+), 59 deletions(-) diff --git a/src/pruna/engine/handler/handler_diffuser.py b/src/pruna/engine/handler/handler_diffuser.py index c09d77da..2d221085 100644 --- a/src/pruna/engine/handler/handler_diffuser.py +++ b/src/pruna/engine/handler/handler_diffuser.py @@ -19,7 +19,7 @@ import torch -from pruna.engine.handler.handler_inference import InferenceHandler, validate_seed_strategy +from pruna.engine.handler.handler_inference import InferenceHandler from pruna.logging.logger import pruna_logger @@ -93,35 +93,27 @@ def process_output(self, output: Any) -> torch.Tensor: # Maybe the user is calling the pipeline with return_dict = False, # which then returns the generated image / video in a tuple generated = output[0] - return generated + return generated.float() def log_model_info(self) -> None: """Log information about the inference handler.""" pruna_logger.info( "Detected diffusers model. Using DiffuserHandler.\n- The first element of the batch is passed as input.\n" + "Inference outputs are expected to have either have an `images` attribute or a `frames` attribute." + "Or be a tuple with the generated image / video as the first element." ) - def configure_seed(self, seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None: + def set_seed(self, seed: int) -> None: """ - Set the random seed according to the chosen strategy. - - - If `seed_strategy="per_sample"`,the `global_seed` is used as a base to derive a different seed for each - sample. This ensures reproducibility while still producing variation across samples, - making it the preferred option for benchmarking. - - If `seed_strategy="no_seed"`, no seed is set internally. - The user is responsible for managing seeds if reproducibility is required. + Set the random seed for the current process. Parameters ---------- - seed_strategy : Literal["per_sample", "no_seed"] - The seeding strategy to apply. - global_seed : int | None - The base seed value to use (if applicable). + seed : int + The seed to set. """ - self.seed_strategy = seed_strategy - validate_seed_strategy(seed_strategy, global_seed) - if global_seed is not None: - self.global_seed = global_seed - self.model_args["generator"] = torch.Generator("cpu").manual_seed(global_seed) - else: - self.model_args["generator"] = None # Remove the seed. + self.model_args["generator"] = torch.Generator("cpu").manual_seed(seed) + + def remove_seed(self) -> None: + """Remove the seed from the current process.""" + self.model_args["generator"] = None diff --git a/src/pruna/engine/handler/handler_inference.py b/src/pruna/engine/handler/handler_inference.py index 9ed0e593..2c1ae290 100644 --- a/src/pruna/engine/handler/handler_inference.py +++ b/src/pruna/engine/handler/handler_inference.py @@ -122,38 +122,36 @@ def configure_seed(self, seed_strategy: Literal["per_sample", "no_seed"], global validate_seed_strategy(seed_strategy, global_seed) if global_seed is not None: self.global_seed = global_seed - set_seed(global_seed) + self.set_seed(global_seed) else: - remove_seed() + self.remove_seed() + def set_seed(self, seed: int) -> None: + """ + Set the random seed for the current process. -def set_seed(seed: int) -> None: - """ - Set the random seed for the current process. - - Parameters - ---------- - seed : int - The seed to set. - """ - # With the default handler, we can't assume anything about the model, - # so we are setting the seed for all RNGs available. - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - -def remove_seed() -> None: - """Remove the seed from the current process.""" - random.seed(None) - np.random.seed(None) - # We can't really remove the seed from the PyTorch RNG, so we are reseeding with torch.seed(). - # torch.seed() creates a non-deterministic random number. - torch.manual_seed(torch.seed()) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(torch.seed()) + Parameters + ---------- + seed : int + The seed to set. + """ + # With the default handler, we can't assume anything about the model, + # so we are setting the seed for all RNGs available. + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + def remove_seed(self) -> None: + """Remove the seed from the current process.""" + random.seed(None) + np.random.seed(None) + # We can't really remove the seed from the PyTorch RNG, so we are reseeding with torch.seed(). + # torch.seed() creates a non-deterministic random number. + torch.manual_seed(torch.seed()) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(torch.seed()) def validate_seed_strategy(seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None: diff --git a/tests/engine/test_handler.py b/tests/engine/test_handler.py index 01d8e48d..5501dc80 100644 --- a/tests/engine/test_handler.py +++ b/tests/engine/test_handler.py @@ -1,11 +1,11 @@ -import types +import numpy as np import pytest import torch from pruna.engine.handler.handler_inference import ( - set_seed, validate_seed_strategy, ) from pruna.engine.handler.handler_diffuser import DiffuserHandler +from pruna.engine.handler.handler_standard import StandardHandler from pruna.engine.pruna_model import PrunaModel from pruna.engine.utils import move_to_device # Default handler tests, mainly for checking seeding. @@ -25,14 +25,16 @@ def test_validate_seed_strategy_invalid(strategy, seed): with pytest.raises(ValueError): validate_seed_strategy(strategy, seed) - def test_set_seed_reproducibility(): - '''Test to see set_seed is reproducible''' - set_seed(42) - a = torch.randn(3) - set_seed(42) - b = torch.randn(3) - assert torch.equal(a, b) + inference_handler = StandardHandler() + inference_handler.set_seed(42) + torch_random_tensor = torch.randn(3) + numpy_random_tensor = np.random.randn(3) + inference_handler.set_seed(42) + torch_expected = torch.randn(3) + numpy_expected = np.random.randn(3) + assert torch.equal(torch_random_tensor, torch_expected) + assert np.array_equal(numpy_random_tensor, numpy_expected) # Diffuser handler tests, checking output processing and seeding. From 6c1d74185a45e26b359c70b31ae59b23165c26db Mon Sep 17 00:00:00 2001 From: johannaSommer Date: Wed, 5 Nov 2025 14:28:42 +0100 Subject: [PATCH 24/42] refactor: algorithm compatibility (#401) * feat: remove algorithm groups from algorithms folder * feat: simply new algorithm registration to smash space * refactor: add new smash config interface * refactor: remove unused tokenizer name function * refactor: adjust order implementation * feat: add new graph-based path finding for algorithm execution order * tests: add first version of pre-smash-routines tests * tests: narrow down pre-smash routine tests * refactor: rename PRUNA_ALGORITHMS * refactor: enhance algorithm tags * refactor: remove `incompatible` specification * feat: add `smash_config` utility * style: initial fix all linting complaints * tests: adjust test structure to new refactoring * style: address PR comments * fix: conditionally register algorithms * fix: adjust smash config access in algorithms * fix: support older smash configs * fix: handle target module exception * fix: deprecated save/load imports * tests: update to fit recent interface changes * fix: add `global_utils` exception to algorithm registry * fix: extending compatible methods * fix: deprecate old hyperparameter interface properly * tests: add symmetry checks for algorithm order * style: address PR comments * feat: add utility to register custom algorithm * fix: insufficient docstring descriptions * fix: test references to HQQ * style: fix remaining linting errors * style: fix typing error w.r.t. compatibility setter * style: import sorting * fix: return type of registry function * fix: model context docstring * fix: some final bugs * fix: duplicate pyproject.toml key * fix: address cursorbot slander * style: move inline comments * fix: unify registry logic * feat: additional check in algorithm order overwrite * fix: documentation wording * fix: device function patching in tests --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9c8f9d74..9070a7a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,6 @@ no-matching-overload = "ignore" # mypy is more permissive with overloads unresolved-reference = "ignore" # mypy is more permissive with references missing-argument = "ignore" possibly-unbound-import = "ignore" -missing-argument = "ignore" [tool.coverage.run] source = ["src/pruna"] From 566ebb960b849acc100a32e18159ca04bda42e0b Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Thu, 27 Nov 2025 14:03:48 +0000 Subject: [PATCH 25/42] feat:stratification to vbench datasets --- src/pruna/evaluation/metrics/vbench_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/pruna/evaluation/metrics/vbench_utils.py b/src/pruna/evaluation/metrics/vbench_utils.py index 180f4c4d..9a130cdf 100644 --- a/src/pruna/evaluation/metrics/vbench_utils.py +++ b/src/pruna/evaluation/metrics/vbench_utils.py @@ -26,6 +26,7 @@ from torchvision.transforms import ToTensor from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.data.utils import define_sample_size_for_dataset, stratify_dataset from pruna.engine.utils import safe_memory_cleanup, set_to_best_available_device from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.result import MetricResult @@ -204,7 +205,7 @@ def _normalize_save_format(save_format: str) -> tuple[str, Callable]: def _normalize_prompts( - prompts: str | List[str] | PrunaDataModule, split: str = "test", batch_size: int = 1 + prompts: str | List[str] | PrunaDataModule, split: str = "test", batch_size: int = 1, num_samples: int | None = None, fraction: float = 1.0 ) -> Iterable[str]: """ Normalize prompts to an iterable format to be used in the generate_videos function. @@ -224,6 +225,9 @@ def _normalize_prompts( if isinstance(prompts, str): return [prompts] elif isinstance(prompts, PrunaDataModule): + target_dataset = getattr(prompts, f"{split}_dataset") + sample_size = define_sample_size_for_dataset(target_dataset, fraction, num_samples) + setattr(prompts, f"{split}_dataset", stratify_dataset(target_dataset, sample_size)) return getattr(prompts, f"{split}_dataloader")(batch_size=batch_size) else: # list of prompts, already iterable return prompts @@ -342,6 +346,8 @@ def sampler(*, prompt: str, seeder: Any, device: str | torch.device, **kwargs: A def generate_videos( model: Any, prompts: str | List[str] | PrunaDataModule, + num_samples: int | None = None, + samples_fraction: float = 1.0, split: str = "test", unique_sample_per_video_count: int = 1, global_seed: int = 42, @@ -396,7 +402,8 @@ def generate_videos( device = set_to_best_available_device(device) - prompt_iterable = _normalize_prompts(prompts, split, batch_size=1) + prompt_iterable = _normalize_prompts(prompts, split, batch_size=1, num_samples=num_samples, fraction=samples_fraction) + save_dir = Path(save_dir) _ensure_dir(save_dir) From b581c3184976133a9716d29390490b0e4f3af959 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Thu, 27 Nov 2025 16:09:44 +0000 Subject: [PATCH 26/42] feat: data stratification by indexing --- src/pruna/data/utils.py | 15 ++++-- src/pruna/evaluation/metrics/vbench_utils.py | 56 ++++++++++++++++---- 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/src/pruna/data/utils.py b/src/pruna/data/utils.py index fb258afa..45fb8517 100644 --- a/src/pruna/data/utils.py +++ b/src/pruna/data/utils.py @@ -184,7 +184,7 @@ def recover_text_from_dataloader(dataloader: DataLoader, tokenizer: Any) -> list return texts -def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Dataset: +def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42, partition_strategy: str = "random", partition_index: int = 0) -> Dataset: """ Stratify the dataset into a specific size. @@ -196,6 +196,10 @@ def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Data The size to stratify. seed : int The seed to use for sampling the dataset. + partition_strategy : str + The strategy to use for partitioning the dataset. Can be "indexed" or "random". + partition_index : int + The index to use for partitioning the dataset. Returns ------- @@ -211,8 +215,13 @@ def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Data return dataset indices = list(range(dataset_length)) - random.Random(seed).shuffle(indices) - selected_indices = indices[:sample_size] + if partition_strategy == "indexed": + selected_indices = indices[sample_size*partition_index:sample_size*(partition_index+1)] + elif partition_strategy == "random": + random.Random(seed).shuffle(indices) + selected_indices = indices[:sample_size] + else: + raise ValueError(f"Invalid partition strategy: {partition_strategy}") dataset = dataset.select(selected_indices) return dataset diff --git a/src/pruna/evaluation/metrics/vbench_utils.py b/src/pruna/evaluation/metrics/vbench_utils.py index 9a130cdf..2fb039e7 100644 --- a/src/pruna/evaluation/metrics/vbench_utils.py +++ b/src/pruna/evaluation/metrics/vbench_utils.py @@ -17,6 +17,7 @@ import re from pathlib import Path from typing import Any, Callable, Iterable, List +import hashlib import numpy as np import torch @@ -82,6 +83,19 @@ def validate_batch(self, batch: torch.Tensor) -> torch.Tensor: raise ValueError(f"Batch must be 4 or 5 dimensional video tensor with B,T,C,H,W, got {batch.ndim}") return batch +def get_sample_seed(experiment_name: str, prompt: str, index: int) -> int: + """ + Get a sample seed for a given experiment name, prompt, and index. + """ + key = f"{experiment_name}_{prompt}_{index}".encode('utf-8') + + return int(hashlib.sha256(key).hexdigest(), 16) % (2**32) + +def is_file_exists(path: str | Path, filename: str) -> bool: + folder = Path(path) + full_path = folder / filename + + return full_path.is_file() def load_video(path: str | Path, return_type: str = "pt") -> List[Image] | np.ndarray | torch.Tensor: """ @@ -205,7 +219,7 @@ def _normalize_save_format(save_format: str) -> tuple[str, Callable]: def _normalize_prompts( - prompts: str | List[str] | PrunaDataModule, split: str = "test", batch_size: int = 1, num_samples: int | None = None, fraction: float = 1.0 + prompts: str | List[str] | PrunaDataModule, split: str = "test", batch_size: int = 1, num_samples: int | None = None, fraction: float = 1.0, data_partition_strategy: str = "indexed", partition_index: int = 0, seed: int = 42 ) -> Iterable[str]: """ Normalize prompts to an iterable format to be used in the generate_videos function. @@ -216,7 +230,18 @@ def _normalize_prompts( The prompts to normalize. split : str The dataset split to sample from. - + batch_size : int + The batch size to sample from. + num_samples : int | None + The number of samples to sample from. + fraction : float + The fraction of the dataset to sample from. + data_partition_strategy : str + The strategy to use for partitioning the dataset. Can be "indexed" or "random". + partition_index : int + The index to use for partitioning the dataset. + seed : int + The seed to use for partitioning the dataset. Returns ------- Iterable[str] @@ -227,7 +252,7 @@ def _normalize_prompts( elif isinstance(prompts, PrunaDataModule): target_dataset = getattr(prompts, f"{split}_dataset") sample_size = define_sample_size_for_dataset(target_dataset, fraction, num_samples) - setattr(prompts, f"{split}_dataset", stratify_dataset(target_dataset, sample_size)) + setattr(prompts, f"{split}_dataset", stratify_dataset(target_dataset, sample_size, seed, data_partition_strategy, partition_index)) return getattr(prompts, f"{split}_dataloader")(batch_size=batch_size) else: # list of prompts, already iterable return prompts @@ -337,8 +362,8 @@ def _wrap_sampler(model: Any, sampling_fn: Callable[..., Any]) -> Callable[..., ) # The sampling function may expect the model as "pipeline" so we pass it as an arg and not a kwarg. - def sampler(*, prompt: str, seeder: Any, device: str | torch.device, **kwargs: Any) -> Any: - return sampling_fn(model, prompt=prompt, seeder=seeder, device=device, **kwargs) + def sampler(*, prompt: str, seeder: Any, **kwargs: Any) -> Any: + return sampling_fn(model, prompt=prompt, seeder=seeder, **kwargs) return sampler @@ -348,6 +373,8 @@ def generate_videos( prompts: str | List[str] | PrunaDataModule, num_samples: int | None = None, samples_fraction: float = 1.0, + data_partition_strategy: str = "indexed", + partition_index: int = 0, split: str = "test", unique_sample_per_video_count: int = 1, global_seed: int = 42, @@ -358,6 +385,8 @@ def generate_videos( filename_fn: Callable = create_vbench_file_name, special_str: str = "", device: str | torch.device = None, + experiment_name: str = "", + sampling_seed_fn: Callable[..., Any] = get_sample_seed, **model_kwargs, ) -> None: """ @@ -402,15 +431,16 @@ def generate_videos( device = set_to_best_available_device(device) - prompt_iterable = _normalize_prompts(prompts, split, batch_size=1, num_samples=num_samples, fraction=samples_fraction) + prompt_iterable = _normalize_prompts(prompts, split, batch_size=1, num_samples=num_samples, fraction=samples_fraction, data_partition_strategy=data_partition_strategy, partition_index=partition_index) save_dir = Path(save_dir) _ensure_dir(save_dir) # set a run-level seed (VBench suggests this) (important for reproducibility) - seed_rng = torch.Generator().manual_seed(global_seed) + seed_rng = lambda x: torch.Generator("cpu").manual_seed(x) sampler = _wrap_sampler(model=model, sampling_fn=sampling_fn) + for batch in prompt_iterable: prompt = prepare_batch(batch) @@ -418,11 +448,15 @@ def generate_videos( file_name = filename_fn(sanitize_prompt(prompt), idx, special_str, file_extension) out_path = save_dir / file_name - vid = sampler(prompt=prompt, seeder=seed_rng, device=device, **model_kwargs) - save_fn(vid, out_path, fps=fps) + if is_file_exists(save_dir, file_name): + continue + else: + seed = sampling_seed_fn(experiment_name, prompt, idx) + vid = sampler(prompt=prompt, seeder=seed_rng(seed), **model_kwargs) + save_fn(vid, out_path, fps=fps) - del vid - safe_memory_cleanup() + del vid + safe_memory_cleanup() def evaluate_videos( From f2059b9548277c902424e816429aaeca9a6264f0 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Wed, 3 Dec 2025 15:13:01 +0100 Subject: [PATCH 27/42] Add image artifactsaver and modified utils to use it for algo sweeper in prime intellect --- .../artifactsavers/image_artifactsaver.py | 84 +++++++++++++++++++ src/pruna/evaluation/artifactsavers/utils.py | 3 + 2 files changed, 87 insertions(+) create mode 100644 src/pruna/evaluation/artifactsavers/image_artifactsaver.py diff --git a/src/pruna/evaluation/artifactsavers/image_artifactsaver.py b/src/pruna/evaluation/artifactsavers/image_artifactsaver.py new file mode 100644 index 00000000..4d1c47f8 --- /dev/null +++ b/src/pruna/evaluation/artifactsavers/image_artifactsaver.py @@ -0,0 +1,84 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import secrets +import time +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from PIL import Image + +from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver + + +class ImageArtifactSaver(ArtifactSaver): + """ + Save image artifacts. + + Parameters + ---------- + root: Path | str | None = None + The root directory to save the artifacts. + export_format: str | None = "png" + The format to save the artifacts (e.g. "png", "jpg", "jpeg", "webp"). + """ + + export_format: str | None + root: Path | str | None + + def __init__(self, root: Path | str | None = None, export_format: str | None = "png") -> None: + self.root = Path(root) if root else Path.cwd() + (self.root / "canonical").mkdir(parents=True, exist_ok=True) + self.export_format = export_format if export_format else "png" + if self.export_format not in ["png", "jpg", "jpeg", "webp"]: + raise ValueError(f"Invalid format: {self.export_format}. Valid formats are: png, jpg, jpeg, webp.") + + def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: + """ + Save the image artifact. + + Parameters + ---------- + data: Any + The data to save. + saving_kwargs: dict + The additional kwargs to pass to the saving utility function. + + Returns + ------- + Path + The path to the saved artifact. + """ + canonical_filename = f"{int(time.time())}_{secrets.token_hex(4)}.{self.export_format}" + canonical_path = Path(str(self.root)) / "canonical" / canonical_filename + + # We save the image as a PIL Image, so we need to convert the data to a PIL Image. + # Usually, the data is already a PIL.Image, so we don't need to convert it. + if isinstance(data, torch.Tensor): + data = np.transpose(data.cpu().numpy(), (1, 2, 0)) + data = np.clip(data * 255, 0, 255).astype(np.uint8) + if isinstance(data, np.ndarray): + data = Image.fromarray(data.astype(np.uint8)) + # Now data must be a PIL Image + if not isinstance(data, Image.Image): + raise ValueError("Model outputs must be torch.Tensor, numpy.ndarray, or PIL.Image.") + + # Save the image (export format is determined by the file extension) + data.save(canonical_path, **saving_kwargs.copy()) + + return canonical_path diff --git a/src/pruna/evaluation/artifactsavers/utils.py b/src/pruna/evaluation/artifactsavers/utils.py index 6954dfc9..ef903b93 100644 --- a/src/pruna/evaluation/artifactsavers/utils.py +++ b/src/pruna/evaluation/artifactsavers/utils.py @@ -17,6 +17,7 @@ from pathlib import Path from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver +from pruna.evaluation.artifactsavers.image_artifactsaver import ImageArtifactSaver from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver @@ -42,5 +43,7 @@ def assign_artifact_saver( """ if modality == "video": return VideoArtifactSaver(root=root, export_format=export_format) + if modality == "image": + return ImageArtifactSaver(root=root, export_format=export_format) else: raise ValueError(f"Modality {modality} is not supported") From 7d92f789da361e82a2d35f5c45fe96c5dae9ecb9 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Wed, 3 Dec 2025 15:13:18 +0100 Subject: [PATCH 28/42] Changes to vbench-utils via ruff --- src/pruna/evaluation/metrics/vbench_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pruna/evaluation/metrics/vbench_utils.py b/src/pruna/evaluation/metrics/vbench_utils.py index 2fb039e7..ca3d7f6d 100644 --- a/src/pruna/evaluation/metrics/vbench_utils.py +++ b/src/pruna/evaluation/metrics/vbench_utils.py @@ -83,6 +83,7 @@ def validate_batch(self, batch: torch.Tensor) -> torch.Tensor: raise ValueError(f"Batch must be 4 or 5 dimensional video tensor with B,T,C,H,W, got {batch.ndim}") return batch + def get_sample_seed(experiment_name: str, prompt: str, index: int) -> int: """ Get a sample seed for a given experiment name, prompt, and index. @@ -91,12 +92,14 @@ def get_sample_seed(experiment_name: str, prompt: str, index: int) -> int: return int(hashlib.sha256(key).hexdigest(), 16) % (2**32) + def is_file_exists(path: str | Path, filename: str) -> bool: folder = Path(path) full_path = folder / filename - + return full_path.is_file() + def load_video(path: str | Path, return_type: str = "pt") -> List[Image] | np.ndarray | torch.Tensor: """ Load videos from a path. @@ -242,6 +245,7 @@ def _normalize_prompts( The index to use for partitioning the dataset. seed : int The seed to use for partitioning the dataset. + Returns ------- Iterable[str] @@ -433,14 +437,12 @@ def generate_videos( prompt_iterable = _normalize_prompts(prompts, split, batch_size=1, num_samples=num_samples, fraction=samples_fraction, data_partition_strategy=data_partition_strategy, partition_index=partition_index) - save_dir = Path(save_dir) _ensure_dir(save_dir) # set a run-level seed (VBench suggests this) (important for reproducibility) seed_rng = lambda x: torch.Generator("cpu").manual_seed(x) sampler = _wrap_sampler(model=model, sampling_fn=sampling_fn) - for batch in prompt_iterable: prompt = prepare_batch(batch) From 6a58933c5c21cc82d90fd9756270dcc6d1557b9e Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Wed, 10 Dec 2025 16:00:27 +0100 Subject: [PATCH 29/42] =?UTF-8?q?Add=20filename=20sanitizer=20which=20modi?= =?UTF-8?q?fies=20invalid=20filename=C2=B4aliases?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../artifactsavers/artifactsaver.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/pruna/evaluation/artifactsavers/artifactsaver.py b/src/pruna/evaluation/artifactsavers/artifactsaver.py index 51bbb3d5..baaf1ae9 100644 --- a/src/pruna/evaluation/artifactsavers/artifactsaver.py +++ b/src/pruna/evaluation/artifactsavers/artifactsaver.py @@ -14,6 +14,9 @@ from __future__ import annotations +import re +import unicodedata + from abc import ABC, abstractmethod from pathlib import Path from typing import Any @@ -55,7 +58,7 @@ def save_artifact(self, data: Any) -> Path: """ pass - def create_alias(self, source_path: Path | str, filename: str) -> Path: + def create_alias(self, source_path: Path | str, filename: str, sanitize: bool = True) -> Path: """ Create an alias for the artifact. @@ -82,6 +85,8 @@ def create_alias(self, source_path: Path | str, filename: str) -> Path: Path The full path to the alias. """ + if sanitize: + filename = sanitize_filename(filename) alias = Path(str(self.root)) / f"{filename}.{self.export_format}" alias.parent.mkdir(parents=True, exist_ok=True) try: @@ -96,3 +101,39 @@ def create_alias(self, source_path: Path | str, filename: str) -> Path: except Exception as e: raise e return alias + + +def sanitize_filename(name: str, max_length: int = 128) -> str: + + """ Sanitize a filename to make it safe for the filesystem. Works for every OS. + + Parameters + ---------- + name: str + The name to sanitize. + max_length: int + The maximum length of the sanitized name. If it is exceeded, the name is truncated to max_length. + Returns + ------- + str + The sanitized name. If the name is empty, "untitled" is returned. + + """ + name = str(name) + name = unicodedata.normalize('NFKD', name) + # Forbidden characters + name = re.sub(r'[<>:"/\\|?*]', '_', name) + # Whitespace -> underscore + name = re.sub(r'\s+', '_', name) + # Control chars weg + name = re.sub(r'[\x00-\x1f\x7f]', "", name) + # Collapse multiple underscores into one + name = re.sub(r'_+', '_', name) + # remove leading/trailing dots/spaces/underscores + name = name.strip(" ._") + # limit length + if len(name) > max_length: + name = name[:max_length].rstrip("._ ") + if name == "": + name = "untitled" + return name From e821e473156f293581d891a1878587d700efe8d9 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Wed, 10 Dec 2025 16:04:34 +0100 Subject: [PATCH 30/42] Change file format via ruff --- src/pruna/evaluation/artifactsavers/artifactsaver.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/pruna/evaluation/artifactsavers/artifactsaver.py b/src/pruna/evaluation/artifactsavers/artifactsaver.py index baaf1ae9..3f116447 100644 --- a/src/pruna/evaluation/artifactsavers/artifactsaver.py +++ b/src/pruna/evaluation/artifactsavers/artifactsaver.py @@ -16,7 +16,6 @@ import re import unicodedata - from abc import ABC, abstractmethod from pathlib import Path from typing import Any @@ -104,15 +103,15 @@ def create_alias(self, source_path: Path | str, filename: str, sanitize: bool = def sanitize_filename(name: str, max_length: int = 128) -> str: + """Sanitize a filename to make it safe for the filesystem. Works for every OS. - """ Sanitize a filename to make it safe for the filesystem. Works for every OS. - Parameters ---------- name: str The name to sanitize. max_length: int The maximum length of the sanitized name. If it is exceeded, the name is truncated to max_length. + Returns ------- str From 88791959212668952be4c9b5e703e4a8f4e553c5 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Wed, 10 Dec 2025 16:06:02 +0100 Subject: [PATCH 31/42] Add helper function for creating aliases as prompt names for outputs --- src/pruna/evaluation/evaluation_agent.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index f705f6ec..e76c99da 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -256,6 +256,9 @@ def update_stateful_metrics( for processed_output in processed_outputs: canonical_path = self.artifact_saver.save_artifact(processed_output) canonical_paths.append(canonical_path) + # Create aliases for the prompts if the user wants to save the artifacts with the prompt name. + # For doing that, the user needs to set the saving_kwargs["save_as_prompt_name"] to True. + self._maybe_create_prompt_aliases(batch, canonical_paths, sample_idx) batch = move_batch_to_device(batch, self.device) processed_outputs = move_batch_to_device(processed_outputs, self.device) @@ -344,3 +347,12 @@ def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[A for metric in children_of_base: results.append(metric.compute(model, self.task.dataloader)) return results + + def _maybe_create_prompt_aliases(self, batch, canonical_paths, sample_idx): + if not self.saving_kwargs.get("save_as_prompt_name", False): + return + + (x, _) = batch + for prompt_idx, prompt in enumerate(x): + alias_name = f"{prompt}-{sample_idx}" + self.artifact_saver.create_alias(canonical_paths[prompt_idx], alias_name) From 40d09bdb65420b8801ccb62e59e23552cedc5f1d Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Tue, 9 Dec 2025 15:35:02 +0530 Subject: [PATCH 32/42] Enable TruffleHog in pre-commit (#439) * update pre-commit * rm redudant filters. * fix nits and whitespacing issues. * Update versions --- .pre-commit-config.yaml | 46 ++++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a4072ad..8daac431 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: check-added-large-files - id: check-case-conflict @@ -13,21 +13,45 @@ repos: - id: trailing-whitespace exclude: \.md$ - id: no-commit-to-branch + - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.12 + rev: v0.14.6 hooks: - id: ruff-check args: [--fix] + files: \.py$ - id: ruff-format + files: \.py$ + + - repo: https://github.com/trufflesecurity/trufflehog + rev: v3.91.1 + hooks: + - id: trufflehog + name: TruffleHog Secrets Scanner + entry: trufflehog + language: golang + types_or: [python, yaml, json, text] + args: + [ + "filesystem", + "src", + "tests", + ".github/workflows", + "--results=verified,unknown", + "--exclude-paths=.venv", + "--fail" + ] + stages: ["pre-commit", "pre-push"] - repo: local hooks: - id: ty - name: ty check + name: type checking using ty entry: uvx ty check . language: system types: [python] pass_filenames: false + files: \.py$ - repo: local hooks: @@ -39,19 +63,13 @@ repos: grep -v "^D" | cut -f2- | while IFS= read -r file; do - if [ -f "$file" ] && ["$file" != ".pre-commit-config.yaml"] && grep -q "pruna_pro" "$file"; then - echo "Error: pruna_pro found in staged file $file" - exit 1 - fi + if [ -f "$file" ] && [ "$file" != ".pre-commit-config.yaml" ] && grep -q "pruna_pro" "$file"; then + echo "Error: pruna_pro found in staged file $file" + exit 1 + fi done ' language: system stages: [pre-commit] types: [python] - exclude: "^docs/" - - id: trufflehog - name: TruffleHog - description: Detect secrets in your data. - entry: bash -c 'git diff --cached --name-only | xargs -I {} trufflehog filesystem {} --fail --no-update' - language: system - stages: ["pre-commit", "pre-push"] + files: \.py$ From 607edefd4fc2ed5807599a04bb11807837174a44 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Thu, 9 Oct 2025 12:54:51 +0000 Subject: [PATCH 33/42] feat: 0 vbench dimensions and vbench dependencies --- .pre-commit-config.yaml | 8 +- pyproject.toml | 1 + .../metric_vbench_background_consistency.py | 42 +-- .../metrics/metric_vbench_dynamic_degree.py | 77 ++--- src/pruna/evaluation/metrics/vbench_utils.py | 283 +++++------------- 5 files changed, 127 insertions(+), 284 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8daac431..ba01405d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -63,10 +63,10 @@ repos: grep -v "^D" | cut -f2- | while IFS= read -r file; do - if [ -f "$file" ] && [ "$file" != ".pre-commit-config.yaml" ] && grep -q "pruna_pro" "$file"; then - echo "Error: pruna_pro found in staged file $file" - exit 1 - fi + if [ -f "$file" ] && [ "$file" != ".pre-commit-config.yaml" ] && grep -q "pruna_pro" "$file"; then + echo "Error: pruna_pro found in staged file $file" + exit 1 + fi done ' language: system diff --git a/pyproject.toml b/pyproject.toml index 9070a7a0..0702c75e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -190,6 +190,7 @@ dev = [ ] cpu = [] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py index 87e6ff2b..acf3df56 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py +++ b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py @@ -19,9 +19,9 @@ import clip import torch import torch.nn.functional as F # noqa: N812 -from torchvision.transforms.functional import convert_image_dtype from vbench.utils import clip_transform, init_submodules +from pruna.engine.utils import set_to_best_available_device from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -41,6 +41,8 @@ class VBenchBackgroundConsistency(StatefulMetric, VBenchMixin): ---------- *args : Any The arguments to pass to the metric. + device : str | None + The device to run the metric on. call_type : str The call type to use for the metric. **kwargs : Any @@ -50,7 +52,9 @@ class VBenchBackgroundConsistency(StatefulMetric, VBenchMixin): metric_name: str = METRIC_VBENCH_BACKGROUND_CONSISTENCY default_call_type: str = "y" # We just need the outputs higher_is_better: bool = True - runs_on: List[str] = ["cuda", "cpu"] + # https://github.com/Vchitect/VBench/blob/dc62783c0fb4fd333249c0b669027fe102696682/evaluate.py#L111 + # explicitly sets the device to cuda. We respect this here. + runs_on: List[str] = ["cuda"] modality: List[str] = ["video"] # state similarity_scores: torch.Tensor @@ -59,10 +63,15 @@ class VBenchBackgroundConsistency(StatefulMetric, VBenchMixin): def __init__( self, *args: Any, + device: str | None = None, call_type: str = SINGLE, **kwargs: Any, ) -> None: - super().__init__(kwargs.pop("device", None)) + super().__init__(*args, **kwargs) + + if device is not None and str(device).split(":")[0] not in self.runs_on: + pruna_logger.error(f"Unsupported device {device}; supported: {self.runs_on}") + raise ValueError() if call_type == PAIRWISE: # VBench itself does not support pairwise. @@ -73,13 +82,13 @@ def __init__( submodules_dict = init_submodules([METRIC_VBENCH_BACKGROUND_CONSISTENCY]) model_path = submodules_dict[METRIC_VBENCH_BACKGROUND_CONSISTENCY][0] + self.device = set_to_best_available_device(device) self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) self.clip_model, self.preprocessor = clip.load(model_path, device=self.device) - # Cropping for the CLIP encoder. self.video_transform = clip_transform(224) - self.add_state("similarity_scores_cumsum", torch.tensor(0.0)) + self.add_state("similarity_scores", torch.tensor(0.0)) self.add_state("n_samples", torch.tensor(0)) def update(self, x: List[str], gt: Any, outputs: Any) -> None: @@ -98,23 +107,19 @@ def update(self, x: List[str], gt: Any, outputs: Any) -> None: outputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device) # Background consistency metric only supports a batch size of 1. # To support larger batch sizes, we stack the outputs. - outputs = super().validate_batch(outputs[0]) - # This metric depends on the outputs being uint8. - outputs = torch.stack([convert_image_dtype(output, dtype=torch.uint8) for output in outputs]) - outputs = torch.stack([self.video_transform(output) for output in outputs]) + outputs = torch.stack([self.video_transform(output) for output in outputs[0]]) features = torch.stack([self.clip_model.encode_image(output) for output in outputs]) - features = torch.stack([F.normalize(feature, dim=-1, p=2) for feature in features]) + features = F.normalize(features, dim=-1, p=2) - # We vectorize the calculation to avoid for loops. - first_feature = features[:, 0, ...].unsqueeze(1).repeat(1, features.shape[1] - 1, 1) + first_feature = features[0].unsqueeze(0) - similarity_to_first = F.cosine_similarity(first_feature, features[:, 1:, ...], dim=-1).clamp(min=0.0) - similarity_to_prev = F.cosine_similarity(features[:, :-1, ...], features[:, 1:, ...], dim=-1).clamp(min=0.0) + similarity_to_first = F.cosine_similarity(first_feature, features[1:]).clamp(min=0.0) + similarity_to_prev = F.cosine_similarity(features[:-1], features[1:]).clamp(min=0.0) similarities = (similarity_to_first + similarity_to_prev) / 2 # Update stats - self.similarity_scores_cumsum += similarities.sum().item() + self.similarity_scores += similarities.sum().item() self.n_samples += similarities.numel() def compute(self) -> MetricResult: @@ -126,13 +131,10 @@ def compute(self) -> MetricResult: MetricResult The final score. """ - if self.n_samples == 0: - return MetricResult(self.metric_name, self.__dict__, 0.0) - score = self.similarity_scores_cumsum / self.n_samples + score = self.similarity_scores / self.n_samples return MetricResult(self.metric_name, self.__dict__, score) def reset(self) -> None: """Reset the metric states.""" - super().reset() - self.similarity_scores_cumsum = torch.tensor(0.0) + self.similarity_scores = torch.tensor(0.0) self.n_samples = torch.tensor(0) diff --git a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py index 83a505f4..dd3f1218 100644 --- a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py +++ b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py @@ -23,6 +23,7 @@ from vbench.third_party.RAFT.core.utils_core.utils import InputPadder from vbench.utils import init_submodules +from pruna.engine.utils import set_to_best_available_device from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -34,18 +35,9 @@ class PrunaDynamicDegree(DynamicDegree): - """ - Helper class to compute Dynamic Degree score for a given video. - - Parameters - ---------- - args : EasyDict - The arguments to pass to the RAFT model. - device : str | torch.device - The device to use for the model. - """ + """Helper class to compute Dynamic Degree score for a given video.""" - def infer(self, frames: torch.Tensor, interval: int) -> bool: + def infer(self, frames: torch.Tensor) -> bool: """ Compute Dynamic Degree score for a given video. @@ -53,28 +45,19 @@ def infer(self, frames: torch.Tensor, interval: int) -> bool: Parameters ---------- - frames : torch.Tensor + frames: torch.Tensor The video frames to compute the Dynamic Degree score for. - interval : int - The interval to skip frames. It's possible for each consecutive frame to not have extreme motion, - even though the video itself contains large dynamic changes. - Therefore it's important to set the inteval to skip frames correctly. Returns ------- bool Whether the video contains large motions. """ - frames = [fr.unsqueeze(0) for fr in frames] - - frames = self.extract_frame(frames, interval=max(1, interval)) self.set_params(frame=frames[0], count=len(frames)) - static_score = [] for image1, image2 in zip(frames[:-1], frames[1:]): padder = InputPadder(image1.shape) image1, image2 = padder.pad(image1, image2) - # 20 iterations as the original DynamicDegree implementation. _, flow_up = self.model(image1, image2, iters=20, test_mode=True) max_rad = self.get_score(image1, flow_up) static_score.append(max_rad) @@ -94,19 +77,11 @@ class VBenchDynamicDegree(StatefulMetric, VBenchMixin): Parameters ---------- - *args : Any - The arguments to be passed to the DynamicDegree class. - call_type : str, default="y" + device: str | None, optional + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. + call_type: str, default="y" The call type to be used, e.g., 'y' or 'y_gt'. Default is "y". - interval : int, default=3 - The interval to be used to extract frames from the video. - The default Vbench dimension loads videos from file and preprocesses them to have 8 frames per second. - For instance, if the video is 24fps, Vbench will only get every 3rd frame. - Here, we deal directly with the model outputs, so we initialize the interval to be 3, - which is a reasonable skip interval. - Feel free to change this to your needs. - **kwargs : Any - The keyword arguments to be passed to the DynamicDegree class. """ metric_name: str = METRIC_VBENCH_DYNAMIC_DEGREE @@ -117,16 +92,20 @@ class VBenchDynamicDegree(StatefulMetric, VBenchMixin): runs_on: List[str] = ["cuda"] modality: List[str] = ["video"] # state - scores: List[bool] + scores: List[float] def __init__( self, *args: Any, + device: str | None = None, call_type: str = SINGLE, - interval: int = 3, **kwargs: Any, ) -> None: - super().__init__(device=kwargs.pop("device", None)) + super().__init__(*args, **kwargs) + + if device is not None and str(device).split(":")[0] not in self.runs_on: + pruna_logger.error(f"Unsupported device {device}; supported: {self.runs_on}") + raise ValueError() if call_type == PAIRWISE: # VBench itself does not support pairwise. @@ -134,15 +113,15 @@ def __init__( pruna_logger.error("VBench does not support pairwise metrics. Please use single mode.") raise ValueError() - self.interval = interval submodules_dict = init_submodules([METRIC_VBENCH_DYNAMIC_DEGREE]) model_path = submodules_dict[METRIC_VBENCH_DYNAMIC_DEGREE]["model"] + self.device = set_to_best_available_device(device) self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) # RAFT models expect arguments to be passed as an object with attributes. # So we need to convert the arguments to an EasyDict. args_new = EasyDict({"model": model_path, "small": False, "mixed_precision": False, "alternate_corr": False}) - self.DynamicDegree = PrunaDynamicDegree(args_new, self.device) + self.DynamicDegree = PrunaDynamicDegree(args_new, device) self.add_state("scores", []) @torch.no_grad() @@ -150,6 +129,8 @@ def update(self, x: List[str], gt: Any, outputs: Any) -> None: """ Calculate the dynamic degree score for the given video. + The video is preprocessed to have approx. 8 frames per second. + Then passed to the RAFT model to calculate the dynamic degree score. Each video is ranked as a 1 or 0 based on whether it contains large motions. @@ -157,21 +138,16 @@ def update(self, x: List[str], gt: Any, outputs: Any) -> None: Parameters ---------- - x : List[str] + x: List[str] The list of input videos. - gt : Any + gt: Any The ground truth videos. - outputs : Any - The generated videos. Should be a tensor of shape (T, C, H, W) or (B, T, C, H, W). - where B is the batch size, T is the number of frames, C is the number of channels, H is the height, - and W is the width. + outputs: Any + The generated videos. """ outputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device) - videos = super().validate_batch(outputs[0]) - - for video in videos: - score = self.DynamicDegree.infer(video, self.interval) - self.scores.append(score) + score = self.DynamicDegree.infer(outputs) + self.scores.append(score) def compute(self) -> MetricResult: """ @@ -184,9 +160,6 @@ def compute(self) -> MetricResult: MetricResult The dynamic degree score. """ - if len(self.scores) == 0: - pruna_logger.warning("No scores have been computed. Returning 0.0.") - return MetricResult(name=self.metric_name, params=self.__dict__, result=0.0) final_score = np.mean(self.scores) return MetricResult(name=self.metric_name, params=self.__dict__, result=final_score) diff --git a/src/pruna/evaluation/metrics/vbench_utils.py b/src/pruna/evaluation/metrics/vbench_utils.py index ca3d7f6d..c64a9bba 100644 --- a/src/pruna/evaluation/metrics/vbench_utils.py +++ b/src/pruna/evaluation/metrics/vbench_utils.py @@ -17,20 +17,16 @@ import re from pathlib import Path from typing import Any, Callable, Iterable, List -import hashlib import numpy as np import torch -from diffusers.utils import export_to_gif, export_to_video -from diffusers.utils import load_video as diffusers_load_video +from diffusers.utils import export_to_gif, export_to_video, load_video from PIL.Image import Image from torchvision.transforms import ToTensor +from vbench import VBench from pruna.data.pruna_datamodule import PrunaDataModule -from pruna.data.utils import define_sample_size_for_dataset, stratify_dataset from pruna.engine.utils import safe_memory_cleanup, set_to_best_available_device -from pruna.evaluation.metrics.metric_stateful import StatefulMetric -from pruna.evaluation.metrics.result import MetricResult from pruna.logging.logger import pruna_logger @@ -39,21 +35,34 @@ class VBenchMixin: Mixin class for VBench metrics. Handles benchmark specific initilizations and artifact saving conventions. + + Parameters + ---------- + *args: Any + The arguments to pass to the metric. + **kwargs: Any + The keyword arguments to pass to the metric. """ + def __init__(self, *args: Any, **kwargs: Any) -> None: + if VBench is None: + # Additional debug info to debug installation issues + pruna_logger.debug("Initialization failed: VBench is None") + raise ImportError("VBench is not installed. Please check your pruna installation.") + def create_filename(self, prompt: str, idx: int, file_extension: str, special_str: str = "") -> str: """ Create filename according to VBench formatting conventions. Parameters ---------- - prompt : str + prompt: str The prompt to create the filename from. - idx : int + idx: int The index of the video. Vbench uses 5 seeds for each prompt. - file_extension : str + file_extension: str The file extension to use. Vbench supports mp4 and gif. - special_str : str + special_str: str A special string to add to the filename if you wish to add a specific identifier. Returns @@ -63,60 +72,24 @@ def create_filename(self, prompt: str, idx: int, file_extension: str, special_st """ return create_vbench_file_name(sanitize_prompt(prompt), idx, special_str, file_extension) - def validate_batch(self, batch: torch.Tensor) -> torch.Tensor: - """ - Make sure that the video tensor has correct dimensions. - Parameters - ---------- - batch : torch.Tensor - The video tensor. - - Returns - ------- - torch.Tensor - The video tensor. - """ - if batch.ndim == 4: - return batch.unsqueeze(0) - elif batch.ndim != 5: - raise ValueError(f"Batch must be 4 or 5 dimensional video tensor with B,T,C,H,W, got {batch.ndim}") - return batch - - -def get_sample_seed(experiment_name: str, prompt: str, index: int) -> int: - """ - Get a sample seed for a given experiment name, prompt, and index. - """ - key = f"{experiment_name}_{prompt}_{index}".encode('utf-8') - - return int(hashlib.sha256(key).hexdigest(), 16) % (2**32) - - -def is_file_exists(path: str | Path, filename: str) -> bool: - folder = Path(path) - full_path = folder / filename - - return full_path.is_file() - - -def load_video(path: str | Path, return_type: str = "pt") -> List[Image] | np.ndarray | torch.Tensor: +def load_videos(path: str | Path, return_type: str = "pt") -> List[Image] | np.ndarray | torch.Tensor: """ Load videos from a path. - Parameters + Parameters: ---------- - path : str | Path + path: str | Path The path to the videos. - return_type : str + return_type: str The type to return the videos as. Can be "pt", "np", "pil". - Returns + Returns: ------- List[torch.Tensor] The videos. """ - video = diffusers_load_video(str(path)) + video = load_video(str(path)) if return_type == "pt": return torch.stack([ToTensor()(frame) for frame in video]) elif return_type == "np": @@ -127,25 +100,6 @@ def load_video(path: str | Path, return_type: str = "pt") -> List[Image] | np.nd raise ValueError(f"Invalid return_type: {return_type}. Use 'pt', 'np', or 'pil'.") -def load_videos_from_path(path: str | Path) -> torch.Tensor: - """ - Load entire directory of mp4 videos as a single tensor ready to be passed to evaluation. - - Parameters - ---------- - path : str | Path - The path to the directory of videos. - - Returns - ------- - torch.Tensor - The videos. - """ - path = Path(str(path)) - videos = torch.stack([load_video(p) for p in path.glob("*.mp4")]) - return videos - - def sanitize_prompt(prompt: str) -> str: """ Return a filesystem-safe version of a prompt. @@ -153,12 +107,12 @@ def sanitize_prompt(prompt: str) -> str: Replaces characters illegal in filenames and collapses whitespace so that generated files are portable across file systems. - Parameters + Parameters: ---------- prompt : str The prompt to sanitize. - Returns + Returns: ------- str The sanitized prompt. @@ -176,19 +130,20 @@ def prepare_batch(batch: str | tuple[str | List[str], Any]) -> str: Pruna datamodules are expected to yield tuples where the first element is a sequence of inputs; this utility enforces batch_size == 1 for simplicity. - Parameters + + Parameters: ---------- - batch : str | tuple[str | List[str], Any] + batch: str | tuple[str | List[str], Any] The batch to prepare. - Returns + Returns: ------- str The prompt string. """ if isinstance(batch, str): return batch - # for pruna datamodule. always returns a tuple where the first element is the input (list of prompts) to the model. + # for pruna datamodule. always returns a tuple where the first element is the input to the model. elif isinstance(batch, tuple): if not hasattr(batch[0], "__len__"): raise ValueError(f"Batch[0] is not a sequence (got {type(batch[0])})") @@ -203,12 +158,12 @@ def _normalize_save_format(save_format: str) -> tuple[str, Callable]: """ Normalize the save format to be used in the generate_videos function. - Parameters + Parameters: ---------- save_format : str The format to save the videos in. VBench supports mp4 and gif. - Returns + Returns: ------- tuple[str, Callable] The normalized save format and the save function. @@ -222,31 +177,17 @@ def _normalize_save_format(save_format: str) -> tuple[str, Callable]: def _normalize_prompts( - prompts: str | List[str] | PrunaDataModule, split: str = "test", batch_size: int = 1, num_samples: int | None = None, fraction: float = 1.0, data_partition_strategy: str = "indexed", partition_index: int = 0, seed: int = 42 + prompts: str | List[str] | PrunaDataModule, split: str = "test", batch_size: int = 1 ) -> Iterable[str]: """ Normalize prompts to an iterable format to be used in the generate_videos function. - Parameters + Parameters: ---------- prompts : str | List[str] | PrunaDataModule The prompts to normalize. - split : str - The dataset split to sample from. - batch_size : int - The batch size to sample from. - num_samples : int | None - The number of samples to sample from. - fraction : float - The fraction of the dataset to sample from. - data_partition_strategy : str - The strategy to use for partitioning the dataset. Can be "indexed" or "random". - partition_index : int - The index to use for partitioning the dataset. - seed : int - The seed to use for partitioning the dataset. - - Returns + + Returns: ------- Iterable[str] The normalized prompts. @@ -254,9 +195,6 @@ def _normalize_prompts( if isinstance(prompts, str): return [prompts] elif isinstance(prompts, PrunaDataModule): - target_dataset = getattr(prompts, f"{split}_dataset") - sample_size = define_sample_size_for_dataset(target_dataset, fraction, num_samples) - setattr(prompts, f"{split}_dataset", stratify_dataset(target_dataset, sample_size, seed, data_partition_strategy, partition_index)) return getattr(prompts, f"{split}_dataloader")(batch_size=batch_size) else: # list of prompts, already iterable return prompts @@ -266,7 +204,7 @@ def _ensure_dir(p: Path) -> None: """ Ensure the directory exists. - Parameters + Parameters: ---------- p : Path The path to ensure the directory exists. @@ -274,66 +212,51 @@ def _ensure_dir(p: Path) -> None: p.mkdir(parents=True, exist_ok=True) -def create_vbench_file_name( - prompt: str, idx: int, special_str: str = "", save_format: str = ".mp4", max_filename_length: int = 255 -) -> str: +def create_vbench_file_name(prompt: str, idx: int, special_str: str = "", postfix: str = ".mp4") -> str: """ Create a file name for the video in accordance with the VBench format. - Parameters + Parameters: ---------- - prompt : str + prompt: str The prompt to create the file name from. - idx : int + idx: int The index of the video. Vbench uses 5 seeds for each prompt. - special_str : str + special_str: str A special string to add to the file name if you wish to add a specific identifier. - save_format : str + postfix: str The format of the video file. Vbench supports mp4 and gif. - max_filename_length : int - The maximum length allowed for the file name. - Returns + Returns: ------- str The file name for the video. """ - filename = f"{prompt}{special_str}-{str(idx)}{save_format}" - if len(filename) > max_filename_length: - pruna_logger.debug( - f"File name {filename} is too long. Maximum length is {max_filename_length} characters. Truncating filename." - ) - filename = filename[:max_filename_length] - return filename + return f"{prompt}{special_str}-{str(idx)}{postfix}" -def sample_video_from_pipelines(pipeline: Any, seeder: Any, prompt: str, **kwargs): +def sample_video_from_pipelines(pipeline: Any, seeder: Any, prompt: str, device: str | torch.device = None, **kwargs): """ Sample a video from diffusers pipeline. - Parameters + Parameters: ---------- - pipeline : Any + pipeline: Any The pipeline to sample from. - seeder : Any - The seeding generator. - prompt : str + prompt: str The prompt to sample from. - **kwargs : Any + seeder: Any + The seeding generator. + **kwargs: Any Additional keyword arguments to pass to the pipeline. - Returns + Returns: ------- torch.Tensor The video tensor. """ - is_return_dict = kwargs.pop("return_dict", True) with torch.inference_mode(): - if is_return_dict: - out = pipeline(prompt=prompt, generator=seeder, **kwargs).frames[0] - else: - # If return_dict is False, the pipeline returns a tuple of (frames, metadata). - out = pipeline(prompt=prompt, generator=seeder, **kwargs)[0] + out = pipeline(prompt=prompt, generator=seeder, **kwargs).frames[0] return out @@ -343,31 +266,27 @@ def _wrap_sampler(model: Any, sampling_fn: Callable[..., Any]) -> Callable[..., Wrap a user-provided sampling function into a uniform callable. The returned callable has a keyword-only signature: - sampler(*, prompt: str, seeder: Any, device: str|torch.device, **kwargs) + sampler(*, prompt: str, seed: int, device: str|torch.device, **kwargs) This wrapper always passes `model` as the first positional argument, so custom functions can name their first parameter `model` or `pipeline`, etc. - Parameters + Parameters: ---------- - model : Any + model: Any The model to sample from. - sampling_fn : Callable[..., Any] + sampling_fn: Callable[..., Any] The sampling function to wrap. - Returns - ------- - Callable[..., Any] - The wrapped sampling function. """ if sampling_fn != sample_video_from_pipelines: pruna_logger.info( - "Using custom sampling function. Ensure it accepts (model, *, prompt, seeder, device, **kwargs)." + "Using custom sampling function. Ensure it accepts (model, *, prompt, seed, device, **kwargs)." ) # The sampling function may expect the model as "pipeline" so we pass it as an arg and not a kwarg. - def sampler(*, prompt: str, seeder: Any, **kwargs: Any) -> Any: - return sampling_fn(model, prompt=prompt, seeder=seeder, **kwargs) + def sampler(*, prompt: str, seeder: Any, device: str | torch.device, **kwargs: Any) -> Any: + return sampling_fn(model, prompt=prompt, seeder=seeder, device=device, **kwargs) return sampler @@ -375,33 +294,26 @@ def sampler(*, prompt: str, seeder: Any, **kwargs: Any) -> Any: def generate_videos( model: Any, prompts: str | List[str] | PrunaDataModule, - num_samples: int | None = None, - samples_fraction: float = 1.0, - data_partition_strategy: str = "indexed", - partition_index: int = 0, split: str = "test", - unique_sample_per_video_count: int = 1, + unique_sample_per_video_count: int = 5, global_seed: int = 42, sampling_fn: Callable[..., Any] = sample_video_from_pipelines, fps: int = 16, save_dir: str | Path = "./saved_videos", save_format: str = "mp4", - filename_fn: Callable = create_vbench_file_name, special_str: str = "", device: str | torch.device = None, - experiment_name: str = "", - sampling_seed_fn: Callable[..., Any] = get_sample_seed, **model_kwargs, ) -> None: """ Generate N samples per prompt and save them to disk with seed tracking. This function: - - Normalizes prompts (string, list, or datamodule). - - Uses an RNG seeded with `global_seed` for reproducibility across runs. - - Saves videos as MP4 or GIF. + 1) Normalizes prompts (string, list, or datamodule). + 2) Uses an RNG seeded with `global_seed` for reproducibility across runs. + 3) Saves videos as MP4 or GIF. - Parameters + Parameters: ---------- model : Any The model to sample from. @@ -422,12 +334,8 @@ def generate_videos( The directory to save the videos to. save_format : str The format to save the videos in. VBench supports mp4 and gif. - filename_fn : Callable - The function to create the file name. special_str : str A special string to add to the file name if you wish to add a specific identifier. - device : str | torch.device | None - The device to sample on. If None, the best available device will be used. **model_kwargs : Any Additional keyword arguments to pass to the sampling function. """ @@ -435,64 +343,23 @@ def generate_videos( device = set_to_best_available_device(device) - prompt_iterable = _normalize_prompts(prompts, split, batch_size=1, num_samples=num_samples, fraction=samples_fraction, data_partition_strategy=data_partition_strategy, partition_index=partition_index) + prompt_iterable = _normalize_prompts(prompts, split, batch_size=1) save_dir = Path(save_dir) _ensure_dir(save_dir) # set a run-level seed (VBench suggests this) (important for reproducibility) - seed_rng = lambda x: torch.Generator("cpu").manual_seed(x) + seed_rng = torch.Generator().manual_seed(global_seed) sampler = _wrap_sampler(model=model, sampling_fn=sampling_fn) for batch in prompt_iterable: prompt = prepare_batch(batch) for idx in range(unique_sample_per_video_count): - file_name = filename_fn(sanitize_prompt(prompt), idx, special_str, file_extension) + file_name = create_vbench_file_name(sanitize_prompt(prompt), idx, special_str, file_extension) out_path = save_dir / file_name - if is_file_exists(save_dir, file_name): - continue - else: - seed = sampling_seed_fn(experiment_name, prompt, idx) - vid = sampler(prompt=prompt, seeder=seed_rng(seed), **model_kwargs) - save_fn(vid, out_path, fps=fps) - - del vid - safe_memory_cleanup() + vid = sampler(prompt=prompt, seeder=seed_rng, device=device, **model_kwargs) + save_fn(vid, out_path, fps=fps) - -def evaluate_videos( - data: Any, metrics: StatefulMetric | List[StatefulMetric], prompts: Any | None = None -) -> List[MetricResult]: - """ - Evaluation loop helper. - - Parameters - ---------- - data : Any - The data to evaluate. - metrics : StatefulMetric | List[StatefulMetric] - The metrics to evaluate. - prompts : Any | None - The prompts to evaluate. - - Returns - ------- - List[MetricResult] - The results of the evaluation. - """ - results = [] - if isinstance(metrics, StatefulMetric): - metrics = [metrics] - if any(metric.call_type != "y" for metric in metrics) and prompts is None: - raise ValueError( - "You are trying to evaluate metrics that require more than the outputs, but didn't provide prompts." - ) - for metric in metrics: - for batch in data: - if prompts is None: - prompts = batch - metric.update(prompts, batch, batch) - prompts = None - results.append(metric.compute()) - return results + del vid + safe_memory_cleanup() From 828750101f3d74896bf1bd4c77dc0102a1c2af6e Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Mon, 22 Dec 2025 09:25:45 +0100 Subject: [PATCH 34/42] Undo limiting prompt length as name for image logging --- src/pruna/evaluation/artifactsavers/artifactsaver.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/pruna/evaluation/artifactsavers/artifactsaver.py b/src/pruna/evaluation/artifactsavers/artifactsaver.py index 3f116447..e1dac50f 100644 --- a/src/pruna/evaluation/artifactsavers/artifactsaver.py +++ b/src/pruna/evaluation/artifactsavers/artifactsaver.py @@ -102,7 +102,7 @@ def create_alias(self, source_path: Path | str, filename: str, sanitize: bool = return alias -def sanitize_filename(name: str, max_length: int = 128) -> str: +def sanitize_filename(name: str) -> str: """Sanitize a filename to make it safe for the filesystem. Works for every OS. Parameters @@ -124,15 +124,12 @@ def sanitize_filename(name: str, max_length: int = 128) -> str: name = re.sub(r'[<>:"/\\|?*]', '_', name) # Whitespace -> underscore name = re.sub(r'\s+', '_', name) - # Control chars weg + # Control chars removed name = re.sub(r'[\x00-\x1f\x7f]', "", name) # Collapse multiple underscores into one name = re.sub(r'_+', '_', name) # remove leading/trailing dots/spaces/underscores name = name.strip(" ._") - # limit length - if len(name) > max_length: - name = name[:max_length].rstrip("._ ") if name == "": name = "untitled" return name From c019dd674ba433ea06ddae297f79d91c5f2c843e Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Mon, 22 Dec 2025 09:29:42 +0100 Subject: [PATCH 35/42] Integrate uncommented method from task.py --- src/pruna/evaluation/task.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index 0a1f9860..dd6f10de 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -141,20 +141,6 @@ def _get_stateful_metric_device_from_task_device(self) -> str: else: return self.device # for when we pass a specific cuda device, or cpu or mps. - -def _safe_build_metrics( - request: str | List[str | BaseMetric | StatefulMetric], inference_device: str, stateful_metric_device: str -): - try: - return get_metrics(request, inference_device, stateful_metric_device) - except torch.cuda.OutOfMemoryError as e: - if stateful_metric_device == "cuda": - pruna_logger.error( - "Not enough GPU memory for metrics on %s. Please try initializing task with `low_memory=True`.", - stateful_metric_device, - ) - raise e - def validate_and_get_task_modality(self) -> str: """ Check if the task has a single modality of metrics. @@ -181,6 +167,20 @@ def validate_and_get_task_modality(self) -> str: return "general" +def _safe_build_metrics( + request: str | List[str | BaseMetric | StatefulMetric], inference_device: str, stateful_metric_device: str +): + try: + return get_metrics(request, inference_device, stateful_metric_device) + except torch.cuda.OutOfMemoryError as e: + if stateful_metric_device == "cuda": + pruna_logger.error( + "Not enough GPU memory for metrics on %s. Please try initializing task with `low_memory=True`.", + stateful_metric_device, + ) + raise e + + def get_metrics( request: str | List[str | BaseMetric | StatefulMetric], inference_device: str, stateful_metric_device: str ) -> List[BaseMetric | StatefulMetric]: From 4a968afbf485b1d0359e82e9ac9bad0de1065d8a Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Mon, 22 Dec 2025 09:30:24 +0100 Subject: [PATCH 36/42] Update function in evaluation agent for generating metadata-json file --- src/pruna/evaluation/evaluation_agent.py | 46 ++++++++++++++++++++---- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index e76c99da..104cac62 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -14,6 +14,7 @@ from __future__ import annotations +import json import tempfile from pathlib import Path from typing import Any, List, Literal @@ -258,7 +259,7 @@ def update_stateful_metrics( canonical_paths.append(canonical_path) # Create aliases for the prompts if the user wants to save the artifacts with the prompt name. # For doing that, the user needs to set the saving_kwargs["save_as_prompt_name"] to True. - self._maybe_create_prompt_aliases(batch, canonical_paths, sample_idx) + self._maybe_create_prompt_metadata(batch, canonical_paths, sample_idx) batch = move_batch_to_device(batch, self.device) processed_outputs = move_batch_to_device(processed_outputs, self.device) @@ -348,11 +349,42 @@ def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[A results.append(metric.compute(model, self.task.dataloader)) return results - def _maybe_create_prompt_aliases(self, batch, canonical_paths, sample_idx): - if not self.saving_kwargs.get("save_as_prompt_name", False): + def _maybe_create_prompt_metadata(self, batch, canonical_paths, sample_idx): + """ + Write prompt-level metadata for saved artifacts. + + If ``save_prompt_metadata`` is enabled, this function appends one JSONL + record per prompt to ``metadata.jsonl`` in the run output directory. + The canonical filename is used as the stable identifier. + + Args: + batch: Batch tuple where the first element contains the prompts. + canonical_paths: List of canonical file paths corresponding to each + prompt in the batch. + sample_idx: Index of the current sample within the evaluation run. + + Returns: + ------- + None + """ + if not self.saving_kwargs.get("save_prompt_metadata", False): return - (x, _) = batch - for prompt_idx, prompt in enumerate(x): - alias_name = f"{prompt}-{sample_idx}" - self.artifact_saver.create_alias(canonical_paths[prompt_idx], alias_name) + (x, _) = batch # x = prompts + + metadata_path = Path(self.root_dir) / "metadata.jsonl" + metadata_path.parent.mkdir(parents=True, exist_ok=True) + + with metadata_path.open("a", encoding="utf-8") as f: + for prompt_idx, prompt in enumerate(x): + record = { + # stable ID (file actually on disk) + "file": Path(canonical_paths[prompt_idx]).name, + # full path + "canonical_path": str(canonical_paths[prompt_idx]), + # original prompt + "prompt": str(prompt), + "sample_idx": sample_idx, + "prompt_idx": prompt_idx, + } + f.write(json.dumps(record, ensure_ascii=False) + "\n") From fc3d21c7a33cd54a19384121d55d9b39d4c363b1 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Wed, 24 Dec 2025 16:39:29 +0000 Subject: [PATCH 37/42] Add evaluation-agent parameter for optional JSON metadata creation; include batch_idx in metadata --- src/pruna/evaluation/evaluation_agent.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 104cac62..5bc5217d 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -17,6 +17,7 @@ import json import tempfile from pathlib import Path +from traceback import print_tb from typing import Any, List, Literal import torch @@ -82,6 +83,7 @@ def __init__( seed_strategy: Literal["per_sample", "no_seed"] = "no_seed", global_seed: int | None = None, artifact_saver_export_format: str | None = None, + save_in_out_metadata: bool = False, saving_kwargs: dict = dict(), ) -> None: if task is not None: @@ -103,6 +105,7 @@ def __init__( self.device = set_to_best_available_device(self.task.device) self.cache: List[Tensor] = [] self.evaluation_for_first_model: bool = True + self.save_in_out_metadata: bool = save_in_out_metadata self.save_artifacts: bool = save_artifacts if save_artifacts: self.root_dir = root_dir if root_dir is not None else OUTPUT_DIR @@ -259,7 +262,7 @@ def update_stateful_metrics( canonical_paths.append(canonical_path) # Create aliases for the prompts if the user wants to save the artifacts with the prompt name. # For doing that, the user needs to set the saving_kwargs["save_as_prompt_name"] to True. - self._maybe_create_prompt_metadata(batch, canonical_paths, sample_idx) + self._maybe_create_input_output_metadata(batch, canonical_paths, sample_idx, batch_idx) batch = move_batch_to_device(batch, self.device) processed_outputs = move_batch_to_device(processed_outputs, self.device) @@ -349,7 +352,7 @@ def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[A results.append(metric.compute(model, self.task.dataloader)) return results - def _maybe_create_prompt_metadata(self, batch, canonical_paths, sample_idx): + def _maybe_create_input_output_metadata(self, batch, canonical_paths, sample_idx, batch_idx): """ Write prompt-level metadata for saved artifacts. @@ -367,7 +370,7 @@ def _maybe_create_prompt_metadata(self, batch, canonical_paths, sample_idx): ------- None """ - if not self.saving_kwargs.get("save_prompt_metadata", False): + if not self.save_in_out_metadata: return (x, _) = batch # x = prompts @@ -385,6 +388,7 @@ def _maybe_create_prompt_metadata(self, batch, canonical_paths, sample_idx): # original prompt "prompt": str(prompt), "sample_idx": sample_idx, + "batch_idx": batch_idx, "prompt_idx": prompt_idx, } f.write(json.dumps(record, ensure_ascii=False) + "\n") From e3b19ed81295f76647e185725f5027d8c93b29f7 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Wed, 24 Dec 2025 16:49:07 +0000 Subject: [PATCH 38/42] Adjust doc string for _maybe_create_input_output_metadat() in evaluation_agent.py --- src/pruna/evaluation/evaluation_agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 5bc5217d..66906b4c 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -365,6 +365,7 @@ def _maybe_create_input_output_metadata(self, batch, canonical_paths, sample_idx canonical_paths: List of canonical file paths corresponding to each prompt in the batch. sample_idx: Index of the current sample within the evaluation run. + batch_idx: Index of the batch within the evaluation dataloader loop. Returns: ------- From f1773131068da1215da97526ebc15b81fbb8203f Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Mon, 29 Dec 2025 16:32:31 +0000 Subject: [PATCH 39/42] Add model role in metadata json when applying pairwise metrics --- src/pruna/evaluation/evaluation_agent.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 66906b4c..93465a70 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -379,9 +379,13 @@ def _maybe_create_input_output_metadata(self, batch, canonical_paths, sample_idx metadata_path = Path(self.root_dir) / "metadata.jsonl" metadata_path.parent.mkdir(parents=True, exist_ok=True) + model_role = "reference" if self.evaluation_for_first_model else "candidate" + with metadata_path.open("a", encoding="utf-8") as f: for prompt_idx, prompt in enumerate(x): record = { + # Model role: reference or candidate + "model_role": model_role, # stable ID (file actually on disk) "file": Path(canonical_paths[prompt_idx]).name, # full path From f936d5f25cac39d3675acafc2e65f0c2e48b1b09 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Tue, 13 Jan 2026 08:24:36 +0000 Subject: [PATCH 40/42] Change name and if logic of json file creation function in evaluation_agent.py --- src/pruna/evaluation/evaluation_agent.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 93465a70..35ae6dcf 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -262,7 +262,8 @@ def update_stateful_metrics( canonical_paths.append(canonical_path) # Create aliases for the prompts if the user wants to save the artifacts with the prompt name. # For doing that, the user needs to set the saving_kwargs["save_as_prompt_name"] to True. - self._maybe_create_input_output_metadata(batch, canonical_paths, sample_idx, batch_idx) + if self.save_in_out_metadata: + self._create_input_output_metadata(batch, canonical_paths, sample_idx, batch_idx) batch = move_batch_to_device(batch, self.device) processed_outputs = move_batch_to_device(processed_outputs, self.device) @@ -352,7 +353,7 @@ def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[A results.append(metric.compute(model, self.task.dataloader)) return results - def _maybe_create_input_output_metadata(self, batch, canonical_paths, sample_idx, batch_idx): + def _create_input_output_metadata(self, batch, canonical_paths, sample_idx, batch_idx): """ Write prompt-level metadata for saved artifacts. @@ -371,9 +372,6 @@ def _maybe_create_input_output_metadata(self, batch, canonical_paths, sample_idx ------- None """ - if not self.save_in_out_metadata: - return - (x, _) = batch # x = prompts metadata_path = Path(self.root_dir) / "metadata.jsonl" From 5a5d3e7d7b3a144a76ab6bccf2c7801c3ac121cb Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Tue, 13 Jan 2026 08:47:27 +0000 Subject: [PATCH 41/42] Copy the tests from image-artifactsaver to feat/img-saver-extended to have one branch for optimization agent and image artifact saver --- tests/evaluation/test_artifactsaver.py | 102 +++++++++++++++++++++++-- 1 file changed, 96 insertions(+), 6 deletions(-) diff --git a/tests/evaluation/test_artifactsaver.py b/tests/evaluation/test_artifactsaver.py index 41b09de4..7a765237 100644 --- a/tests/evaluation/test_artifactsaver.py +++ b/tests/evaluation/test_artifactsaver.py @@ -6,35 +6,49 @@ import torch from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver +from pruna.evaluation.artifactsavers.image_artifactsaver import ImageArtifactSaver from pruna.evaluation.artifactsavers.utils import assign_artifact_saver from pruna.evaluation.metrics.vbench_utils import load_video from PIL import Image def test_create_alias(): - """ Test that we can create an alias for an existing video.""" + """ Test that we can create an alias for an existing video and image.""" with tempfile.TemporaryDirectory() as tmp_path: + + # --- Test video artifact saver --- # First, we create a random video and save it. saver = VideoArtifactSaver(root=tmp_path, export_format="mp4") dummy_video = np.random.randint(0, 255, (10, 16, 16, 3), dtype=np.uint8) source_filename = saver.save_artifact(dummy_video, saving_kwargs={"fps": 5}) - # Then, we create an alias for the video. alias = saver.create_alias(source_filename, "alias_filename") # Finally, we reload the alias and check that it is the same as the original video. reloaded_alias_video = load_video(str(alias), return_type = "np") - assert(reloaded_alias_video.shape == dummy_video.shape) assert alias.exists() assert alias.name.endswith(".mp4") + # --- Test image artifact saver --- + saver = ImageArtifactSaver(root=tmp_path, export_format="png") + dummy_image = np.random.randint(0, 255, (16, 16, 3), dtype=np.uint8) + source_filename = saver.save_artifact(dummy_image, saving_kwargs={"quality": 95}) + # Then, we create an alias for the image. + alias = saver.create_alias(source_filename, "alias_filename") + # Finally, we reload the alias and check that it is the same as the original image. + reloaded_alias_image = np.array(Image.open(str(alias))) + assert(reloaded_alias_image.shape == dummy_image.shape) + assert alias.exists() + assert alias.name.endswith(".png") -def test_assign_artifact_saver_video(tmp_path: Path): - """ Test the artifact save is assigned correctly.""" +def test_assign_all_artifact_savers(tmp_path: Path): + """ Test each artifact saver is assigned correctly.""" saver = assign_artifact_saver("video", root=tmp_path, export_format="mp4") assert isinstance(saver, VideoArtifactSaver) assert saver.export_format == "mp4" - + saver = assign_artifact_saver("image", root=tmp_path, export_format="png") + assert isinstance(saver, ImageArtifactSaver) + assert saver.export_format == "png" def test_assign_artifact_saver_invalid(): """ Test that we raise an error if the artifact saver is assigned incorrectly.""" @@ -88,3 +102,79 @@ def test_video_artifact_saver_tensor(export_format: str, save_from_type: str, sa path = saver.save_artifact(dummy_video) assert path.exists() assert path.suffix == f".{export_format}" + +@pytest.mark.parametrize( + "export_format, save_from_type, save_from_dtype", + [ + # --- Test png format --- + # numpy + pytest.param("png", "np", "uint8"), + pytest.param("png", "np", "float32"), + # torch + pytest.param("png", "pt", "float32"), + pytest.param("png", "pt", "float16"), + pytest.param("png", "pt", "uint8"), + # PIL + pytest.param("png", "pil", "uint8"), + # --- Test jpg format --- + # numpy + pytest.param("jpg", "np", "uint8"), + pytest.param("jpg", "np", "float32"), + # torch + pytest.param("jpg", "pt", "float32"), + pytest.param("jpg", "pt", "float16"), + pytest.param("jpg", "pt", "uint8"), + # PIL + pytest.param("jpg", "pil", "uint8"), + # --- Test webp format --- + # numpy + pytest.param("webp", "np", "uint8"), + pytest.param("webp", "np", "float32"), + # torch + pytest.param("webp", "pt", "float32"), + pytest.param("webp", "pt", "float16"), + pytest.param("webp", "pt", "uint8"), + # PIL + pytest.param("webp", "pil", "uint8"), + # --- Test jpeg format --- + # numpy + pytest.param("jpeg", "np", "uint8"), + pytest.param("jpeg", "np", "float32"), + # torch + pytest.param("jpeg", "pt", "float32"), + pytest.param("jpeg", "pt", "float16"), + pytest.param("jpeg", "pt", "uint8"), + # PIL + pytest.param("jpeg", "pil", "uint8"), + ] + ) +def test_image_artifact_saver_tensor(export_format: str, save_from_type: str, save_from_dtype: str): + """ Test that we can save an image from a tensor.""" + with tempfile.TemporaryDirectory() as tmp_path: + saver = ImageArtifactSaver(root=tmp_path, export_format=export_format) + # Create fake image: + if save_from_type == "pt": + # Note: torch convention is (C, H, W) + if save_from_dtype == "uint8": + dtype = getattr(torch, save_from_dtype) + dummy_image = torch.randint(0, 256, (3, 16, 16), dtype=dtype) + else: + dtype = getattr(torch, save_from_dtype) + dummy_image = torch.randn(3, 16, 16, dtype=dtype) + elif save_from_type == "np": + # Note: Numpy arrays as images follow the convention (H, W, C) + if save_from_dtype == "uint8": + dtype = getattr(np, save_from_dtype) + dummy_image = np.random.randint(0, 256, (16, 16, 3), dtype=dtype) + else: + rng = np.random.default_rng() + dtype = getattr(np, save_from_dtype) + dummy_image = rng.random((16, 16, 3), dtype=dtype) + elif save_from_type == "pil": + # Note: PIL images by default have shape (H, W, C) and are usually uint8 (standard for ".jpg", etc.) + dtype = getattr(np, save_from_dtype) + dummy_image = np.random.randint(0, 256, (16, 16, 3), dtype=dtype) + dummy_image = Image.fromarray(dummy_image.astype(np.uint8)) + path = saver.save_artifact(dummy_image) + assert path.exists() + assert path.suffix == f".{export_format}" \ No newline at end of file From 3ae00ee3f8d5e8b65280b935fea86b448efc2785 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Tue, 13 Jan 2026 10:11:04 +0000 Subject: [PATCH 42/42] Remove unused import --- src/pruna/evaluation/artifactsavers/image_artifactsaver.py | 2 +- src/pruna/evaluation/evaluation_agent.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pruna/evaluation/artifactsavers/image_artifactsaver.py b/src/pruna/evaluation/artifactsavers/image_artifactsaver.py index 4d1c47f8..f5d47f4f 100644 --- a/src/pruna/evaluation/artifactsavers/image_artifactsaver.py +++ b/src/pruna/evaluation/artifactsavers/image_artifactsaver.py @@ -81,4 +81,4 @@ def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: # Save the image (export format is determined by the file extension) data.save(canonical_path, **saving_kwargs.copy()) - return canonical_path + return canonical_path \ No newline at end of file diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 35ae6dcf..395b40fc 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -17,7 +17,6 @@ import json import tempfile from pathlib import Path -from traceback import print_tb from typing import Any, List, Literal import torch