diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a4072ad..ba01405d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: check-added-large-files - id: check-case-conflict @@ -13,21 +13,45 @@ repos: - id: trailing-whitespace exclude: \.md$ - id: no-commit-to-branch + - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.12 + rev: v0.14.6 hooks: - id: ruff-check args: [--fix] + files: \.py$ - id: ruff-format + files: \.py$ + + - repo: https://github.com/trufflesecurity/trufflehog + rev: v3.91.1 + hooks: + - id: trufflehog + name: TruffleHog Secrets Scanner + entry: trufflehog + language: golang + types_or: [python, yaml, json, text] + args: + [ + "filesystem", + "src", + "tests", + ".github/workflows", + "--results=verified,unknown", + "--exclude-paths=.venv", + "--fail" + ] + stages: ["pre-commit", "pre-push"] - repo: local hooks: - id: ty - name: ty check + name: type checking using ty entry: uvx ty check . language: system types: [python] pass_filenames: false + files: \.py$ - repo: local hooks: @@ -39,7 +63,7 @@ repos: grep -v "^D" | cut -f2- | while IFS= read -r file; do - if [ -f "$file" ] && ["$file" != ".pre-commit-config.yaml"] && grep -q "pruna_pro" "$file"; then + if [ -f "$file" ] && [ "$file" != ".pre-commit-config.yaml" ] && grep -q "pruna_pro" "$file"; then echo "Error: pruna_pro found in staged file $file" exit 1 fi @@ -48,10 +72,4 @@ repos: language: system stages: [pre-commit] types: [python] - exclude: "^docs/" - - id: trufflehog - name: TruffleHog - description: Detect secrets in your data. - entry: bash -c 'git diff --cached --name-only | xargs -I {} trufflehog filesystem {} --fail --no-update' - language: system - stages: ["pre-commit", "pre-push"] + files: \.py$ diff --git a/pyproject.toml b/pyproject.toml index 3aca3631..0702c75e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,10 +31,13 @@ unsupported-operator = "ignore" # mypy supports | syntax with from __future__ im invalid-argument-type = "ignore" # mypy is more permissive with argument types invalid-return-type = "ignore" # mypy is more permissive with return types invalid-parameter-default = "ignore" # mypy is more permissive with parameter defaults +possibly-missing-attribute = "ignore" # mypy is more permissive with attribute access +possibly-unbound-attribute = "ignore" +possibly-missing-import = "ignore" # mypy is more permissive with imports no-matching-overload = "ignore" # mypy is more permissive with overloads unresolved-reference = "ignore" # mypy is more permissive with references -possibly-unbound-import = "ignore" missing-argument = "ignore" +possibly-unbound-import = "ignore" [tool.coverage.run] source = ["src/pruna"] @@ -75,6 +78,7 @@ gptqmodel = [ { index = "pruna_internal", marker = "sys_platform != 'darwin' or platform_machine != 'arm64'"}, { index = "pypi", marker = "sys_platform == 'darwin' and platform_machine == 'arm64'"}, ] +clip = {git = "https://github.com/openai/CLIP.git", rev = "dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1"} [project] name = "pruna" @@ -186,6 +190,7 @@ dev = [ ] cpu = [] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/src/pruna/data/utils.py b/src/pruna/data/utils.py index fb258afa..45fb8517 100644 --- a/src/pruna/data/utils.py +++ b/src/pruna/data/utils.py @@ -184,7 +184,7 @@ def recover_text_from_dataloader(dataloader: DataLoader, tokenizer: Any) -> list return texts -def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Dataset: +def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42, partition_strategy: str = "random", partition_index: int = 0) -> Dataset: """ Stratify the dataset into a specific size. @@ -196,6 +196,10 @@ def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Data The size to stratify. seed : int The seed to use for sampling the dataset. + partition_strategy : str + The strategy to use for partitioning the dataset. Can be "indexed" or "random". + partition_index : int + The index to use for partitioning the dataset. Returns ------- @@ -211,8 +215,13 @@ def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Data return dataset indices = list(range(dataset_length)) - random.Random(seed).shuffle(indices) - selected_indices = indices[:sample_size] + if partition_strategy == "indexed": + selected_indices = indices[sample_size*partition_index:sample_size*(partition_index+1)] + elif partition_strategy == "random": + random.Random(seed).shuffle(indices) + selected_indices = indices[:sample_size] + else: + raise ValueError(f"Invalid partition strategy: {partition_strategy}") dataset = dataset.select(selected_indices) return dataset diff --git a/src/pruna/engine/handler/handler_diffuser.py b/src/pruna/engine/handler/handler_diffuser.py index fd1fd239..2d221085 100644 --- a/src/pruna/engine/handler/handler_diffuser.py +++ b/src/pruna/engine/handler/handler_diffuser.py @@ -15,10 +15,9 @@ from __future__ import annotations import inspect -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple import torch -from torchvision import transforms from pruna.engine.handler.handler_inference import InferenceHandler from pruna.logging.logger import pruna_logger @@ -28,10 +27,6 @@ class DiffuserHandler(InferenceHandler): """ Handle inference arguments, inputs and outputs for diffusers models. - A generator with a fixed seed (42) is passed as an argument to the model for reproducibility. - The first element of the batch is passed as input to the model. - The generated outputs are expected to have .images attribute. - Parameters ---------- call_signature : inspect.Signature @@ -40,12 +35,18 @@ class DiffuserHandler(InferenceHandler): The arguments to pass to the model. """ - def __init__(self, call_signature: inspect.Signature, model_args: Optional[Dict[str, Any]] = None) -> None: - default_args = {"generator": torch.Generator("cpu").manual_seed(42)} + def __init__( + self, + call_signature: inspect.Signature, + model_args: Optional[Dict[str, Any]] = None, + seed_strategy: Literal["per_sample", "no_seed"] = "no_seed", + global_seed: int | None = None, + ) -> None: self.call_signature = call_signature - if model_args: - default_args.update(model_args) - self.model_args = default_args + self.model_args = model_args if model_args else {} + # We want the default output type to be pytorch tensors. + self.model_args["output_type"] = "pt" + self.configure_seed(seed_strategy, global_seed) def prepare_inputs( self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor | dict[str, Any], ...] | dict[str, Any] @@ -83,13 +84,36 @@ def process_output(self, output: Any) -> torch.Tensor: torch.Tensor The processed images. """ - generated = output.images - return torch.stack([transforms.PILToTensor()(g) for g in generated]) + if hasattr(output, "images"): + generated = output.images + # For video models. + elif hasattr(output, "frames"): + generated = output.frames + else: + # Maybe the user is calling the pipeline with return_dict = False, + # which then returns the generated image / video in a tuple + generated = output[0] + return generated.float() def log_model_info(self) -> None: """Log information about the inference handler.""" pruna_logger.info( - "Detected diffusers model. Using DiffuserHandler with fixed seed.\n" - "- The first element of the batch is passed as input.\n" - "- The generated outputs are expected to have .images attribute." + "Detected diffusers model. Using DiffuserHandler.\n- The first element of the batch is passed as input.\n" + "Inference outputs are expected to have either have an `images` attribute or a `frames` attribute." + "Or be a tuple with the generated image / video as the first element." ) + + def set_seed(self, seed: int) -> None: + """ + Set the random seed for the current process. + + Parameters + ---------- + seed : int + The seed to set. + """ + self.model_args["generator"] = torch.Generator("cpu").manual_seed(seed) + + def remove_seed(self) -> None: + """Remove the seed from the current process.""" + self.model_args["generator"] = None diff --git a/src/pruna/engine/handler/handler_inference.py b/src/pruna/engine/handler/handler_inference.py index 900e0e25..2c1ae290 100644 --- a/src/pruna/engine/handler/handler_inference.py +++ b/src/pruna/engine/handler/handler_inference.py @@ -14,9 +14,11 @@ from __future__ import annotations +import random from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Literal, Tuple +import numpy as np import torch from pruna.data.utils import move_batch_to_device @@ -98,3 +100,75 @@ def move_inputs_to_device( return move_batch_to_device(inputs, device) except torch.cuda.OutOfMemoryError as e: raise e + + def configure_seed(self, seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None: + """ + Set the random seed according to the chosen strategy. + + - If `seed_strategy="per_sample"`,the `global_seed` is used as a base to derive a different seed for each + sample. This ensures reproducibility while still producing variation across samples, + making it the preferred option for benchmarking. + - If `seed_strategy="no_seed"`, no seed is set internally. + The user is responsible for managing seeds if reproducibility is required. + + Parameters + ---------- + seed_strategy : Literal["per_sample", "no_seed"] + The seeding strategy to apply. + global_seed : int | None + The base seed value to use (if applicable). + """ + self.seed_strategy = seed_strategy + validate_seed_strategy(seed_strategy, global_seed) + if global_seed is not None: + self.global_seed = global_seed + self.set_seed(global_seed) + else: + self.remove_seed() + + def set_seed(self, seed: int) -> None: + """ + Set the random seed for the current process. + + Parameters + ---------- + seed : int + The seed to set. + """ + # With the default handler, we can't assume anything about the model, + # so we are setting the seed for all RNGs available. + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + def remove_seed(self) -> None: + """Remove the seed from the current process.""" + random.seed(None) + np.random.seed(None) + # We can't really remove the seed from the PyTorch RNG, so we are reseeding with torch.seed(). + # torch.seed() creates a non-deterministic random number. + torch.manual_seed(torch.seed()) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(torch.seed()) + + +def validate_seed_strategy(seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None: + """ + Check the consistency of the seed strategy and the global seed. + + If the seed strategy is "no_seed", the global seed must be None. + If the seed strategy is or "per_sample", the user must provide a global seed. + + Parameters + ---------- + seed_strategy : Literal["per_sample", "no_seed"] + The seeding strategy to apply. + global_seed : int | None + The base seed value to use (if applicable). + """ + if seed_strategy != "no_seed" and global_seed is None: + raise ValueError("Global seed must be provided if seed strategy is not 'no_seed'.") + elif global_seed is not None and seed_strategy == "no_seed": + raise ValueError("Seed strategy cannot be 'no_seed' if global seed is provided.") diff --git a/src/pruna/engine/pruna_model.py b/src/pruna/engine/pruna_model.py index fb4d6f91..a3bdec14 100644 --- a/src/pruna/engine/pruna_model.py +++ b/src/pruna/engine/pruna_model.py @@ -24,7 +24,7 @@ from pruna.config.smash_config import SmashConfig from pruna.engine.handler.handler_utils import register_inference_handler -from pruna.engine.load import load_pruna_model, load_pruna_model_from_pretrained +from pruna.engine.load import filter_load_kwargs, load_pruna_model, load_pruna_model_from_pretrained from pruna.engine.save import save_pruna_model, save_pruna_model_to_hub from pruna.engine.utils import get_device, get_nn_modules, set_to_eval from pruna.logging.filter import apply_warning_filter @@ -108,6 +108,8 @@ def run_inference(self, batch: Any) -> Any: ) inference_function = getattr(self, inference_function_name) + self.inference_handler.model_args = filter_load_kwargs(self.model.__call__, self.inference_handler.model_args) + if prepared_inputs is None: outputs = inference_function(**self.inference_handler.model_args) elif isinstance(prepared_inputs, dict): diff --git a/src/pruna/evaluation/artifactsavers/artifactsaver.py b/src/pruna/evaluation/artifactsavers/artifactsaver.py new file mode 100644 index 00000000..e1dac50f --- /dev/null +++ b/src/pruna/evaluation/artifactsavers/artifactsaver.py @@ -0,0 +1,135 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re +import unicodedata +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + + +class ArtifactSaver(ABC): + """ + Abstract class for artifact savers. + + The artifact saver is responsible for saving the inference outputs during evaluation. + + There needs to be a subclass for each metric modality (e.g. video, image, text, etc.). + + Parameters + ---------- + export_format: str | None + The format to export the artifacts in. + root: Path | str | None + The root directory to save the artifacts in. + """ + + export_format: str | None = None + root: Path | str | None = None + + @abstractmethod + def save_artifact(self, data: Any) -> Path: + """ + Implement this method to save the artifact. + + Parameters + ---------- + data: Any + The data to save. + + Returns + ------- + Path + The full path to the saved artifact. + """ + pass + + def create_alias(self, source_path: Path | str, filename: str, sanitize: bool = True) -> Path: + """ + Create an alias for the artifact. + + The evaluation agent will save the inference outputs with a canonical file + formatting style that makes sense for the general case. + + If your metric requires a different file naming convention for evaluation, + you can use this method to create an alias for the artifact. + + This way we prevent duplicate artifacts from being saved and save storage space. + + By default, the alias will be created as a hardlink to the source artifact. + If the hardlink fails, a symlink will be created. + + Parameters + ---------- + source_path : Path | str + The path to the source artifact. + filename : str + The filename to create the alias for. + + Returns + ------- + Path + The full path to the alias. + """ + if sanitize: + filename = sanitize_filename(filename) + alias = Path(str(self.root)) / f"{filename}.{self.export_format}" + alias.parent.mkdir(parents=True, exist_ok=True) + try: + if alias.exists(): + alias.unlink() + alias.hardlink_to(source_path) + except Exception: + try: + if alias.exists(): + alias.unlink() + alias.symlink_to(source_path) + except Exception as e: + raise e + return alias + + +def sanitize_filename(name: str) -> str: + """Sanitize a filename to make it safe for the filesystem. Works for every OS. + + Parameters + ---------- + name: str + The name to sanitize. + max_length: int + The maximum length of the sanitized name. If it is exceeded, the name is truncated to max_length. + + Returns + ------- + str + The sanitized name. If the name is empty, "untitled" is returned. + + """ + name = str(name) + name = unicodedata.normalize('NFKD', name) + # Forbidden characters + name = re.sub(r'[<>:"/\\|?*]', '_', name) + # Whitespace -> underscore + name = re.sub(r'\s+', '_', name) + # Control chars removed + name = re.sub(r'[\x00-\x1f\x7f]', "", name) + # Collapse multiple underscores into one + name = re.sub(r'_+', '_', name) + # remove leading/trailing dots/spaces/underscores + name = name.strip(" ._") + if name == "": + name = "untitled" + return name diff --git a/src/pruna/evaluation/artifactsavers/image_artifactsaver.py b/src/pruna/evaluation/artifactsavers/image_artifactsaver.py new file mode 100644 index 00000000..f5d47f4f --- /dev/null +++ b/src/pruna/evaluation/artifactsavers/image_artifactsaver.py @@ -0,0 +1,84 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import secrets +import time +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from PIL import Image + +from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver + + +class ImageArtifactSaver(ArtifactSaver): + """ + Save image artifacts. + + Parameters + ---------- + root: Path | str | None = None + The root directory to save the artifacts. + export_format: str | None = "png" + The format to save the artifacts (e.g. "png", "jpg", "jpeg", "webp"). + """ + + export_format: str | None + root: Path | str | None + + def __init__(self, root: Path | str | None = None, export_format: str | None = "png") -> None: + self.root = Path(root) if root else Path.cwd() + (self.root / "canonical").mkdir(parents=True, exist_ok=True) + self.export_format = export_format if export_format else "png" + if self.export_format not in ["png", "jpg", "jpeg", "webp"]: + raise ValueError(f"Invalid format: {self.export_format}. Valid formats are: png, jpg, jpeg, webp.") + + def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: + """ + Save the image artifact. + + Parameters + ---------- + data: Any + The data to save. + saving_kwargs: dict + The additional kwargs to pass to the saving utility function. + + Returns + ------- + Path + The path to the saved artifact. + """ + canonical_filename = f"{int(time.time())}_{secrets.token_hex(4)}.{self.export_format}" + canonical_path = Path(str(self.root)) / "canonical" / canonical_filename + + # We save the image as a PIL Image, so we need to convert the data to a PIL Image. + # Usually, the data is already a PIL.Image, so we don't need to convert it. + if isinstance(data, torch.Tensor): + data = np.transpose(data.cpu().numpy(), (1, 2, 0)) + data = np.clip(data * 255, 0, 255).astype(np.uint8) + if isinstance(data, np.ndarray): + data = Image.fromarray(data.astype(np.uint8)) + # Now data must be a PIL Image + if not isinstance(data, Image.Image): + raise ValueError("Model outputs must be torch.Tensor, numpy.ndarray, or PIL.Image.") + + # Save the image (export format is determined by the file extension) + data.save(canonical_path, **saving_kwargs.copy()) + + return canonical_path \ No newline at end of file diff --git a/src/pruna/evaluation/artifactsavers/utils.py b/src/pruna/evaluation/artifactsavers/utils.py new file mode 100644 index 00000000..ef903b93 --- /dev/null +++ b/src/pruna/evaluation/artifactsavers/utils.py @@ -0,0 +1,49 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pathlib import Path + +from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver +from pruna.evaluation.artifactsavers.image_artifactsaver import ImageArtifactSaver +from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver + + +def assign_artifact_saver( + modality: str, root: Path | str | None = None, export_format: str | None = None +) -> ArtifactSaver: + """ + Assign the appropriate artifact saver based on the modality. + + Parameters + ---------- + modality: str + The modality of the data. + root: str + The root directory to save the artifacts. + export_format: str + The format to save the artifacts. + + Returns + ------- + ArtifactSaver + The appropriate artifact saver. + """ + if modality == "video": + return VideoArtifactSaver(root=root, export_format=export_format) + if modality == "image": + return ImageArtifactSaver(root=root, export_format=export_format) + else: + raise ValueError(f"Modality {modality} is not supported") diff --git a/src/pruna/evaluation/artifactsavers/video_artifactsaver.py b/src/pruna/evaluation/artifactsavers/video_artifactsaver.py new file mode 100644 index 00000000..4acae51c --- /dev/null +++ b/src/pruna/evaluation/artifactsavers/video_artifactsaver.py @@ -0,0 +1,82 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import secrets +import time +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from diffusers.utils import export_to_gif, export_to_video +from PIL import Image + +from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver + + +class VideoArtifactSaver(ArtifactSaver): + """ + Save video artifacts. + + Parameters + ---------- + root: Path | str | None = None + The root directory to save the artifacts. + export_format: str | None = "mp4" + The format to save the artifacts. + """ + + export_format: str | None + root: Path | str | None + + def __init__(self, root: Path | str | None = None, export_format: str | None = "mp4") -> None: + self.root = Path(root) if root else Path.cwd() + (self.root / "canonical").mkdir(parents=True, exist_ok=True) + self.export_format = export_format if export_format else "mp4" + + def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: + """ + Save the video artifact. + + Parameters + ---------- + data: Any + The data to save. + saving_kwargs: dict + The additional kwargs to pass to the saving utility function. + + Returns + ------- + Path + The path to the saved artifact. + """ + canonical_filename = f"{int(time.time())}_{secrets.token_hex(4)}.{self.export_format}" + canonical_path = Path(str(self.root)) / "canonical" / canonical_filename + + # all diffusers saving utility functions accept a list of PIL.Images, so we convert to PIL to be safe. + if isinstance(data, torch.Tensor): + data = np.transpose(data.cpu().numpy(), (0, 2, 3, 1)) + data = np.clip(data * 255, 0, 255).astype(np.uint8) + if isinstance(data, np.ndarray): + data = [Image.fromarray(frame.astype(np.uint8)) for frame in data] + + if self.export_format == "mp4": + export_to_video(data, str(canonical_path), **saving_kwargs.copy()) + elif self.export_format == "gif": + export_to_gif(data, str(canonical_path), **saving_kwargs.copy()) + else: + raise ValueError(f"Invalid format: {self.export_format}") + return canonical_path diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 20cc2fd5..395b40fc 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -14,7 +14,10 @@ from __future__ import annotations -from typing import Any, List +import json +import tempfile +from pathlib import Path +from typing import Any, List, Literal import torch from torch import Tensor @@ -26,6 +29,7 @@ from pruna.data.utils import move_batch_to_device from pruna.engine.pruna_model import PrunaModel from pruna.engine.utils import get_device, move_to_device, safe_memory_cleanup, set_to_best_available_device +from pruna.evaluation.artifactsavers.utils import assign_artifact_saver from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.result import MetricResult @@ -33,6 +37,8 @@ from pruna.evaluation.task import Task from pruna.logging.logger import pruna_logger +OUTPUT_DIR = tempfile.mkdtemp(prefix="inference_outputs") + class EvaluationAgent: """ @@ -49,6 +55,18 @@ class EvaluationAgent: device : str | torch.device | None, optional The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. + save_artifacts : bool, optional + Whether to save the artifacts. Default is False. + root_dir : str | Path | None, optional + The directory to save the artifacts. Default is None. + num_samples_per_input : int, optional + The number of samples to generate per input. Default is 1. + seed_strategy : Literal["per_sample", "no_seed"], optional + The seed strategy to use. Default is "no_seed". + global_seed : int | None, optional + The global seed to use. Default is None. + saving_kwargs : dict, optional + The kwargs to pass to the artifact saver. Default is an empty dict. """ def __init__( @@ -58,6 +76,14 @@ def __init__( request: str | List[str | BaseMetric | StatefulMetric] | None = None, datamodule: PrunaDataModule | None = None, device: str | torch.device | None = None, + save_artifacts: bool = False, + root_dir: str | Path | None = None, + num_samples_per_input: int = 1, + seed_strategy: Literal["per_sample", "no_seed"] = "no_seed", + global_seed: int | None = None, + artifact_saver_export_format: str | None = None, + save_in_out_metadata: bool = False, + saving_kwargs: dict = dict(), ) -> None: if task is not None: if request is not None or datamodule is not None or device is not None: @@ -70,12 +96,21 @@ def __init__( if request is None or datamodule is None: raise ValueError("When not using 'task' parameter, both 'request' and 'datamodule' must be provided.") self.task = Task(request=request, datamodule=datamodule, device=device) - self.first_model_results: List[MetricResult] = [] self.subsequent_model_results: List[MetricResult] = [] + self.seed_strategy = seed_strategy + self.num_samples_per_input = num_samples_per_input + self.global_seed = global_seed self.device = set_to_best_available_device(self.task.device) self.cache: List[Tensor] = [] self.evaluation_for_first_model: bool = True + self.save_in_out_metadata: bool = save_in_out_metadata + self.save_artifacts: bool = save_artifacts + if save_artifacts: + self.root_dir = root_dir if root_dir is not None else OUTPUT_DIR + self.artifact_saver = assign_artifact_saver(self.task.modality, self.root_dir, artifact_saver_export_format) + # for miscellaneous saving kwargs like fps, etc. + self.saving_kwargs = saving_kwargs def evaluate(self, model: Any) -> List[MetricResult]: """ @@ -153,7 +188,7 @@ def prepare_model(self, model: Any) -> PrunaModel: ) else: - smash_config = SmashConfig(device="cpu") + smash_config = SmashConfig(device=get_device(model)) model = PrunaModel(model, smash_config=smash_config) pruna_logger.info("Evaluating a base model.") is_base = True @@ -183,6 +218,10 @@ def prepare_model(self, model: Any) -> PrunaModel: self.device = self.task.device # Keeping the device map to move model back to the original device, when the agent is finished. self.device_map = get_device_map(model) + model.set_to_eval() + # Setup seeding for inference. + model.inference_handler.configure_seed(self.seed_strategy, self.global_seed) + return model def update_stateful_metrics( @@ -212,22 +251,48 @@ def update_stateful_metrics( move_to_device(model, self.device, device_map=self.device_map) for batch_idx, batch in enumerate(tqdm(self.task.dataloader, desc="Processing batches", unit="batch")): - processed_outputs = model.run_inference(batch) - - batch = move_batch_to_device(batch, self.device) - processed_outputs = move_batch_to_device(processed_outputs, self.device) - (x, gt) = batch - # Non-pairwise (aka single) metrics have regular update. - for stateful_metric in single_stateful_metrics: - stateful_metric.update(x, gt, processed_outputs) - - # Cache outputs once in the agent for pairwise metrics to save compute time and memory. - if self.task.is_pairwise_evaluation(): - if self.evaluation_for_first_model: - self.cache.append(processed_outputs) - else: - for pairwise_metric in pairwise_metrics: - pairwise_metric.update(x, self.cache[batch_idx], processed_outputs) + for sample_idx in range(self.num_samples_per_input): + processed_outputs = model.run_inference(batch) + if self.save_artifacts: + canonical_paths = [] + # We have to save the artifacts for each sample in the batch. + for processed_output in processed_outputs: + canonical_path = self.artifact_saver.save_artifact(processed_output) + canonical_paths.append(canonical_path) + # Create aliases for the prompts if the user wants to save the artifacts with the prompt name. + # For doing that, the user needs to set the saving_kwargs["save_as_prompt_name"] to True. + if self.save_in_out_metadata: + self._create_input_output_metadata(batch, canonical_paths, sample_idx, batch_idx) + + batch = move_batch_to_device(batch, self.device) + processed_outputs = move_batch_to_device(processed_outputs, self.device) + (x, gt) = batch + # Non-pairwise (aka single) metrics have regular update. + for stateful_metric in single_stateful_metrics: + stateful_metric.update(x, gt, processed_outputs) + if self.save_artifacts and stateful_metric.create_alias: + # The evaluation agent saves the artifacts with a canonical filenaming convention. + # If the user wants to save the artifact with a different filename, + # here we give them the option to create an alias for the file. + for prompt_idx, prompt in enumerate(x): + if self.artifact_saver.export_format is None: + raise ValueError( + "Export format is not set. Please set the export format for the artifact saver." + ) + alias_filename = stateful_metric.create_filename( + filename=prompt, idx=sample_idx, file_extension=self.artifact_saver.export_format + ) + self.artifact_saver.create_alias(canonical_paths[prompt_idx], alias_filename) + + # Cache outputs once in the agent for pairwise metrics to save compute time and memory. + if self.task.is_pairwise_evaluation(): + if self.num_samples_per_input > 1: + raise ValueError("Pairwise evaluation with multiple samples per input is not supported.") + if self.evaluation_for_first_model: + self.cache.append(processed_outputs) + else: + for pairwise_metric in pairwise_metrics: + pairwise_metric.update(x, self.cache[batch_idx], processed_outputs) def compute_stateful_metrics( self, single_stateful_metrics: List[StatefulMetric], pairwise_metrics: List[StatefulMetric] @@ -286,3 +351,46 @@ def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[A for metric in children_of_base: results.append(metric.compute(model, self.task.dataloader)) return results + + def _create_input_output_metadata(self, batch, canonical_paths, sample_idx, batch_idx): + """ + Write prompt-level metadata for saved artifacts. + + If ``save_prompt_metadata`` is enabled, this function appends one JSONL + record per prompt to ``metadata.jsonl`` in the run output directory. + The canonical filename is used as the stable identifier. + + Args: + batch: Batch tuple where the first element contains the prompts. + canonical_paths: List of canonical file paths corresponding to each + prompt in the batch. + sample_idx: Index of the current sample within the evaluation run. + batch_idx: Index of the batch within the evaluation dataloader loop. + + Returns: + ------- + None + """ + (x, _) = batch # x = prompts + + metadata_path = Path(self.root_dir) / "metadata.jsonl" + metadata_path.parent.mkdir(parents=True, exist_ok=True) + + model_role = "reference" if self.evaluation_for_first_model else "candidate" + + with metadata_path.open("a", encoding="utf-8") as f: + for prompt_idx, prompt in enumerate(x): + record = { + # Model role: reference or candidate + "model_role": model_role, + # stable ID (file actually on disk) + "file": Path(canonical_paths[prompt_idx]).name, + # full path + "canonical_path": str(canonical_paths[prompt_idx]), + # original prompt + "prompt": str(prompt), + "sample_idx": sample_idx, + "batch_idx": batch_idx, + "prompt_idx": prompt_idx, + } + f.write(json.dumps(record, ensure_ascii=False) + "\n") diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index 77ccef6a..f39de5c4 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -24,6 +24,8 @@ from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper +from pruna.evaluation.metrics.metric_vbench_background_consistency import VBenchBackgroundConsistency +from pruna.evaluation.metrics.metric_vbench_dynamic_degree import VBenchDynamicDegree __all__ = [ "MetricRegistry", @@ -43,4 +45,6 @@ "DinoScore", "SharpnessMetric", "AestheticLAION", + "VBenchBackgroundConsistency", + "VBenchDynamicDegree", ] diff --git a/src/pruna/evaluation/metrics/metric_cmmd.py b/src/pruna/evaluation/metrics/metric_cmmd.py index 839cbe11..e93442ad 100644 --- a/src/pruna/evaluation/metrics/metric_cmmd.py +++ b/src/pruna/evaluation/metrics/metric_cmmd.py @@ -25,7 +25,7 @@ from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult -from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor +from pruna.evaluation.metrics.utils import IMAGE, SINGLE, get_call_type_for_single_metric, metric_data_processor from pruna.logging.logger import pruna_logger METRIC_CMMD = "cmmd" @@ -58,6 +58,7 @@ class CMMD(StatefulMetric): default_call_type: str = "gt_y" higher_is_better: bool = False metric_name: str = METRIC_CMMD + modality = {IMAGE} def __init__( self, @@ -99,6 +100,8 @@ def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.T The output images. """ inputs = metric_data_processor(x, gt, outputs, self.call_type, self.device) + if inputs[1].dtype == torch.bfloat16: + inputs[1] = inputs[1].to(torch.float16) gt_embeddings = self._get_embeddings(inputs[0]) output_embeddings = self._get_embeddings(inputs[1]) diff --git a/src/pruna/evaluation/metrics/metric_pairwise_clip.py b/src/pruna/evaluation/metrics/metric_pairwise_clip.py index 62e436e5..e541fa33 100644 --- a/src/pruna/evaluation/metrics/metric_pairwise_clip.py +++ b/src/pruna/evaluation/metrics/metric_pairwise_clip.py @@ -26,7 +26,7 @@ from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult -from pruna.evaluation.metrics.utils import PAIRWISE, metric_data_processor +from pruna.evaluation.metrics.utils import IMAGE, PAIRWISE, metric_data_processor from pruna.logging.logger import pruna_logger @@ -47,6 +47,7 @@ class PairwiseClipScore(CLIPScore, StatefulMetric): # type: ignore[misc] higher_is_better: bool = True metric_name: str = "pairwise_clip_score" + modality = {IMAGE} def __init__(self, **kwargs: Any) -> None: device = kwargs.pop("device", None) diff --git a/src/pruna/evaluation/metrics/metric_sharpness.py b/src/pruna/evaluation/metrics/metric_sharpness.py index b09067de..c0abaeac 100644 --- a/src/pruna/evaluation/metrics/metric_sharpness.py +++ b/src/pruna/evaluation/metrics/metric_sharpness.py @@ -24,7 +24,7 @@ from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult -from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor +from pruna.evaluation.metrics.utils import IMAGE, SINGLE, get_call_type_for_single_metric, metric_data_processor from pruna.logging.logger import pruna_logger METRIC_SHARPNESS = "sharpness" @@ -64,6 +64,7 @@ class SharpnessMetric(StatefulMetric): higher_is_better: bool = True metric_name: str = METRIC_SHARPNESS runs_on: List[str] = ["cpu", "cuda"] + modality = {IMAGE} def __init__(self, *args, kernel_size: int = 3, call_type: str = SINGLE, **kwargs) -> None: super().__init__(device=kwargs.pop("device", None)) diff --git a/src/pruna/evaluation/metrics/metric_stateful.py b/src/pruna/evaluation/metrics/metric_stateful.py index 39fddcf6..8407e96c 100644 --- a/src/pruna/evaluation/metrics/metric_stateful.py +++ b/src/pruna/evaluation/metrics/metric_stateful.py @@ -45,6 +45,8 @@ class StatefulMetric(ABC): metric_name: str call_type: str runs_on: list[str] = ["cuda", "cpu", "mps"] + create_alias: bool = False + modality: set[str] def __init__(self, device: str | torch.device | None = None, **kwargs) -> None: """Initialize the StatefulMetric class.""" @@ -167,3 +169,21 @@ def is_device_supported(self, device: str | torch.device) -> bool: """ dvc, _ = split_device(device_to_string(device)) return dvc in self.runs_on + + def create_filename(self, filename: str, idx: int, file_extension: str) -> str: + """ + Create a filename for the metric. + + Parameters + ---------- + filename: str + The name of the file. + file_extension: str + The extension of the file. + + Returns + ------- + str + The filename. + """ + return f"{filename}-{idx}.{file_extension}" diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index e1dfb0e0..e79d8a71 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -42,8 +42,11 @@ from pruna.evaluation.metrics.result import MetricResult from pruna.evaluation.metrics.utils import ( CALL_TYPES, + IMAGE, + MODALITIES, PAIRWISE, SINGLE, + TEXT, get_pairwise_pairing, get_single_pairing, metric_data_processor, @@ -124,9 +127,7 @@ def arniqa_update(metric: ARNIQA, preds: Any) -> None: def ssim_update( - metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, - preds: Any, - target: Any + metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, preds: Any, target: Any ) -> None: """ Update handler for SSIM or MS-SSIM metric. @@ -170,23 +171,24 @@ class TorchMetrics(Enum): The starting value for the enum. """ - fid = (partial(FrechetInceptionDistance), fid_update, "gt_y") - accuracy = (partial(Accuracy), None, "y_gt") - perplexity = (partial(Perplexity), None, "y_gt") - clip_score = (partial(CLIPScore), None, "y_x") - precision = (partial(Precision), None, "y_gt") - recall = (partial(Recall), None, "y_gt") - psnr = (partial(PeakSignalNoiseRatio), None, "pairwise_y_gt") - ssim = (partial(StructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt") - msssim = (partial(MultiScaleStructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt") - lpips = (partial(LearnedPerceptualImagePatchSimilarity), lpips_update, "pairwise_y_gt") - arniqa = (partial(ARNIQA), arniqa_update, "y") - clipiqa = (partial(CLIPImageQualityAssessment), None, "y") + fid = (partial(FrechetInceptionDistance), fid_update, "gt_y", {IMAGE}) + accuracy = (partial(Accuracy), None, "y_gt", MODALITIES) + perplexity = (partial(Perplexity), None, "y_gt", {TEXT}) + clip_score = (partial(CLIPScore), None, "y_x", {IMAGE}) + precision = (partial(Precision), None, "y_gt", MODALITIES) + recall = (partial(Recall), None, "y_gt", MODALITIES) + psnr = (partial(PeakSignalNoiseRatio), None, "pairwise_y_gt", {IMAGE}) + ssim = (partial(StructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt", {IMAGE}) + msssim = (partial(MultiScaleStructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt", {IMAGE}) + lpips = (partial(LearnedPerceptualImagePatchSimilarity), lpips_update, "pairwise_y_gt", {IMAGE}) + arniqa = (partial(ARNIQA), arniqa_update, "y", {IMAGE}) + clipiqa = (partial(CLIPImageQualityAssessment), None, "y", {IMAGE}) def __init__(self, *args, **kwargs) -> None: self.tm = self.value[0] self.update_fn = self.value[1] or default_update self.call_type = self.value[2] + self.modality = self.value[3] def __call__(self, **kwargs) -> Metric: """ @@ -260,6 +262,7 @@ def __init__(self, metric_name: str, call_type: str = "", **kwargs) -> None: # Get the specific update function for the metric, or use the default if not found. self.update_fn = TorchMetrics[metric_name].update_fn + self.modality = TorchMetrics[metric_name].modality except KeyError: raise ValueError(f"Metric {metric_name} is not supported.") diff --git a/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py new file mode 100644 index 00000000..acf3df56 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_vbench_background_consistency.py @@ -0,0 +1,140 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, List + +import clip +import torch +import torch.nn.functional as F # noqa: N812 +from vbench.utils import clip_transform, init_submodules + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import PAIRWISE, SINGLE, get_call_type_for_single_metric, metric_data_processor +from pruna.evaluation.metrics.vbench_utils import VBenchMixin +from pruna.logging.logger import pruna_logger + +METRIC_VBENCH_BACKGROUND_CONSISTENCY = "background_consistency" + + +@MetricRegistry.register(METRIC_VBENCH_BACKGROUND_CONSISTENCY) +class VBenchBackgroundConsistency(StatefulMetric, VBenchMixin): + """ + Background Consistency metric for VBench. + + Parameters + ---------- + *args : Any + The arguments to pass to the metric. + device : str | None + The device to run the metric on. + call_type : str + The call type to use for the metric. + **kwargs : Any + The keyword arguments to pass to the metric. + """ + + metric_name: str = METRIC_VBENCH_BACKGROUND_CONSISTENCY + default_call_type: str = "y" # We just need the outputs + higher_is_better: bool = True + # https://github.com/Vchitect/VBench/blob/dc62783c0fb4fd333249c0b669027fe102696682/evaluate.py#L111 + # explicitly sets the device to cuda. We respect this here. + runs_on: List[str] = ["cuda"] + modality: List[str] = ["video"] + # state + similarity_scores: torch.Tensor + n_samples: torch.Tensor + + def __init__( + self, + *args: Any, + device: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + if device is not None and str(device).split(":")[0] not in self.runs_on: + pruna_logger.error(f"Unsupported device {device}; supported: {self.runs_on}") + raise ValueError() + + if call_type == PAIRWISE: + # VBench itself does not support pairwise. + # We can work on this in the future. + pruna_logger.error("VBench does not support pairwise metrics. Please use single mode.") + raise ValueError() + + submodules_dict = init_submodules([METRIC_VBENCH_BACKGROUND_CONSISTENCY]) + model_path = submodules_dict[METRIC_VBENCH_BACKGROUND_CONSISTENCY][0] + + self.device = set_to_best_available_device(device) + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + + self.clip_model, self.preprocessor = clip.load(model_path, device=self.device) + self.video_transform = clip_transform(224) + + self.add_state("similarity_scores", torch.tensor(0.0)) + self.add_state("n_samples", torch.tensor(0)) + + def update(self, x: List[str], gt: Any, outputs: Any) -> None: + """ + Update the similarity scores for the batch. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data. + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images. + """ + outputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device) + # Background consistency metric only supports a batch size of 1. + # To support larger batch sizes, we stack the outputs. + outputs = torch.stack([self.video_transform(output) for output in outputs[0]]) + features = torch.stack([self.clip_model.encode_image(output) for output in outputs]) + features = F.normalize(features, dim=-1, p=2) + + first_feature = features[0].unsqueeze(0) + + similarity_to_first = F.cosine_similarity(first_feature, features[1:]).clamp(min=0.0) + similarity_to_prev = F.cosine_similarity(features[:-1], features[1:]).clamp(min=0.0) + + similarities = (similarity_to_first + similarity_to_prev) / 2 + + # Update stats + self.similarity_scores += similarities.sum().item() + self.n_samples += similarities.numel() + + def compute(self) -> MetricResult: + """ + Aggregate the final score. + + Returns + ------- + MetricResult + The final score. + """ + score = self.similarity_scores / self.n_samples + return MetricResult(self.metric_name, self.__dict__, score) + + def reset(self) -> None: + """Reset the metric states.""" + self.similarity_scores = torch.tensor(0.0) + self.n_samples = torch.tensor(0) diff --git a/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py new file mode 100644 index 00000000..dd3f1218 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_vbench_dynamic_degree.py @@ -0,0 +1,169 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, List + +import numpy as np +import torch +from easydict import EasyDict +from vbench.dynamic_degree import DynamicDegree +from vbench.third_party.RAFT.core.utils_core.utils import InputPadder +from vbench.utils import init_submodules + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import PAIRWISE, SINGLE, get_call_type_for_single_metric, metric_data_processor +from pruna.evaluation.metrics.vbench_utils import VBenchMixin +from pruna.logging.logger import pruna_logger + +METRIC_VBENCH_DYNAMIC_DEGREE = "dynamic_degree" + + +class PrunaDynamicDegree(DynamicDegree): + """Helper class to compute Dynamic Degree score for a given video.""" + + def infer(self, frames: torch.Tensor) -> bool: + """ + Compute Dynamic Degree score for a given video. + + Uses the RAFT Model for given video frames. + + Parameters + ---------- + frames: torch.Tensor + The video frames to compute the Dynamic Degree score for. + + Returns + ------- + bool + Whether the video contains large motions. + """ + self.set_params(frame=frames[0], count=len(frames)) + static_score = [] + for image1, image2 in zip(frames[:-1], frames[1:]): + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + _, flow_up = self.model(image1, image2, iters=20, test_mode=True) + max_rad = self.get_score(image1, flow_up) + static_score.append(max_rad) + whether_move = self.check_move(static_score) + return whether_move + + +@MetricRegistry.register(METRIC_VBENCH_DYNAMIC_DEGREE) +class VBenchDynamicDegree(StatefulMetric, VBenchMixin): + """ + Dynamic Degree Dimension from the Vbench video benchmark suite. + + It measures the degree of dynamics (i.e. whether it contains large motions) generated by the model. + + This is important since a completely static video can score well + in temporal quality metrics but is not actually useful. + + Parameters + ---------- + device: str | None, optional + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. + call_type: str, default="y" + The call type to be used, e.g., 'y' or 'y_gt'. Default is "y". + """ + + metric_name: str = METRIC_VBENCH_DYNAMIC_DEGREE + default_call_type: str = "y" # We just need the outputs + higher_is_better: bool = True + # https://github.com/Vchitect/VBench/blob/dc62783c0fb4fd333249c0b669027fe102696682/evaluate.py#L111 + # explicitly sets the device to cuda. We respect this here. + runs_on: List[str] = ["cuda"] + modality: List[str] = ["video"] + # state + scores: List[float] + + def __init__( + self, + *args: Any, + device: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + if device is not None and str(device).split(":")[0] not in self.runs_on: + pruna_logger.error(f"Unsupported device {device}; supported: {self.runs_on}") + raise ValueError() + + if call_type == PAIRWISE: + # VBench itself does not support pairwise. + # We can work on this in the future. + pruna_logger.error("VBench does not support pairwise metrics. Please use single mode.") + raise ValueError() + + submodules_dict = init_submodules([METRIC_VBENCH_DYNAMIC_DEGREE]) + model_path = submodules_dict[METRIC_VBENCH_DYNAMIC_DEGREE]["model"] + + self.device = set_to_best_available_device(device) + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + # RAFT models expect arguments to be passed as an object with attributes. + # So we need to convert the arguments to an EasyDict. + args_new = EasyDict({"model": model_path, "small": False, "mixed_precision": False, "alternate_corr": False}) + self.DynamicDegree = PrunaDynamicDegree(args_new, device) + self.add_state("scores", []) + + @torch.no_grad() + def update(self, x: List[str], gt: Any, outputs: Any) -> None: + """ + Calculate the dynamic degree score for the given video. + + The video is preprocessed to have approx. 8 frames per second. + + Then passed to the RAFT model to calculate the dynamic degree score. + + Each video is ranked as a 1 or 0 based on whether it contains large motions. + The final score is the mean of the scores for all videos. + + Parameters + ---------- + x: List[str] + The list of input videos. + gt: Any + The ground truth videos. + outputs: Any + The generated videos. + """ + outputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device) + score = self.DynamicDegree.infer(outputs) + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Calculate the final dynamic degree score. + + The final score is the mean of the scores for all videos. + + Returns + ------- + MetricResult + The dynamic degree score. + """ + final_score = np.mean(self.scores) + return MetricResult(name=self.metric_name, params=self.__dict__, result=final_score) + + def reset(self) -> None: + """Reset the state variables for the metric.""" + super().reset() + self.scores.clear() diff --git a/src/pruna/evaluation/metrics/utils.py b/src/pruna/evaluation/metrics/utils.py index 29342701..2232b198 100644 --- a/src/pruna/evaluation/metrics/utils.py +++ b/src/pruna/evaluation/metrics/utils.py @@ -41,6 +41,11 @@ SINGLE = "single" PAIRWISE = "pairwise" CALL_TYPES = (SINGLE, PAIRWISE) +IMAGE = "image" +VIDEO = "video" +TEXT = "text" + +MODALITIES = {IMAGE, VIDEO, TEXT} def metric_data_processor( diff --git a/src/pruna/evaluation/metrics/vbench_utils.py b/src/pruna/evaluation/metrics/vbench_utils.py new file mode 100644 index 00000000..c64a9bba --- /dev/null +++ b/src/pruna/evaluation/metrics/vbench_utils.py @@ -0,0 +1,365 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re +from pathlib import Path +from typing import Any, Callable, Iterable, List + +import numpy as np +import torch +from diffusers.utils import export_to_gif, export_to_video, load_video +from PIL.Image import Image +from torchvision.transforms import ToTensor +from vbench import VBench + +from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.engine.utils import safe_memory_cleanup, set_to_best_available_device +from pruna.logging.logger import pruna_logger + + +class VBenchMixin: + """ + Mixin class for VBench metrics. + + Handles benchmark specific initilizations and artifact saving conventions. + + Parameters + ---------- + *args: Any + The arguments to pass to the metric. + **kwargs: Any + The keyword arguments to pass to the metric. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + if VBench is None: + # Additional debug info to debug installation issues + pruna_logger.debug("Initialization failed: VBench is None") + raise ImportError("VBench is not installed. Please check your pruna installation.") + + def create_filename(self, prompt: str, idx: int, file_extension: str, special_str: str = "") -> str: + """ + Create filename according to VBench formatting conventions. + + Parameters + ---------- + prompt: str + The prompt to create the filename from. + idx: int + The index of the video. Vbench uses 5 seeds for each prompt. + file_extension: str + The file extension to use. Vbench supports mp4 and gif. + special_str: str + A special string to add to the filename if you wish to add a specific identifier. + + Returns + ------- + str + The filename. + """ + return create_vbench_file_name(sanitize_prompt(prompt), idx, special_str, file_extension) + + +def load_videos(path: str | Path, return_type: str = "pt") -> List[Image] | np.ndarray | torch.Tensor: + """ + Load videos from a path. + + Parameters: + ---------- + path: str | Path + The path to the videos. + return_type: str + The type to return the videos as. Can be "pt", "np", "pil". + + Returns: + ------- + List[torch.Tensor] + The videos. + """ + video = load_video(str(path)) + if return_type == "pt": + return torch.stack([ToTensor()(frame) for frame in video]) + elif return_type == "np": + return np.stack([np.array(frame) for frame in video]) + elif return_type == "pil": + return video + else: + raise ValueError(f"Invalid return_type: {return_type}. Use 'pt', 'np', or 'pil'.") + + +def sanitize_prompt(prompt: str) -> str: + """ + Return a filesystem-safe version of a prompt. + + Replaces characters illegal in filenames and collapses whitespace so that + generated files are portable across file systems. + + Parameters: + ---------- + prompt : str + The prompt to sanitize. + + Returns: + ------- + str + The sanitized prompt. + """ + prompt = re.sub(r"[\\/:*?\"<>|]", " ", prompt) # remove illegal chars + prompt = re.sub(r"\s+", " ", prompt) # collapse multiple spaces + prompt = prompt.strip() # remove leading and trailing whitespace + return prompt + + +def prepare_batch(batch: str | tuple[str | List[str], Any]) -> str: + """ + Prepare the batch to be used in the generate_videos function. + + Pruna datamodules are expected to yield tuples where the first element is + a sequence of inputs; this utility enforces batch_size == 1 for simplicity. + + + Parameters: + ---------- + batch: str | tuple[str | List[str], Any] + The batch to prepare. + + Returns: + ------- + str + The prompt string. + """ + if isinstance(batch, str): + return batch + # for pruna datamodule. always returns a tuple where the first element is the input to the model. + elif isinstance(batch, tuple): + if not hasattr(batch[0], "__len__"): + raise ValueError(f"Batch[0] is not a sequence (got {type(batch[0])})") + if len(batch[0]) != 1: + raise ValueError(f"Only batch size 1 is supported; got {len(batch[0])}") + return batch[0][0] + else: + raise ValueError(f"Invalid batch type: {type(batch)}") + + +def _normalize_save_format(save_format: str) -> tuple[str, Callable]: + """ + Normalize the save format to be used in the generate_videos function. + + Parameters: + ---------- + save_format : str + The format to save the videos in. VBench supports mp4 and gif. + + Returns: + ------- + tuple[str, Callable] + The normalized save format and the save function. + """ + save_format = save_format.lower().strip() + if save_format == "mp4": + return ".mp4", export_to_video + if save_format == "gif": + return ".gif", export_to_gif + raise ValueError(f"Invalid save_format: {save_format}. Use 'mp4' or 'gif'.") + + +def _normalize_prompts( + prompts: str | List[str] | PrunaDataModule, split: str = "test", batch_size: int = 1 +) -> Iterable[str]: + """ + Normalize prompts to an iterable format to be used in the generate_videos function. + + Parameters: + ---------- + prompts : str | List[str] | PrunaDataModule + The prompts to normalize. + + Returns: + ------- + Iterable[str] + The normalized prompts. + """ + if isinstance(prompts, str): + return [prompts] + elif isinstance(prompts, PrunaDataModule): + return getattr(prompts, f"{split}_dataloader")(batch_size=batch_size) + else: # list of prompts, already iterable + return prompts + + +def _ensure_dir(p: Path) -> None: + """ + Ensure the directory exists. + + Parameters: + ---------- + p : Path + The path to ensure the directory exists. + """ + p.mkdir(parents=True, exist_ok=True) + + +def create_vbench_file_name(prompt: str, idx: int, special_str: str = "", postfix: str = ".mp4") -> str: + """ + Create a file name for the video in accordance with the VBench format. + + Parameters: + ---------- + prompt: str + The prompt to create the file name from. + idx: int + The index of the video. Vbench uses 5 seeds for each prompt. + special_str: str + A special string to add to the file name if you wish to add a specific identifier. + postfix: str + The format of the video file. Vbench supports mp4 and gif. + + Returns: + ------- + str + The file name for the video. + """ + return f"{prompt}{special_str}-{str(idx)}{postfix}" + + +def sample_video_from_pipelines(pipeline: Any, seeder: Any, prompt: str, device: str | torch.device = None, **kwargs): + """ + Sample a video from diffusers pipeline. + + Parameters: + ---------- + pipeline: Any + The pipeline to sample from. + prompt: str + The prompt to sample from. + seeder: Any + The seeding generator. + **kwargs: Any + Additional keyword arguments to pass to the pipeline. + + Returns: + ------- + torch.Tensor + The video tensor. + """ + with torch.inference_mode(): + out = pipeline(prompt=prompt, generator=seeder, **kwargs).frames[0] + + return out + + +def _wrap_sampler(model: Any, sampling_fn: Callable[..., Any]) -> Callable[..., Any]: + """ + Wrap a user-provided sampling function into a uniform callable. + + The returned callable has a keyword-only signature: + sampler(*, prompt: str, seed: int, device: str|torch.device, **kwargs) + + This wrapper always passes `model` as the first positional argument, so + custom functions can name their first parameter `model` or `pipeline`, etc. + + Parameters: + ---------- + model: Any + The model to sample from. + sampling_fn: Callable[..., Any] + The sampling function to wrap. + + """ + if sampling_fn != sample_video_from_pipelines: + pruna_logger.info( + "Using custom sampling function. Ensure it accepts (model, *, prompt, seed, device, **kwargs)." + ) + + # The sampling function may expect the model as "pipeline" so we pass it as an arg and not a kwarg. + def sampler(*, prompt: str, seeder: Any, device: str | torch.device, **kwargs: Any) -> Any: + return sampling_fn(model, prompt=prompt, seeder=seeder, device=device, **kwargs) + + return sampler + + +def generate_videos( + model: Any, + prompts: str | List[str] | PrunaDataModule, + split: str = "test", + unique_sample_per_video_count: int = 5, + global_seed: int = 42, + sampling_fn: Callable[..., Any] = sample_video_from_pipelines, + fps: int = 16, + save_dir: str | Path = "./saved_videos", + save_format: str = "mp4", + special_str: str = "", + device: str | torch.device = None, + **model_kwargs, +) -> None: + """ + Generate N samples per prompt and save them to disk with seed tracking. + + This function: + 1) Normalizes prompts (string, list, or datamodule). + 2) Uses an RNG seeded with `global_seed` for reproducibility across runs. + 3) Saves videos as MP4 or GIF. + + Parameters: + ---------- + model : Any + The model to sample from. + prompts : str | List[str] | PrunaDataModule + The prompts to sample from. + split : str + The split to sample from. + Default is "test" since most benchmarking datamodules in Pruna are configured to use the test split. + unique_sample_per_video_count : int + The number of unique samples per video. Default is 5 by VBench requirements. + global_seed : int + The global seed to sample from. + sampling_fn : Callable[..., Any] + The sampling function to use. + fps : int + The frames per second of the video. + save_dir : str | Path + The directory to save the videos to. + save_format : str + The format to save the videos in. VBench supports mp4 and gif. + special_str : str + A special string to add to the file name if you wish to add a specific identifier. + **model_kwargs : Any + Additional keyword arguments to pass to the sampling function. + """ + file_extension, save_fn = _normalize_save_format(save_format) + + device = set_to_best_available_device(device) + + prompt_iterable = _normalize_prompts(prompts, split, batch_size=1) + + save_dir = Path(save_dir) + _ensure_dir(save_dir) + + # set a run-level seed (VBench suggests this) (important for reproducibility) + seed_rng = torch.Generator().manual_seed(global_seed) + sampler = _wrap_sampler(model=model, sampling_fn=sampling_fn) + + for batch in prompt_iterable: + prompt = prepare_batch(batch) + for idx in range(unique_sample_per_video_count): + file_name = create_vbench_file_name(sanitize_prompt(prompt), idx, special_str, file_extension) + out_path = save_dir / file_name + + vid = sampler(prompt=prompt, seeder=seed_rng, device=device, **model_kwargs) + save_fn(vid, out_path, fps=fps) + + del vid + safe_memory_cleanup() diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index e8c63688..dd6f10de 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -70,6 +70,7 @@ def __init__( self.stateful_metric_device = self._get_stateful_metric_device_from_task_device() self.metrics = _safe_build_metrics(request, self.device, self.stateful_metric_device) + self.modality = self.validate_and_get_task_modality() self.datamodule = datamodule self.dataloader = datamodule.test_dataloader() @@ -140,6 +141,31 @@ def _get_stateful_metric_device_from_task_device(self) -> str: else: return self.device # for when we pass a specific cuda device, or cpu or mps. + def validate_and_get_task_modality(self) -> str: + """ + Check if the task has a single modality of metrics. + + Inference handling is different for different modalities. + The task should have one consistent modality across metrics. + Stateless metrics and stateful metrics with general modalities are allowed. + + Returns + ------- + str + The modality of the task. + """ + if not self.get_single_stateful_metrics() and not self.get_pairwise_stateful_metrics(): + return "general" + modality_intersection = set.intersection( + *[metric.modality for metric in self.metrics if isinstance(metric, StatefulMetric)] + ) + if len(modality_intersection) == 1: + return modality_intersection.pop() + elif len(modality_intersection) == 0: + raise ValueError("The task should have a single modality across all quality metrics.") + else: # More than one modality, fine for evaluation, can't save artifacts (for now). + return "general" + def _safe_build_metrics( request: str | List[str | BaseMetric | StatefulMetric], inference_device: str, stateful_metric_device: str diff --git a/tests/algorithms/testers/flash_attn3.py b/tests/algorithms/testers/flash_attn3.py index 575f8802..90738424 100644 --- a/tests/algorithms/testers/flash_attn3.py +++ b/tests/algorithms/testers/flash_attn3.py @@ -13,4 +13,4 @@ class TestFlashAttn3(AlgorithmTesterBase): reject_models = ["opt_tiny_random"] allow_pickle_files = False algorithm_class = FlashAttn3 - metrics = ["latency"] + metrics = ["background_consistency"] diff --git a/tests/engine/test_handler.py b/tests/engine/test_handler.py new file mode 100644 index 00000000..5501dc80 --- /dev/null +++ b/tests/engine/test_handler.py @@ -0,0 +1,116 @@ +import numpy as np +import pytest +import torch +from pruna.engine.handler.handler_inference import ( + validate_seed_strategy, +) +from pruna.engine.handler.handler_diffuser import DiffuserHandler +from pruna.engine.handler.handler_standard import StandardHandler +from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import move_to_device +# Default handler tests, mainly for checking seeding. + +def test_validate_seed_strategy_valid(): + '''Test to see validate_seed_strategy is valid for valid strategies''' + validate_seed_strategy("per_sample", 42) + validate_seed_strategy("no_seed", None) + + +@pytest.mark.parametrize("strategy,seed", [ + ("per_sample", None), + ("no_seed", 42), +]) +def test_validate_seed_strategy_invalid(strategy, seed): + '''Test to see validate_seed_strategy raises an error for invalid strategies''' + with pytest.raises(ValueError): + validate_seed_strategy(strategy, seed) + +def test_set_seed_reproducibility(): + inference_handler = StandardHandler() + inference_handler.set_seed(42) + torch_random_tensor = torch.randn(3) + numpy_random_tensor = np.random.randn(3) + inference_handler.set_seed(42) + torch_expected = torch.randn(3) + numpy_expected = np.random.randn(3) + assert torch.equal(torch_random_tensor, torch_expected) + assert np.array_equal(numpy_random_tensor, numpy_expected) + + +# Diffuser handler tests, checking output processing and seeding. +@pytest.mark.parametrize("model_fixture", + [ + pytest.param("flux_tiny_random", marks=pytest.mark.cuda), + ], + indirect=["model_fixture"], + ) +def test_assignment_of_diffuser_handler(model_fixture): + """Check if a diffusion model is assigned to the DiffuserHandler""" + model, smash_config = model_fixture + + pruna_model = PrunaModel(model, smash_config=smash_config) + assert isinstance(pruna_model.inference_handler, DiffuserHandler) + +@pytest.mark.parametrize("model_fixture, seed, output_attr, return_dict, device", + [ + pytest.param("flux_tiny_random", 42, "images", True, "cpu", marks=pytest.mark.cpu), + pytest.param("wan_tiny_random", 42, "frames" ,True, "cuda", marks=pytest.mark.cuda), + pytest.param("flux_tiny_random", 42, "none", False, "cpu", marks=pytest.mark.cpu), + ], + indirect=["model_fixture"],) +def test_process_output_images(model_fixture, seed, output_attr, return_dict, device): + """Check if the output of the model is processed correctly""" + input_text = "a photo of a cute prune" + + # Get the output from PrunaModel + model, smash_config = model_fixture + smash_config.device = device + move_to_device(model, device) + pruna_model = PrunaModel(model, smash_config=smash_config) + pruna_model.inference_handler.configure_seed("per_sample", global_seed=seed) + result = pruna_model.run_inference(input_text) + + # Get the output from the pipeline. + pipe_output = model(input_text, output_type="pt", generator=torch.Generator("cpu").manual_seed(seed), return_dict=return_dict) + if output_attr != "none": + pipe_output = getattr(pipe_output, output_attr) + pipe_output = pipe_output[0] + + assert (result == pipe_output).all().item() + + +@pytest.mark.parametrize("model_fixture", + [ + pytest.param("flux_tiny_random", marks=pytest.mark.cpu), + ], + indirect=["model_fixture"],) +def test_per_sample_seed_is_applied(model_fixture): + """Check if samples change per inference run when per_sample seed is applied""" + model, smash_config = model_fixture + smash_config.device = "cpu" + move_to_device(model, "cpu") + input_text = "a photo of a cute prune" + pruna_model = PrunaModel(model, smash_config=smash_config) + pruna_model.inference_handler.configure_seed("per_sample", global_seed=42) + first_result = pruna_model.run_inference(input_text) + second_result = pruna_model.run_inference(input_text) + # If seeding is successfull, each sample should create a different output. + assert not torch.equal(first_result, second_result) + +@pytest.mark.parametrize("model_fixture", + [ + pytest.param("flux_tiny_random", marks=pytest.mark.cpu), + ], + indirect=["model_fixture"],) +def test_seed_is_removed(model_fixture): + """ Check if seed is removed when no_seed seed is applied""" + model, smash_config = model_fixture + smash_config.device = "cpu" + move_to_device(model, "cpu") + pruna_model = PrunaModel(model, smash_config=smash_config) + pruna_model.inference_handler.configure_seed("per_sample", global_seed=42) + # First check if the seed is set. + assert pruna_model.inference_handler.model_args["generator"] is not None + pruna_model.inference_handler.configure_seed("no_seed", None) + # Then check if the seed generator is removed. + assert pruna_model.inference_handler.model_args["generator"] is None diff --git a/tests/evaluation/test_artifactsaver.py b/tests/evaluation/test_artifactsaver.py new file mode 100644 index 00000000..7a765237 --- /dev/null +++ b/tests/evaluation/test_artifactsaver.py @@ -0,0 +1,180 @@ +import pytest +import tempfile +from pathlib import Path + +import numpy as np +import torch + +from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver +from pruna.evaluation.artifactsavers.image_artifactsaver import ImageArtifactSaver +from pruna.evaluation.artifactsavers.utils import assign_artifact_saver +from pruna.evaluation.metrics.vbench_utils import load_video +from PIL import Image + + +def test_create_alias(): + """ Test that we can create an alias for an existing video and image.""" + with tempfile.TemporaryDirectory() as tmp_path: + + # --- Test video artifact saver --- + # First, we create a random video and save it. + saver = VideoArtifactSaver(root=tmp_path, export_format="mp4") + dummy_video = np.random.randint(0, 255, (10, 16, 16, 3), dtype=np.uint8) + source_filename = saver.save_artifact(dummy_video, saving_kwargs={"fps": 5}) + # Then, we create an alias for the video. + alias = saver.create_alias(source_filename, "alias_filename") + # Finally, we reload the alias and check that it is the same as the original video. + reloaded_alias_video = load_video(str(alias), return_type = "np") + assert(reloaded_alias_video.shape == dummy_video.shape) + assert alias.exists() + assert alias.name.endswith(".mp4") + + # --- Test image artifact saver --- + saver = ImageArtifactSaver(root=tmp_path, export_format="png") + dummy_image = np.random.randint(0, 255, (16, 16, 3), dtype=np.uint8) + source_filename = saver.save_artifact(dummy_image, saving_kwargs={"quality": 95}) + # Then, we create an alias for the image. + alias = saver.create_alias(source_filename, "alias_filename") + # Finally, we reload the alias and check that it is the same as the original image. + reloaded_alias_image = np.array(Image.open(str(alias))) + assert(reloaded_alias_image.shape == dummy_image.shape) + assert alias.exists() + assert alias.name.endswith(".png") + +def test_assign_all_artifact_savers(tmp_path: Path): + """ Test each artifact saver is assigned correctly.""" + saver = assign_artifact_saver("video", root=tmp_path, export_format="mp4") + assert isinstance(saver, VideoArtifactSaver) + assert saver.export_format == "mp4" + saver = assign_artifact_saver("image", root=tmp_path, export_format="png") + assert isinstance(saver, ImageArtifactSaver) + assert saver.export_format == "png" + +def test_assign_artifact_saver_invalid(): + """ Test that we raise an error if the artifact saver is assigned incorrectly.""" + with pytest.raises(ValueError): + assign_artifact_saver("nonexistent_modality") + +@pytest.mark.parametrize( + "export_format, save_from_type, save_from_dtype", + [pytest.param("gif", "np", "uint8"), + pytest.param("gif", "np", "float32"), + # Numpy doesn't have half precision, so we do not test for float16 + pytest.param("gif", "pt", "float32"), + pytest.param("gif", "pt", "float16"), + pytest.param("gif", "pt", "uint8"), + # PIL doesnot support creating images from float numpy arrays, so we only test uint8. + pytest.param("gif", "pil", "uint8"), + pytest.param("mp4", "np", "uint8"), + pytest.param("mp4", "np", "float32"), + pytest.param("mp4", "pt", "float32"), + pytest.param("mp4", "pt", "float16"), + pytest.param("mp4", "pt", "uint8"), + # PIL doesnot support creating images from float numpy arrays, so we only test uint8. + pytest.param("mp4", "pil", "uint8"),] +) +def test_video_artifact_saver_tensor(export_format: str, save_from_type: str, save_from_dtype: str): + """ Test that we can save a video from numpy, torch and PIL in mp4 and gif formats. """ + with tempfile.TemporaryDirectory() as tmp_path: + saver = VideoArtifactSaver(root=tmp_path, export_format=export_format) + # create a fake video: + if save_from_type == "pt": + # Unfortunately, neither torch nor numpy have one random generator function that can support all dtypes. + # Therefore, we need to use different functions for int and float dtypes. + if save_from_dtype == "uint8": + dtype = getattr(torch, save_from_dtype) + dummy_video = torch.randint(0, 256, (2, 3, 16, 16), dtype=dtype) + else: + dtype = getattr(torch, save_from_dtype) + dummy_video = torch.randn(2, 3, 16, 16, dtype=dtype) + elif save_from_type == "np": + if save_from_dtype == "uint8": + dtype = getattr(np, save_from_dtype) + dummy_video = np.random.randint(0, 256, (2, 16, 16, 3), dtype=dtype) + else: + rng = np.random.default_rng() + dtype = getattr(np, save_from_dtype) + dummy_video = rng.random((2, 16, 16, 3), dtype=dtype) + elif save_from_type == "pil": + dtype = getattr(np, save_from_dtype) + dummy_video = np.random.randint(0, 256, (2, 16, 16, 3), dtype=dtype) + dummy_video = [Image.fromarray(frame.astype(np.uint8)) for frame in dummy_video] + path = saver.save_artifact(dummy_video) + assert path.exists() + assert path.suffix == f".{export_format}" + +@pytest.mark.parametrize( + "export_format, save_from_type, save_from_dtype", + [ + # --- Test png format --- + # numpy + pytest.param("png", "np", "uint8"), + pytest.param("png", "np", "float32"), + # torch + pytest.param("png", "pt", "float32"), + pytest.param("png", "pt", "float16"), + pytest.param("png", "pt", "uint8"), + # PIL + pytest.param("png", "pil", "uint8"), + # --- Test jpg format --- + # numpy + pytest.param("jpg", "np", "uint8"), + pytest.param("jpg", "np", "float32"), + # torch + pytest.param("jpg", "pt", "float32"), + pytest.param("jpg", "pt", "float16"), + pytest.param("jpg", "pt", "uint8"), + # PIL + pytest.param("jpg", "pil", "uint8"), + # --- Test webp format --- + # numpy + pytest.param("webp", "np", "uint8"), + pytest.param("webp", "np", "float32"), + # torch + pytest.param("webp", "pt", "float32"), + pytest.param("webp", "pt", "float16"), + pytest.param("webp", "pt", "uint8"), + # PIL + pytest.param("webp", "pil", "uint8"), + # --- Test jpeg format --- + # numpy + pytest.param("jpeg", "np", "uint8"), + pytest.param("jpeg", "np", "float32"), + # torch + pytest.param("jpeg", "pt", "float32"), + pytest.param("jpeg", "pt", "float16"), + pytest.param("jpeg", "pt", "uint8"), + # PIL + pytest.param("jpeg", "pil", "uint8"), + ] + ) +def test_image_artifact_saver_tensor(export_format: str, save_from_type: str, save_from_dtype: str): + """ Test that we can save an image from a tensor.""" + with tempfile.TemporaryDirectory() as tmp_path: + saver = ImageArtifactSaver(root=tmp_path, export_format=export_format) + # Create fake image: + if save_from_type == "pt": + # Note: torch convention is (C, H, W) + if save_from_dtype == "uint8": + dtype = getattr(torch, save_from_dtype) + dummy_image = torch.randint(0, 256, (3, 16, 16), dtype=dtype) + else: + dtype = getattr(torch, save_from_dtype) + dummy_image = torch.randn(3, 16, 16, dtype=dtype) + elif save_from_type == "np": + # Note: Numpy arrays as images follow the convention (H, W, C) + if save_from_dtype == "uint8": + dtype = getattr(np, save_from_dtype) + dummy_image = np.random.randint(0, 256, (16, 16, 3), dtype=dtype) + else: + rng = np.random.default_rng() + dtype = getattr(np, save_from_dtype) + dummy_image = rng.random((16, 16, 3), dtype=dtype) + elif save_from_type == "pil": + # Note: PIL images by default have shape (H, W, C) and are usually uint8 (standard for ".jpg", etc.) + dtype = getattr(np, save_from_dtype) + dummy_image = np.random.randint(0, 256, (16, 16, 3), dtype=dtype) + dummy_image = Image.fromarray(dummy_image.astype(np.uint8)) + path = saver.save_artifact(dummy_image) + assert path.exists() + assert path.suffix == f".{export_format}" \ No newline at end of file diff --git a/tests/evaluation/test_evalagent.py b/tests/evaluation/test_evalagent.py new file mode 100644 index 00000000..e3dc33ba --- /dev/null +++ b/tests/evaluation/test_evalagent.py @@ -0,0 +1,42 @@ +import pytest +import torch +from pathlib import Path +import tempfile +from pruna.evaluation.evaluation_agent import EvaluationAgent +from pruna.engine.pruna_model import PrunaModel + + +@pytest.mark.cuda +@pytest.mark.parametrize("model_fixture, export_format", +[pytest.param("wan_tiny_random", "mp4", marks=pytest.mark.cuda), +pytest.param("wan_tiny_random", "gif", marks=pytest.mark.cuda)], indirect=["model_fixture"]) +def test_agent_saves_artifacts(model_fixture, export_format): + """ Test that the agent runs inference and saves the inference output artifacts correctly.""" + model, smash_config = model_fixture + # Artifact path + temp_path = tempfile.mkdtemp() + + # Let's limit the number of data points + dm = smash_config.data + data_points = 2 + dm.limit_datasets(data_points) + + + agent = EvaluationAgent( + request=['background_consistency'], + datamodule=dm, + device="cuda", + save_artifacts=True, + root_dir=temp_path, + saving_kwargs={"fps":4}, + artifact_saver_export_format=export_format, + ) + + pruna_model = PrunaModel(model, smash_config) + pruna_model.inference_handler.model_args["num_inference_steps"] = 1 + + agent.evaluate(model=pruna_model) + files = list(Path(temp_path).rglob(f"*.{export_format}")) + + # Check that we saved the correct number of files + assert len(files) == data_points diff --git a/tests/evaluation/test_task.py b/tests/evaluation/test_task.py index 5420a17b..218edd8b 100644 --- a/tests/evaluation/test_task.py +++ b/tests/evaluation/test_task.py @@ -107,3 +107,22 @@ def test_task_from_string_request(): assert isinstance(task.metrics[0], CMMD) assert isinstance(task.metrics[1], PairwiseClipScore) assert isinstance(task.metrics[2], TorchMetricWrapper) + + +@pytest.mark.parametrize("metrics, modality", [ + (["cmmd", "ssim", "lpips", "latency"], "image"), + (["perplexity", "disk_memory"], "text"), + ([MetricRegistry().get_metric("accuracy", task="binary"), MetricRegistry().get_metric("recall", task="binary")], "general"), + (["background_consistency", "dynamic_degree"], "video") +]) +def test_task_modality(metrics, modality): + """ Test that the task modality is assigned correctly for image, text, general and video metrics.""" + datamodule = type("dm", (), {"test_dataloader": lambda self: []})() + task = Task(request=metrics, datamodule=datamodule) + assert task.modality == modality + +def test_task_modality_mixed_raises(): + """ Test that we raise an error if the task modality is mixed.""" + datamodule = type("dm", (), {"test_dataloader": lambda self: []})() + with pytest.raises(ValueError): + Task(request=["cmmd", "background_consistency"], datamodule=datamodule) diff --git a/tests/evaluation/test_vbench_metrics.py b/tests/evaluation/test_vbench_metrics.py new file mode 100644 index 00000000..c1b98d0a --- /dev/null +++ b/tests/evaluation/test_vbench_metrics.py @@ -0,0 +1,164 @@ +from __future__ import annotations +from __future__ import annotations + +import pytest +import torch +import numpy as np + +from pruna.evaluation.metrics.metric_vbench_background_consistency import VBenchBackgroundConsistency +from pruna.evaluation.metrics.metric_vbench_dynamic_degree import VBenchDynamicDegree +from pruna.evaluation.metrics.utils import PAIRWISE +from pruna.evaluation.metrics.registry import MetricRegistry + + +@pytest.mark.cuda +def test_metric_background_consistency(): + """Test metric background consistency.""" + # let's create a batch of 2 random RGB videos with 2 frames each that's 16 x 16 pixels. + random_input_video_batch = torch.randn(2, 2, 3, 16,16) + # let's create a batch of 2 all black RGB videos with 2 frames each that's 16 x 16 pixels. + all_black_input_video_batch = torch.zeros(2, 2, 3, 16,16) + metric = VBenchBackgroundConsistency() + metric.update(random_input_video_batch, random_input_video_batch, random_input_video_batch) + random_result = metric.compute() + metric.reset() + metric.update(all_black_input_video_batch, all_black_input_video_batch, all_black_input_video_batch) + all_black_result = metric.compute() + metric.reset() + # Background consistency checks for the cosine similarity between frames. + # Therefore we would expect a completely blacked out video (even though meaningless) to have a much higher score + # than a completely random set of frames. + assert all_black_result.result >= random_result.result + # Since the all black video is completely black, the cosine similarity between frames should be 1.0 + assert np.isclose(all_black_result.result, 1.0) + +@pytest.mark.cuda +@pytest.mark.parametrize("model_fixture", ["wan_tiny_random"], indirect=["model_fixture"]) +def test_metric_dynamic_degree_dynamic(model_fixture): + """Test metric dynamic degree with an example sample from the vbench dataset that returns a dynamic video.""" + model, smash_config = model_fixture + model.to("cuda") + model.to(torch.float32) + # this is a prompt from the vbench dataset under dynamic degree dimension. + output_video = model("a dog running happily", num_inference_steps=10, output_type="pt").frames[0].unsqueeze(0) + + metric = VBenchDynamicDegree(interval=4) + metric.update(output_video, output_video, output_video) + result = metric.compute() + # a video of a dog running ideally should have a dynamic degree of 1.0 (since it contains large movements) + assert result.result == 1.0 + +@pytest.mark.cuda +def test_metric_dynamic_degree_static(): + """Test metric dynamic degree fail case.""" + # Testing for a lack of movement is much easier than testing for movement. + # We create a completely black video to test the metric. + video = torch.zeros(1,4,3,64,64) # a completely black video doesn't contain any movements + metric = VBenchDynamicDegree(interval=1) + metric.update(video, video, video) + result = metric.compute() + assert result.result == 0.0 + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", [VBenchDynamicDegree, VBenchBackgroundConsistency]) +def test_metric_pairwise_call_type(vbench_metric): + """Test that VBenchBackgroundConsistency raises ValueError for PAIRWISE call_type.""" + with pytest.raises(ValueError): + vbench_metric(call_type=PAIRWISE) + + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", ['background_consistency', 'dynamic_degree']) +def test_vbench_metrics_invalid_tensor_dimensions(vbench_metric): + """Test that validate_batch raises ValueError for invalid tensor dimensions.""" + metric = MetricRegistry.get_metric(vbench_metric) + # Test 3D tensor (should fail) + invalid_tensor = torch.randn(2, 3, 16) + with pytest.raises(ValueError, match="4 or 5 dimensional"): + metric.validate_batch(invalid_tensor) + + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", ["background_consistency", "dynamic_degree"]) +def test_vbench_metrics_compute_without_updates(vbench_metric): + """Test compute() returns 0.0 when no updates have been made.""" + metric = MetricRegistry.get_metric(vbench_metric) + result = metric.compute() + assert result.result == 0.0 + assert result.name == vbench_metric + + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", ["background_consistency", "dynamic_degree"]) +def test_vbench_metrics_4d_tensor(vbench_metric): + """Test that 4D tensors (B, C, H, W) are properly converted to 5D.""" + # 4D tensor should be converted to 5D by validate_batch + four_d_video = torch.randn(2, 3, 16, 16) # Missing time dimension + metric = MetricRegistry.get_metric(vbench_metric) + metric.update(four_d_video, four_d_video, four_d_video) + result = metric.compute() + assert result.result >= 0.0 + assert result.name == vbench_metric + + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", ["background_consistency", "dynamic_degree"]) +def test_vbench_metrics_different_batch_sizes(vbench_metric): + """Test background consistency with different batch sizes.""" + metric = MetricRegistry.get_metric(vbench_metric) + for batch_size in [1, 2, 3]: + video_batch = torch.randn(batch_size, 5, 3, 64, 64) + metric.update(video_batch, video_batch, video_batch) + result = metric.compute() + assert 0.0 <= result.result <= 1.0 + metric.reset() + + + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", ["background_consistency", "dynamic_degree"]) +def test_vbench_metrics_different_resolutions(vbench_metric): + """Test background consistency with different video resolutions.""" + resolutions = [(64, 64), (128, 128), (224, 224)] + for h, w in resolutions: + video = torch.randn(1, 5, 3, h, w) + metric = MetricRegistry.get_metric(vbench_metric) + metric.update(video, video, video) + result = metric.compute() + assert 0.0 <= result.result <= 1.0 + metric.reset() + + + +@pytest.mark.cuda +@pytest.mark.parametrize("vbench_metric", ["background_consistency", "dynamic_degree"]) +def test_vbench_metrics_reset_clears_state(vbench_metric): + """Test that reset() properly clears the metric state.""" + metric = MetricRegistry.get_metric(vbench_metric) + video = torch.randn(1, 5, 3, 64, 64) + + # First update and compute + metric.update(video, video, video) + result1 = metric.compute() + + # Reset and verify state is cleared + metric.reset() + result2 = metric.compute() + assert result2.result == 0.0 # Should be 0.0 after reset with no updates + + # Update again and verify it works + metric.update(video, video, video) + result3 = metric.compute() + assert result3.result >= 0.0 + + + +@pytest.mark.cuda +@pytest.mark.parametrize("interval", [1, 3, 5, 10]) +def test_metric_dynamic_degree_different_intervals(interval): + """Test dynamic degree with different interval values.""" + video = torch.zeros(1, 12, 3, 64, 64) # 12 frames to allow various intervals + metric = VBenchDynamicDegree(interval=interval) + metric.update(video, video, video) + result = metric.compute() + assert 0.0 <= result.result <= 1.0 # Score should be in valid range