Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4d5d496
feat: 2 vbench dimensions and vbench dependencies
begumcig Oct 9, 2025
b8d0392
test: vbench metric tests
begumcig Oct 10, 2025
d350a70
docs: add more comprehensive docstring explanations for important par…
begumcig Oct 10, 2025
d16c583
feat: add additional helper tools to utilities
begumcig Oct 13, 2025
82342cb
refactor: small updates to utilities and docstrings
begumcig Oct 14, 2025
102e91f
refactor: add support for more calltypes in video eval utils
begumcig Oct 15, 2025
2f579fd
refactor: make utilities more vbench independent and fix small things…
begumcig Oct 28, 2025
9abf898
refactor: address PR comments
begumcig Nov 3, 2025
e4bc717
test: adding more tests for dynamic degree and background consistency
begumcig Nov 10, 2025
2f37e88
feat: artifact saving and vbench related agent updates
begumcig Oct 6, 2025
b8f3e8d
test: add tests for the artifact savers
begumcig Oct 7, 2025
d8c7a28
test: add artifact related evaluation tests and task modality tests
begumcig Oct 8, 2025
fb15eed
refactor: add some comments
begumcig Oct 8, 2025
c2c1d27
refactor: better initialization for artifact savers
begumcig Oct 23, 2025
ceead22
test: add more dtype tests for artifact saver
begumcig Oct 24, 2025
2329abb
feat: metric modalities as sets
begumcig Oct 24, 2025
34a9e82
refactor: comments tests task modality
begumcig Oct 28, 2025
6e3eb76
feat: add video inference support and seeding strategies to inference…
begumcig Oct 6, 2025
a3f7cbc
feat: remove per evaluation seed and add tests
begumcig Oct 6, 2025
c2259cc
chore: add comments
begumcig Oct 6, 2025
8a340bb
fix: bfloats cannot be moved to cpu error in cmmd metric
begumcig Oct 14, 2025
2441e9d
fix: pre commit file fix
begumcig Oct 14, 2025
52133e1
refactor: configure seeding and tests
begumcig Oct 28, 2025
6c1d741
refactor: algorithm compatibility (#401)
johannaSommer Nov 5, 2025
566ebb9
feat:stratification to vbench datasets
begumcig Nov 27, 2025
b581c31
feat: data stratification by indexing
begumcig Nov 27, 2025
f2059b9
Add image artifactsaver and modified utils to use it for algo sweeper…
Dec 3, 2025
7d92f78
Changes to vbench-utils via ruff
Dec 3, 2025
6a58933
Add filename sanitizer which modifies invalid filename´aliases
Dec 10, 2025
e821e47
Change file format via ruff
Dec 10, 2025
8879195
Add helper function for creating aliases as prompt names for outputs
Dec 10, 2025
40d09bd
Enable TruffleHog in pre-commit (#439)
ParagEkbote Dec 9, 2025
607edef
feat: 0 vbench dimensions and vbench dependencies
begumcig Oct 9, 2025
8287501
Undo limiting prompt length as name for image logging
Dec 22, 2025
c019dd6
Integrate uncommented method from task.py
Dec 22, 2025
4a968af
Update function in evaluation agent for generating metadata-json file
Dec 22, 2025
fc3d21c
Add evaluation-agent parameter for optional JSON metadata creation; i…
Marius-Graml Dec 24, 2025
e3b19ed
Adjust doc string for _maybe_create_input_output_metadat() in evaluat…
Marius-Graml Dec 24, 2025
f177313
Add model role in metadata json when applying pairwise metrics
Marius-Graml Dec 29, 2025
f936d5f
Change name and if logic of json file creation function in evaluation…
Marius-Graml Jan 13, 2026
5a5d3e7
Copy the tests from image-artifactsaver to feat/img-saver-extended to…
Marius-Graml Jan 13, 2026
3ae00ee
Remove unused import
Marius-Graml Jan 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: check-added-large-files
- id: check-case-conflict
Expand All @@ -13,21 +13,45 @@ repos:
- id: trailing-whitespace
exclude: \.md$
- id: no-commit-to-branch

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.12
rev: v0.14.6
hooks:
- id: ruff-check
args: [--fix]
files: \.py$
- id: ruff-format
files: \.py$

- repo: https://github.com/trufflesecurity/trufflehog
rev: v3.91.1
hooks:
- id: trufflehog
name: TruffleHog Secrets Scanner
entry: trufflehog
language: golang
types_or: [python, yaml, json, text]
args:
[
"filesystem",
"src",
"tests",
".github/workflows",
"--results=verified,unknown",
"--exclude-paths=.venv",
"--fail"
]
stages: ["pre-commit", "pre-push"]

- repo: local
hooks:
- id: ty
name: ty check
name: type checking using ty
entry: uvx ty check .
language: system
types: [python]
pass_filenames: false
files: \.py$

- repo: local
hooks:
Expand All @@ -39,7 +63,7 @@ repos:
grep -v "^D" |
cut -f2- |
while IFS= read -r file; do
if [ -f "$file" ] && ["$file" != ".pre-commit-config.yaml"] && grep -q "pruna_pro" "$file"; then
if [ -f "$file" ] && [ "$file" != ".pre-commit-config.yaml" ] && grep -q "pruna_pro" "$file"; then
echo "Error: pruna_pro found in staged file $file"
exit 1
fi
Expand All @@ -48,10 +72,4 @@ repos:
language: system
stages: [pre-commit]
types: [python]
exclude: "^docs/"
- id: trufflehog
name: TruffleHog
description: Detect secrets in your data.
entry: bash -c 'git diff --cached --name-only | xargs -I {} trufflehog filesystem {} --fail --no-update'
language: system
stages: ["pre-commit", "pre-push"]
files: \.py$
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ unsupported-operator = "ignore" # mypy supports | syntax with from __future__ im
invalid-argument-type = "ignore" # mypy is more permissive with argument types
invalid-return-type = "ignore" # mypy is more permissive with return types
invalid-parameter-default = "ignore" # mypy is more permissive with parameter defaults
possibly-missing-attribute = "ignore" # mypy is more permissive with attribute access
possibly-unbound-attribute = "ignore"
possibly-missing-import = "ignore" # mypy is more permissive with imports
no-matching-overload = "ignore" # mypy is more permissive with overloads
unresolved-reference = "ignore" # mypy is more permissive with references
possibly-unbound-import = "ignore"
missing-argument = "ignore"
possibly-unbound-import = "ignore"

[tool.coverage.run]
source = ["src/pruna"]
Expand Down Expand Up @@ -75,6 +78,7 @@ gptqmodel = [
{ index = "pruna_internal", marker = "sys_platform != 'darwin' or platform_machine != 'arm64'"},
{ index = "pypi", marker = "sys_platform == 'darwin' and platform_machine == 'arm64'"},
]
clip = {git = "https://github.com/openai/CLIP.git", rev = "dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1"}

[project]
name = "pruna"
Expand Down Expand Up @@ -186,6 +190,7 @@ dev = [
]
cpu = []


[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand Down
15 changes: 12 additions & 3 deletions src/pruna/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def recover_text_from_dataloader(dataloader: DataLoader, tokenizer: Any) -> list
return texts


def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Dataset:
def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42, partition_strategy: str = "random", partition_index: int = 0) -> Dataset:
"""
Stratify the dataset into a specific size.

Expand All @@ -196,6 +196,10 @@ def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Data
The size to stratify.
seed : int
The seed to use for sampling the dataset.
partition_strategy : str
The strategy to use for partitioning the dataset. Can be "indexed" or "random".
partition_index : int
The index to use for partitioning the dataset.

Returns
-------
Expand All @@ -211,8 +215,13 @@ def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Data
return dataset

indices = list(range(dataset_length))
random.Random(seed).shuffle(indices)
selected_indices = indices[:sample_size]
if partition_strategy == "indexed":
selected_indices = indices[sample_size*partition_index:sample_size*(partition_index+1)]
elif partition_strategy == "random":
random.Random(seed).shuffle(indices)
selected_indices = indices[:sample_size]
else:
raise ValueError(f"Invalid partition strategy: {partition_strategy}")
dataset = dataset.select(selected_indices)
return dataset

Expand Down
56 changes: 40 additions & 16 deletions src/pruna/engine/handler/handler_diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from __future__ import annotations

import inspect
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple

import torch
from torchvision import transforms

from pruna.engine.handler.handler_inference import InferenceHandler
from pruna.logging.logger import pruna_logger
Expand All @@ -28,10 +27,6 @@ class DiffuserHandler(InferenceHandler):
"""
Handle inference arguments, inputs and outputs for diffusers models.

A generator with a fixed seed (42) is passed as an argument to the model for reproducibility.
The first element of the batch is passed as input to the model.
The generated outputs are expected to have .images attribute.

Parameters
----------
call_signature : inspect.Signature
Expand All @@ -40,12 +35,18 @@ class DiffuserHandler(InferenceHandler):
The arguments to pass to the model.
"""

def __init__(self, call_signature: inspect.Signature, model_args: Optional[Dict[str, Any]] = None) -> None:
default_args = {"generator": torch.Generator("cpu").manual_seed(42)}
def __init__(
self,
call_signature: inspect.Signature,
model_args: Optional[Dict[str, Any]] = None,
seed_strategy: Literal["per_sample", "no_seed"] = "no_seed",
global_seed: int | None = None,
) -> None:
self.call_signature = call_signature
if model_args:
default_args.update(model_args)
self.model_args = default_args
self.model_args = model_args if model_args else {}
# We want the default output type to be pytorch tensors.
self.model_args["output_type"] = "pt"
self.configure_seed(seed_strategy, global_seed)

def prepare_inputs(
self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor | dict[str, Any], ...] | dict[str, Any]
Expand Down Expand Up @@ -83,13 +84,36 @@ def process_output(self, output: Any) -> torch.Tensor:
torch.Tensor
The processed images.
"""
generated = output.images
return torch.stack([transforms.PILToTensor()(g) for g in generated])
if hasattr(output, "images"):
generated = output.images
# For video models.
elif hasattr(output, "frames"):
generated = output.frames
else:
# Maybe the user is calling the pipeline with return_dict = False,
# which then returns the generated image / video in a tuple
generated = output[0]
return generated.float()

def log_model_info(self) -> None:
"""Log information about the inference handler."""
pruna_logger.info(
"Detected diffusers model. Using DiffuserHandler with fixed seed.\n"
"- The first element of the batch is passed as input.\n"
"- The generated outputs are expected to have .images attribute."
"Detected diffusers model. Using DiffuserHandler.\n- The first element of the batch is passed as input.\n"
"Inference outputs are expected to have either have an `images` attribute or a `frames` attribute."
"Or be a tuple with the generated image / video as the first element."
)

def set_seed(self, seed: int) -> None:
"""
Set the random seed for the current process.

Parameters
----------
seed : int
The seed to set.
"""
self.model_args["generator"] = torch.Generator("cpu").manual_seed(seed)

def remove_seed(self) -> None:
"""Remove the seed from the current process."""
self.model_args["generator"] = None
76 changes: 75 additions & 1 deletion src/pruna/engine/handler/handler_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

from __future__ import annotations

import random
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Literal, Tuple

import numpy as np
import torch

from pruna.data.utils import move_batch_to_device
Expand Down Expand Up @@ -98,3 +100,75 @@ def move_inputs_to_device(
return move_batch_to_device(inputs, device)
except torch.cuda.OutOfMemoryError as e:
raise e

def configure_seed(self, seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None:
"""
Set the random seed according to the chosen strategy.

- If `seed_strategy="per_sample"`,the `global_seed` is used as a base to derive a different seed for each
sample. This ensures reproducibility while still producing variation across samples,
making it the preferred option for benchmarking.
- If `seed_strategy="no_seed"`, no seed is set internally.
The user is responsible for managing seeds if reproducibility is required.

Parameters
----------
seed_strategy : Literal["per_sample", "no_seed"]
The seeding strategy to apply.
global_seed : int | None
The base seed value to use (if applicable).
"""
self.seed_strategy = seed_strategy
validate_seed_strategy(seed_strategy, global_seed)
if global_seed is not None:
self.global_seed = global_seed
self.set_seed(global_seed)
else:
self.remove_seed()

def set_seed(self, seed: int) -> None:
"""
Set the random seed for the current process.

Parameters
----------
seed : int
The seed to set.
"""
# With the default handler, we can't assume anything about the model,
# so we are setting the seed for all RNGs available.
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)

def remove_seed(self) -> None:
"""Remove the seed from the current process."""
random.seed(None)
np.random.seed(None)
# We can't really remove the seed from the PyTorch RNG, so we are reseeding with torch.seed().
# torch.seed() creates a non-deterministic random number.
torch.manual_seed(torch.seed())
if torch.cuda.is_available():
torch.cuda.manual_seed_all(torch.seed())


def validate_seed_strategy(seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None:
"""
Check the consistency of the seed strategy and the global seed.

If the seed strategy is "no_seed", the global seed must be None.
If the seed strategy is or "per_sample", the user must provide a global seed.

Parameters
----------
seed_strategy : Literal["per_sample", "no_seed"]
The seeding strategy to apply.
global_seed : int | None
The base seed value to use (if applicable).
"""
if seed_strategy != "no_seed" and global_seed is None:
raise ValueError("Global seed must be provided if seed strategy is not 'no_seed'.")
elif global_seed is not None and seed_strategy == "no_seed":
raise ValueError("Seed strategy cannot be 'no_seed' if global seed is provided.")
4 changes: 3 additions & 1 deletion src/pruna/engine/pruna_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pruna.config.smash_config import SmashConfig
from pruna.engine.handler.handler_utils import register_inference_handler
from pruna.engine.load import load_pruna_model, load_pruna_model_from_pretrained
from pruna.engine.load import filter_load_kwargs, load_pruna_model, load_pruna_model_from_pretrained
from pruna.engine.save import save_pruna_model, save_pruna_model_to_hub
from pruna.engine.utils import get_device, get_nn_modules, set_to_eval
from pruna.logging.filter import apply_warning_filter
Expand Down Expand Up @@ -108,6 +108,8 @@ def run_inference(self, batch: Any) -> Any:
)
inference_function = getattr(self, inference_function_name)

self.inference_handler.model_args = filter_load_kwargs(self.model.__call__, self.inference_handler.model_args)

if prepared_inputs is None:
outputs = inference_function(**self.inference_handler.model_args)
elif isinstance(prepared_inputs, dict):
Expand Down
Loading