-
Notifications
You must be signed in to change notification settings - Fork 77
feat: img-saver-extended #490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
… in the metric implementations.
* feat: remove algorithm groups from algorithms folder * feat: simply new algorithm registration to smash space * refactor: add new smash config interface * refactor: remove unused tokenizer name function * refactor: adjust order implementation * feat: add new graph-based path finding for algorithm execution order * tests: add first version of pre-smash-routines tests * tests: narrow down pre-smash routine tests * refactor: rename PRUNA_ALGORITHMS * refactor: enhance algorithm tags * refactor: remove `incompatible` specification * feat: add `smash_config` utility * style: initial fix all linting complaints * tests: adjust test structure to new refactoring * style: address PR comments * fix: conditionally register algorithms * fix: adjust smash config access in algorithms * fix: support older smash configs * fix: handle target module exception * fix: deprecated save/load imports * tests: update to fit recent interface changes * fix: add `global_utils` exception to algorithm registry * fix: extending compatible methods * fix: deprecate old hyperparameter interface properly * tests: add symmetry checks for algorithm order * style: address PR comments * feat: add utility to register custom algorithm * fix: insufficient docstring descriptions * fix: test references to HQQ * style: fix remaining linting errors * style: fix typing error w.r.t. compatibility setter * style: import sorting * fix: return type of registry function * fix: model context docstring * fix: some final bugs * fix: duplicate pyproject.toml key * fix: address cursorbot slander * style: move inline comments * fix: unify registry logic * feat: additional check in algorithm order overwrite * fix: documentation wording * fix: device function patching in tests
… in prime intellect
* update pre-commit * rm redudant filters. * fix nits and whitespacing issues. * Update versions
…nclude batch_idx in metadata
… have one branch for optimization agent and image artifact saver
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment @cursor review or bugbot run to trigger another review on this PR
| import json | ||
| import tempfile | ||
| from pathlib import Path | ||
| from traceback import print_tb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Usually, the data is already a PIL.Image, so we don't need to convert it. | ||
| if isinstance(data, torch.Tensor): | ||
| data = np.transpose(data.cpu().numpy(), (1, 2, 0)) | ||
| data = np.clip(data * 255, 0, 255).astype(np.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uint8 tensors corrupted by unconditional scaling
High Severity
When saving tensor data, the code unconditionally multiplies values by 255, assuming the input is in [0, 1] float range. If a uint8 tensor (values 0-255) is passed, this corrupts the data — for example, a pixel value of 2 becomes 2 * 255 = 510, which clips to 255. Effectively, all non-zero values become 255, turning the image into a near-binary output. The tests only verify file existence, not content correctness.
Additional Locations (1)
| # https://github.com/Vchitect/VBench/blob/dc62783c0fb4fd333249c0b669027fe102696682/evaluate.py#L111 | ||
| # explicitly sets the device to cuda. We respect this here. | ||
| runs_on: List[str] = ["cuda"] | ||
| modality: List[str] = ["video"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VBench metrics use wrong type for modality
Medium Severity
The VBench metrics declare modality: List[str] = ["video"] but the base class StatefulMetric declares modality: set[str]. All other metrics in the codebase (CMMD, PairwiseClipScore, SharpnessMetric) use sets like modality = {IMAGE}. While set.intersection() in validate_and_get_task_modality happens to work with lists, this type inconsistency violates the interface contract and the constant VIDEO from utils.py is not being used.
Additional Locations (1)
| # Test 3D tensor (should fail) | ||
| invalid_tensor = torch.randn(2, 3, 16) | ||
| with pytest.raises(ValueError, match="4 or 5 dimensional"): | ||
| metric.validate_batch(invalid_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests call non-existent validate_batch method
Medium Severity
The test test_vbench_metrics_invalid_tensor_dimensions calls metric.validate_batch(invalid_tensor), but neither VBenchBackgroundConsistency nor VBenchDynamicDegree implements a validate_batch method. The test expects a ValueError with message "4 or 5 dimensional" but will instead raise AttributeError. The test comments also reference validate_batch for 4D tensor conversion, indicating this method was expected but never implemented.
| # So we need to convert the arguments to an EasyDict. | ||
| args_new = EasyDict({"model": model_path, "small": False, "mixed_precision": False, "alternate_corr": False}) | ||
| self.DynamicDegree = PrunaDynamicDegree(args_new, device) | ||
| self.add_state("scores", []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interval parameter silently ignored in VBenchDynamicDegree
Medium Severity
The VBenchDynamicDegree.__init__ accepts an interval parameter via **kwargs but never uses it. The parameter is absorbed into kwargs and passed to the parent __init__ but is not used when creating PrunaDynamicDegree. Tests explicitly pass different interval values (1, 3, 4, 5, 10) expecting different sampling behavior, but all values produce identical results because the parameter has no effect.
| MetricResult | ||
| The final score. | ||
| """ | ||
| score = self.similarity_scores / self.n_samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Division by zero and empty mean return nan
Medium Severity
Both VBench metrics lack validation for empty state before computing results. VBenchBackgroundConsistency.compute() performs self.similarity_scores / self.n_samples where n_samples is initialized to 0, causing division by zero and returning nan. Similarly, VBenchDynamicDegree.compute() calls np.mean(self.scores) where scores is initialized to an empty list, also returning nan. The test test_vbench_metrics_compute_without_updates expects 0.0 in both cases, but will fail due to these missing zero-checks.
Description
Added image artifactsaver and corresponding tests. Also added a small method for the optimization agent in evaluation_agent.py which creates a json file mapping the input prompts for the model to the generated output images. This allows to log the generated images with the corresponding prompt as file name.
Related Issue
/
Type of Change
How Has This Been Tested?
Image artifact saver: Tests implemented
Json-file creation for optimization agent: Checked the created json-file manually
Checklist
Additional Notes
PR is a combination of the branch "algo-sweeper" and "image-artifactsaver". Originally, I started this branch based on Begüm's branch where the video artifactsaver was implemented. It seems that there were a few bugs in this branch which are now also contained in this branch.