From 50cdef31a5e139f005ea820046d0e72452355159 Mon Sep 17 00:00:00 2001 From: Julius Parulek Date: Wed, 20 Aug 2025 22:56:03 +0200 Subject: [PATCH 1/2] Refactor GenKW to represent a single scalar parameter GenKW now models one scalar parameter instead of a group. This enables per-parameter control of `input_source` (e.g. design_matrix or sampled) and allows the `update` flag to be set individually. Key changes: - `genkw.name` now refers to the parameter name. - A new `group_name` attribute holds the group; for fields and surfaces this currently matches `name`. - All GenKW parameters are stored in `ens/SCALARS.parquet`: - rows = realizations - columns = parameters - mandatory `realization` column. - Introduced `ParameterCardinality` to clarify how parameters map to storage: - `one_param_config_per_ensemble_dataset`: one config per param, one dataset per ensemble. - `multiple_configs_per_ensemble_dataset`: one config per group, one dataset per ensemble. - `one_param_config_per_realization_dataset`: one config per group, one dataset per realization. - Removed `transform_function_definitions`; GenKW now only supports a single distribution. - Simplified merging of GenKW with design_matrix: overlapping parameters default to using `design_matrix` as `input_source`. Includes storage migration to version 13. --- src/ert/config/__init__.py | 7 +- src/ert/config/design_matrix.py | 110 ++---- src/ert/config/distribution.py | 14 + src/ert/config/ensemble_config.py | 15 +- src/ert/config/ert_config.py | 52 +-- src/ert/config/gen_kw_config.py | 250 +++++------- src/ert/config/parameter_config.py | 60 ++- src/ert/gui/ertwidgets/models/ertsummary.py | 10 +- .../manage_experiments_panel.py | 23 +- .../manage_experiments/storage_widget.py | 11 +- src/ert/run_models/_create_run_path.py | 3 +- .../run_models/initial_ensemble_run_model.py | 22 +- src/ert/run_models/run_model.py | 24 +- src/ert/run_models/update_run_model.py | 11 +- src/ert/sample_prior.py | 47 ++- src/ert/shared/storage/extraction.py | 19 +- src/ert/storage/local_ensemble.py | 228 ++++++----- src/ert/storage/local_experiment.py | 39 +- src/ert/storage/local_storage.py | 18 +- src/ert/storage/migration/to13.py | 95 +++++ tests/ert/performance_tests/test_analysis.py | 4 +- .../test_obs_and_responses_performance.py | 53 ++- .../cli/analysis/test_design_matrix.py | 22 +- .../ui_tests/cli/analysis/test_es_update.py | 10 +- .../ert/ui_tests/cli/test_field_parameter.py | 10 +- .../gui/test_manage_experiments_tool.py | 31 +- .../ert/unit_tests/analysis/test_es_update.py | 119 +++--- .../unit_tests/config/test_ensemble_config.py | 2 +- .../ert/unit_tests/config/test_ert_config.py | 28 -- .../unit_tests/config/test_gen_kw_config.py | 373 +++++++----------- .../unit_tests/config/test_surface_config.py | 2 +- .../dark_storage/test_http_endpoints.py | 14 +- .../gui/ertwidgets/models/test_ertsummary.py | 36 +- .../gui/ertwidgets/test_ensembleselector.py | 11 +- .../gui/tools/plot/test_plot_api.py | 43 +- .../heat_equationconfig.ert/config.json | 38 +- .../poly_examplepoly.ert/poly.json | 61 +-- .../snake_oilsnake_oil.ert/snake_oil.json | 208 ++++++---- .../heat_equationconfig.ert/config.json | 38 +- .../poly_examplepoly.ert/poly.json | 61 +-- .../snake_oilsnake_oil.ert/snake_oil.json | 208 ++++++---- .../heat_equationconfig.ert/config.json | 38 +- .../poly_examplepoly.ert/poly.json | 61 +-- .../snake_oilsnake_oil.ert/snake_oil.json | 208 ++++++---- .../heat_equationconfig.ert/config.json | 38 +- .../poly_examplepoly.ert/poly.json | 61 +-- .../snake_oilsnake_oil.ert/snake_oil.json | 208 ++++++---- .../run_models/test_base_run_model.py | 21 +- .../test_experiment_serialization.py | 17 +- .../test_design_matrix.py | 106 ++--- .../unit_tests/storage/migration/test_to13.py | 48 +++ .../storage/migration/test_version_1.py | 28 -- .../storage/migration/test_version_2.py | 55 --- .../storage/migration/test_version_3.py | 36 -- .../14.2/design_matrix_snapshot.json | 40 +- .../test_that_storage_matches/parameters | 17 +- .../unit_tests/storage/test_local_storage.py | 73 ++-- .../storage/test_parameter_sample_types.py | 50 +-- .../storage/test_storage_migration.py | 55 +-- .../ert/unit_tests/test_run_path_creation.py | 23 +- 60 files changed, 1790 insertions(+), 1823 deletions(-) create mode 100644 src/ert/storage/migration/to13.py create mode 100644 tests/ert/unit_tests/storage/migration/test_to13.py delete mode 100644 tests/ert/unit_tests/storage/migration/test_version_1.py delete mode 100644 tests/ert/unit_tests/storage/migration/test_version_2.py delete mode 100644 tests/ert/unit_tests/storage/migration/test_version_3.py diff --git a/src/ert/config/__init__.py b/src/ert/config/__init__.py index fdca7267e4f..ae421ae9179 100644 --- a/src/ert/config/__init__.py +++ b/src/ert/config/__init__.py @@ -25,10 +25,10 @@ ForwardModelStepWarning, ) from .gen_data_config import GenDataConfig -from .gen_kw_config import GenKwConfig, PriorDict, TransformFunction +from .gen_kw_config import DataSource, GenKwConfig, PriorDict from .lint_file import lint_file from .model_config import ModelConfig -from .parameter_config import ParameterConfig, ParameterMetadata +from .parameter_config import ParameterCardinality, ParameterConfig, ParameterMetadata from .parsing import ( ConfigValidationError, ConfigWarning, @@ -74,6 +74,7 @@ "AnalysisModule", "ConfigValidationError", "ConfigWarning", + "DataSource", "DesignMatrix", "ESSettings", "EnsembleConfig", @@ -109,6 +110,7 @@ "ObservationSettings", "ObservationType", "OutlierSettings", + "ParameterCardinality", "ParameterConfig", "ParameterMetadata", "PostExperimentFixtures", @@ -125,7 +127,6 @@ "ResponseMetadata", "SummaryConfig", "SurfaceConfig", - "TransformFunction", "WarningInfo", "Workflow", "WorkflowConfigs", diff --git a/src/ert/config/design_matrix.py b/src/ert/config/design_matrix.py index cbd292e257f..ee746c8d0c6 100644 --- a/src/ert/config/design_matrix.py +++ b/src/ert/config/design_matrix.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections import Counter -from collections.abc import Iterable from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, cast @@ -10,12 +9,12 @@ import polars as pl from polars.exceptions import InvalidOperationError -from .gen_kw_config import GenKwConfig, TransformFunctionDefinition +from .distribution import RawSettings +from .gen_kw_config import DataSource, GenKwConfig from .parsing import ConfigValidationError, ErrorInfo if TYPE_CHECKING: from ert.config import ParameterConfig - from ert.storage import Ensemble DESIGN_MATRIX_GROUP = "DESIGN_MATRIX" @@ -32,7 +31,7 @@ def __post_init__(self) -> None: ( self.active_realizations, self.design_matrix_df, - self.parameter_configuration, + self.parameter_configurations, ) = self.read_and_validate_design_matrix() except (ValueError, AttributeError) as exc: raise ConfigValidationError.with_context( @@ -42,26 +41,6 @@ def __post_init__(self) -> None: str(self.xls_filename), ) from exc - def save_to_ensemble( - self, - ensemble: Ensemble, - active_realizations: Iterable[int], - design_group_name: str = DESIGN_MATRIX_GROUP, - ) -> None: - design_matrix_df = self.design_matrix_df - assert not design_matrix_df.is_empty() - if not set(active_realizations) <= set( - design_matrix_df["realization"].to_list() - ): - raise KeyError("Active realization mask is not in design matrix!") - ensemble.save_parameters( - design_group_name, - realization=None, - dataset=design_matrix_df.filter( - pl.col("realization").is_in(list(active_realizations)) - ), - ) - @classmethod def from_config_list(cls, config_list: list[str | dict[str, str]]) -> DesignMatrix: filename = Path(cast(str, config_list[0])) @@ -149,66 +128,44 @@ def merge_with_other(self, dm_other: DesignMatrix) -> None: f"{dm_other.default_sheet or ''})': {exc}!" ) from exc - for tfd in dm_other.parameter_configuration.transform_function_definitions: - if tfd.name not in common_keys: - self.parameter_configuration.transform_function_definitions.append(tfd) + self.parameter_configurations.extend( + cfg + for cfg in dm_other.parameter_configurations + if cfg.name not in common_keys + ) def merge_with_existing_parameters( self, existing_parameters: list[ParameterConfig] - ) -> tuple[list[ParameterConfig], GenKwConfig]: + ) -> list[ParameterConfig]: """ This method merges the design matrix parameters with the existing parameters and - returns the new list of existing parameters, wherein we drop GEN_KW group having - a full overlap with the design matrix group. GEN_KW group that was dropped will - acquire a new name from the design matrix group. Additionally, the - ParameterConfig which is the design matrix group is returned separately. + returns the new list of existing parameters. Args: existing_parameters (List[ParameterConfig]): List of existing parameters - Raises: - ConfigValidationError: If there is a partial overlap between the design - matrix group and any existing GEN_KW group - Returns: - tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters - and the dedicated design matrix group + List[ParameterConfig]: List of new parameters after merge """ - new_param_config: list[ParameterConfig] = [] + new_param_configs: list[ParameterConfig] = [] - design_parameter_group = self.parameter_configuration - design_keys = [e.name for e in design_parameter_group.transform_functions] + design_cfgs = {cfg.name: cfg for cfg in self.parameter_configurations} - design_group_added = False - for parameter_group in existing_parameters: - if not isinstance(parameter_group, GenKwConfig): - new_param_config += [parameter_group] - continue - existing_keys = [e.name for e in parameter_group.transform_functions] - if set(existing_keys) == set(design_keys): - if design_group_added: - raise ConfigValidationError( - "Multiple overlapping groups with design matrix found in " - "existing parameters!\n" - f"{design_parameter_group.name} and {parameter_group.name}" - ) - - design_parameter_group.name = parameter_group.name - design_group_added = True - elif set(design_keys) & set(existing_keys): - raise ConfigValidationError( - "Overlapping parameter names found in design matrix!\n" - f"{DESIGN_MATRIX_GROUP}:{design_keys}\n{parameter_group.name}:{existing_keys}" - "\nThey need to match exactly or not at all." - ) - else: - new_param_config += [parameter_group] - return new_param_config, design_parameter_group + for param_cfg in existing_parameters: + if isinstance(param_cfg, GenKwConfig) and param_cfg.name in design_cfgs: + param_cfg.input_source = DataSource.DESIGN_MATRIX + param_cfg.update = False + param_cfg.distribution = RawSettings() + del design_cfgs[param_cfg.name] + new_param_configs += [param_cfg] + if design_cfgs.values(): + new_param_configs += list(design_cfgs.values()) + return new_param_configs def read_and_validate_design_matrix( self, - ) -> tuple[list[bool], pl.DataFrame, GenKwConfig]: + ) -> tuple[list[bool], pl.DataFrame, list[GenKwConfig]]: # Read the parameter names (first row) as strings to prevent polars from # modifying them. This ensures that duplicate or empty column names are # preserved exactly as they appear in the Excel sheet. By doing this, we @@ -308,23 +265,24 @@ def read_and_validate_design_matrix( design_matrix_df = design_matrix_df.with_row_index(name="realization") design_matrix_df = convert_numeric_string_columns(design_matrix_df) - transform_function_definitions = [ - TransformFunctionDefinition(name=col, param_name="RAW", values=[]) + + parameter_configurations: list[GenKwConfig] = [ + GenKwConfig( + name=col, + update=False, + group=DESIGN_MATRIX_GROUP, + input_source=DataSource.DESIGN_MATRIX, + distribution={"name": "raw"}, + ) for col in design_matrix_df.columns if col != "realization" ] - parameter_configuration = GenKwConfig( - name=DESIGN_MATRIX_GROUP, - forward_init=False, - transform_function_definitions=transform_function_definitions, - update=False, - ) reals = design_matrix_df.get_column("realization").to_list() return ( [x in reals for x in range(max(reals) + 1)], design_matrix_df, - parameter_configuration, + parameter_configurations, ) @staticmethod diff --git a/src/ert/config/distribution.py b/src/ert/config/distribution.py index 78db0e17144..95d38b2fd17 100644 --- a/src/ert/config/distribution.py +++ b/src/ert/config/distribution.py @@ -315,6 +315,20 @@ def transform(self, x: float) -> float: return float(result) +DistributionSettings = ( + UnifSettings + | LogNormalSettings + | LogUnifSettings + | DUnifSettings + | RawSettings + | ConstSettings + | NormalSettings + | TruncNormalSettings + | ErrfSettings + | DerrfSettings + | TriangularSettings +) + DISTRIBUTION_CLASSES: dict[str, type[TransSettingsValidation]] = { "NORMAL": NormalSettings, "LOGNORMAL": LogNormalSettings, diff --git a/src/ert/config/ensemble_config.py b/src/ert/config/ensemble_config.py index feb8891d0e9..99276a0f523 100644 --- a/src/ert/config/ensemble_config.py +++ b/src/ert/config/ensemble_config.py @@ -48,9 +48,6 @@ def set_derived_fields(self) -> Self: [p.name for p in self.parameter_configs.values()], [key for config in self.response_configs.values() for key in config.keys], ) - self._check_for_duplicate_gen_kw_param_names( - [p for p in self.parameter_configs.values() if isinstance(p, GenKwConfig)] - ) return self @@ -69,9 +66,7 @@ def _check_for_duplicate_names( @staticmethod def _check_for_duplicate_gen_kw_param_names(gen_kw_list: list[GenKwConfig]) -> None: - gen_kw_param_count = Counter( - keyword.name for p in gen_kw_list for keyword in p.transform_functions - ) + gen_kw_param_count = Counter(p.name for p in gen_kw_list) duplicate_gen_kw_names = [ (n, c) for n, c in gen_kw_param_count.items() if c > 1 ] @@ -156,12 +151,16 @@ def make_field(field_list: list[str | dict[str, str]]) -> FieldConfig: return FieldConfig.from_config_list(grid_file_path, dims, field_list) + gen_kw_cfgs = [ + cfg for g in gen_kw_list for cfg in GenKwConfig.from_config_list(g) + ] + parameter_configs = ( - [GenKwConfig.from_config_list(g) for g in gen_kw_list] + gen_kw_cfgs + [SurfaceConfig.from_config_list(s) for s in surface_list] + [make_field(f) for f in field_list] ) - + EnsembleConfig._check_for_duplicate_gen_kw_param_names(gen_kw_cfgs) response_configs: list[KnownResponseTypes] = [] for config_cls in _KNOWN_RESPONSE_TYPES: diff --git a/src/ert/config/ert_config.py b/src/ert/config/ert_config.py index c94fd3cdddb..0f0371c0345 100644 --- a/src/ert/config/ert_config.py +++ b/src/ert/config/ert_config.py @@ -30,7 +30,6 @@ make_observation_declarations, ) from .analysis_config import AnalysisConfig -from .design_matrix import DESIGN_MATRIX_GROUP from .ensemble_config import EnsembleConfig from .forward_model_step import ( ForwardModelJSON, @@ -759,10 +758,10 @@ def validate_genkw_parameter_name_overlap(self) -> Self: def validate_dm_parameter_name_overlap(self) -> Self: if not self.analysis_config.design_matrix: return self - dm_param_config = self.analysis_config.design_matrix.parameter_configuration + dm_param_configs = self.analysis_config.design_matrix.parameter_configurations overlapping_parameter_names = [ parameter_definition.name - for parameter_definition in dm_param_config.transform_function_definitions + for parameter_definition in dm_param_configs if f"<{parameter_definition.name}>" in self.substitutions or parameter_definition.name in ErtConfig.RESERVED_KEYWORDS ] @@ -998,39 +997,20 @@ def from_dict(cls, config_dict: ConfigDict) -> Self: raise ConfigValidationError.from_collected(errors) if dm := analysis_config.design_matrix: - dm_errors = [] - dm_params = { - x.name - for x in dm.parameter_configuration.transform_function_definitions - } - for group_name, config in ensemble_config.parameter_configs.items(): - if not isinstance(config, GenKwConfig): - continue - group_params = {x.name for x in config.transform_function_definitions} - if group_name == DESIGN_MATRIX_GROUP: - dm_errors.append( - ConfigValidationError( - f"Cannot have GEN_KW with group name {DESIGN_MATRIX_GROUP} " - "when using DESIGN_MATRIX keyword." - ) - ) - if dm_params == group_params: - ConfigWarning.warn( - f"Parameters {group_params} from GEN_KW group '{group_name}' " - "will be overridden by design matrix. This will cause " - "updates to be turned off for these parameters." - ) - elif intersection := dm_params & group_params: - dm_errors.append( - ConfigValidationError( - "Only full overlaps of design matrix and " - "one genkw group are supported.\n" - f"design matrix parameters: {dm_params}\n" - f"parameters in genkw group <{group_name}>: " - f"{group_params}\n" - f"overlap between them: {intersection}" - ) - ) + dm_errors: list[ErrorInfo | ConfigValidationError] = [] + dm_params = {x.name for x in dm.parameter_configurations} + overwrite_params = [ + cfg.name + for cfg in ensemble_config.parameter_configs.values() + if isinstance(cfg, GenKwConfig) and cfg.name in dm_params + ] + if overwrite_params: + ConfigWarning.warn( + f"Parameters {dm_params} " + "will be overridden by design matrix. This will cause " + "updates to be turned off for these parameters." + ) + if dm_errors: raise ConfigValidationError.from_collected(dm_errors) diff --git a/src/ert/config/gen_kw_config.py b/src/ert/config/gen_kw_config.py index 3588ea62a80..55ea9a1e7f2 100644 --- a/src/ert/config/gen_kw_config.py +++ b/src/ert/config/gen_kw_config.py @@ -2,15 +2,15 @@ import os from collections.abc import Callable, Iterator -from dataclasses import dataclass +from enum import StrEnum from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Literal, Self, cast, overload +from typing import TYPE_CHECKING, Annotated, Literal, Self, cast, overload import networkx as nx import numpy as np import polars as pl import xarray as xr -from pydantic import BaseModel, Field, PrivateAttr, ValidationError, model_validator +from pydantic import Field, ValidationError from typing_extensions import TypedDict from ._str_to_bool import str_to_bool @@ -29,8 +29,8 @@ UnifSettings, get_distribution, ) -from .parameter_config import ParameterConfig, ParameterMetadata -from .parsing import ConfigValidationError, ConfigWarning, ErrorInfo +from .parameter_config import ParameterCardinality, ParameterConfig, ParameterMetadata +from .parsing import ConfigValidationError, ConfigWarning if TYPE_CHECKING: import numpy.typing as npt @@ -60,15 +60,13 @@ def _get_abs_path(file: str | None) -> str | None: return file -class TransformFunctionDefinition(BaseModel): - name: str - param_name: str - values: list[Any] +class DataSource(StrEnum): + DESIGN_MATRIX = "design_matrix" + SAMPLED = "sampled" -@dataclass -class TransformFunction: - name: str +class GenKwConfig(ParameterConfig): + type: Literal["gen_kw"] = "gen_kw" distribution: Annotated[ UnifSettings | LogNormalSettings @@ -83,79 +81,34 @@ class TransformFunction: | TriangularSettings, Field(discriminator="name"), ] - - @property - def parameter_list(self) -> dict[str, float]: - """Return the parameters of the distribution as a dictionary.""" - return self.distribution.model_dump(exclude={"name"}) - - -class GenKwConfig(ParameterConfig): - type: Literal["gen_kw"] = "gen_kw" - transform_function_definitions: list[TransformFunctionDefinition] - - _transform_functions: list[TransformFunction] = PrivateAttr() - - @model_validator(mode="after") - def validate_and_setup_transform_functions(self) -> Self: - transform_functions: list[TransformFunction] = [] - - errors = [] - for e in self.transform_function_definitions: - try: - if isinstance(e, dict): - transform_functions.append( - self._parse_transform_function_definition( - TransformFunctionDefinition(**e) - ) - ) - else: - transform_functions.append( - self._parse_transform_function_definition(e) - ) - except ConfigValidationError as e: - errors.append(e) - - self._transform_functions = transform_functions - - try: - self._validate() - except ConfigValidationError as e: - errors.append(e) - - if errors: - raise ConfigValidationError.from_collected(errors) - - return self + forward_init: bool = False + update: bool = True + group: str = "DEFAULT" + input_source: DataSource = DataSource.SAMPLED def __contains__(self, item: str) -> bool: - return item in [v.name for v in self.transform_function_definitions] + return item == self.name def __len__(self) -> int: - return len(self.transform_functions) - - @property - def transform_functions(self) -> list[TransformFunction]: - return self._transform_functions + return 1 @property def parameter_keys(self) -> list[str]: - keys = [] - for tf in self.transform_functions: - keys.append(tf.name) + return [self.name] - return keys + @property + def data_cardinality(self) -> ParameterCardinality: + return ParameterCardinality.one_param_config_per_ensemble_dataset @property def metadata(self) -> list[ParameterMetadata]: return [ ParameterMetadata( - key=f"{self.name}:{tf.name}", - transformation=tf.distribution.name.upper(), + key=f"{self.group}:{self.name}", + transformation=self.distribution.name.upper(), dimensionality=1, userdata={"data_origin": "GEN_KW"}, ) - for tf in self.transform_functions ] @classmethod @@ -191,7 +144,7 @@ def templates_from_config( return None @classmethod - def from_config_list(cls, gen_kw: list[str | dict[str, str]]) -> Self: + def from_config_list(cls, gen_kw: list[str | dict[str, str]]) -> list[Self]: gen_kw_key = cast(str, gen_kw[0]) options = cast(dict[str, str], gen_kw[-1]) @@ -217,7 +170,7 @@ def from_config_list(cls, gen_kw: list[str | dict[str, str]]) -> Self: f"Unexpected positional arguments: {positional_args}" ) - transform_function_definitions: list[TransformFunctionDefinition] = [] + distributions_spec: list[list[str]] = [] for line_number, item in enumerate(parameter_file_contents.splitlines()): item = item.split("--")[0] # remove comments if item.strip(): # only lines with content @@ -231,14 +184,9 @@ def from_config_list(cls, gen_kw: list[str | dict[str, str]]) -> Self: ) ) else: - transform_function_definitions.append( - TransformFunctionDefinition( - name=items[0], - param_name=items[1], - values=items[2:], - ) - ) - if not transform_function_definitions: + distributions_spec.append(items) + + if not distributions_spec: errors.append( ConfigValidationError.with_context( f"No parameters specified in {parameter_file_context}", @@ -257,35 +205,29 @@ def from_config_list(cls, gen_kw: list[str | dict[str, str]]) -> Self: gen_kw_key, ) try: - return cls( - name=gen_kw_key, - forward_init=False, - transform_function_definitions=transform_function_definitions, - update=update_parameter, - ) + return [ + cls( + name=params[0], + group=gen_kw_key, + distribution=GenKwConfig._parse_distribution( + params[0], params[1], params[2:] + ), + forward_init=False, + update=update_parameter, + ) + for params in distributions_spec + ] + except ConfigValidationError as e: + raise ConfigValidationError.from_collected( + [err.set_context(gen_kw_key) for err in e.errors] + ) from e except ValidationError as e: raise ConfigValidationError.from_pydantic(e, gen_kw) from e - def _validate(self) -> None: - errors = [] - unique_keys = set() - for prior in self.get_priors(): - key = prior["key"] - if key in unique_keys: - errors.append( - ErrorInfo( - f"Duplicate GEN_KW keys {key!r} found, keys must be unique." - ).set_context(self.name) - ) - unique_keys.add(key) - - if errors: - raise ConfigValidationError.from_collected(errors) - def load_parameter_graph(self) -> nx.Graph[int]: # Create a graph with no edges graph_independence: nx.Graph[int] = nx.Graph() - graph_independence.add_nodes_from(range(len(self.transform_functions))) + graph_independence.add_nodes_from([0]) return graph_independence def read_from_runpath( @@ -307,15 +249,13 @@ def write_to_runpath( ) assert isinstance(df, pl.DataFrame) - if not df.width == len(self.transform_functions): + if not df.width == 1: raise ValueError( - f"The configuration of GEN_KW parameter {self.name}" - f" has {len(self.transform_functions)} parameters, but ensemble dataset" - f" for realization {real_nr} has {df.width} parameters." + f"GEN_KW {self.group_name}:{self.name} should be a single parameter!" ) data = df.to_dicts()[0] - return {self.name: data} + return {self.group_name: data} def load_parameters( self, ensemble: Ensemble, realizations: npt.NDArray[np.int_] @@ -337,15 +277,15 @@ def create_storage_datasets( pl.DataFrame( { "realization": iens_active_index, + self.name: pl.Series(from_data.flatten()), } - ).with_columns( - [ - pl.Series(from_data[i, :]).alias(param_name.name) - for i, param_name in enumerate(self.transform_functions) - ] ), ) + @property + def group_name(self) -> str: + return self.group + def copy_parameters( self, source_ensemble: Ensemble, @@ -353,69 +293,69 @@ def copy_parameters( realizations: npt.NDArray[np.int_], ) -> None: df = source_ensemble.load_parameters(self.name, realizations) - target_ensemble.save_parameters(self.name, realization=None, dataset=df) + target_ensemble.save_parameters(dataset=df) def get_priors(self) -> list[PriorDict]: - priors: list[PriorDict] = [] - for tf in self.transform_functions: - priors.append( - { - "key": tf.name, - "function": tf.distribution.name.upper(), - "parameters": { - k.upper(): v - for k, v in tf.parameter_list.items() - if k != "name" - }, - } - ) - return priors + dist_json = self.distribution.model_dump(exclude={"name"}) + return [ + { + "key": self.name, + "function": self.distribution.name.upper(), + "parameters": {k.upper(): v for k, v in dist_json.items()}, + } + ] - def transform_col(self, param_name: str) -> Callable[[float], float]: - tf: TransformFunction | None = None - for tf in self.transform_functions: - if tf.name == param_name: - break - assert tf is not None, f"Transform function {param_name} not found" - return tf.distribution.transform + def transform_data(self) -> Callable[[float], float]: + return self.distribution.transform - def _parse_transform_function_definition( - self, - t: TransformFunctionDefinition, - ) -> TransformFunction: - if t.param_name not in DISTRIBUTION_CLASSES: + @classmethod + def _parse_distribution( + cls, param_name: str, dist_name: str, values: list[str] + ) -> ( + UnifSettings + | LogNormalSettings + | LogUnifSettings + | DUnifSettings + | RawSettings + | ConstSettings + | NormalSettings + | TruncNormalSettings + | ErrfSettings + | DerrfSettings + | TriangularSettings + ): + if dist_name not in DISTRIBUTION_CLASSES: raise ConfigValidationError( - f"Unknown distribution provided: {t.param_name}, for variable {t.name}", - self.name, + f"Unknown distribution provided: {dist_name}" + f", for variable {param_name}", + param_name, ) + dist_cls = DISTRIBUTION_CLASSES[dist_name] - cls = DISTRIBUTION_CLASSES[t.param_name] - - if len(t.values) != len(cls.get_param_names()): + if len(values) != len(dist_cls.get_param_names()): raise ConfigValidationError.with_context( - f"Incorrect number of values: {t.values}, provided for variable " - f"{t.name} with distribution {t.param_name}.", - self.name, + f"Incorrect number of values: {values}, provided for variable " + f"{param_name} with distribution {dist_name}.", + param_name, ) param_floats = [] - for p in t.values: + for p in values: try: param_floats.append(float(p)) except ValueError as e: raise ConfigValidationError.with_context( f"Unable to convert '{p}' to float number for variable " - f"{t.name} with distribution {t.param_name}.", - self.name, + f"{param_name} with distribution {dist_name}.", + param_name, ) from e try: - dist = get_distribution(t.param_name, param_floats) + dist = get_distribution(dist_name, param_floats) except ValidationError as e: error_to_raise = ConfigValidationError.from_pydantic( - error=e, context=self.name + error=e, context=param_name ) for error_info in error_to_raise.errors: - error_info.message += f" parameter {t.name}" + error_info.message += f" parameter {param_name}" raise error_to_raise from e - - return TransformFunction(name=t.name, distribution=dist) + return dist diff --git a/src/ert/config/parameter_config.py b/src/ert/config/parameter_config.py index 1b1ae87fd51..30b4c3040f0 100644 --- a/src/ert/config/parameter_config.py +++ b/src/ert/config/parameter_config.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import abstractmethod -from collections.abc import Iterator +from collections.abc import Callable, Iterator +from enum import StrEnum, auto from hashlib import sha256 from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -18,6 +19,23 @@ from ert.storage import Ensemble +class ParameterCardinality(StrEnum): + """ + one_param_config_per_ensemble_dataset: one config instance per single param, one + dataset per ensemble + + multiple_configs_per_ensemble_dataset: one config instance per group, one + dataset per ensemble + + one_param_config_per_realization_dataset: one config instance per group, one + dataset per realization + """ + + one_param_config_per_ensemble_dataset = auto() + multiple_configs_per_ensemble_dataset = auto() + one_param_config_per_realization_dataset = auto() + + class ParameterMetadata(BaseModel): key: str transformation: str | None @@ -102,7 +120,7 @@ def copy_parameters( # Converts to standard python scalar due to mypy realization_int = int(realization) ds = source_ensemble.load_parameters(self.name, realization_int) - target_ensemble.save_parameters(self.name, realization_int, ds) + target_ensemble.save_parameters(ds, self.name, realization_int) @abstractmethod def load_parameters( @@ -120,12 +138,23 @@ def load_parameter_graph(self) -> nx.Graph[int]: Often a neighbourhood graph. """ + @property + def data_cardinality(self) -> ParameterCardinality: + return ParameterCardinality.one_param_config_per_realization_dataset + def save_experiment_data( self, experiment_path: Path, ) -> None: pass + @property + def group_name(self) -> str: + return self.name + + def transform_data(self) -> Callable[[float], float]: + return lambda x: x + def sample_value( self, global_seed: str, @@ -156,18 +185,15 @@ def sample_value( before generating a single sample, enhancing efficiency by avoiding the generation of large, unused sample sets. """ - parameter_values = [] - for key in self.parameter_keys: - key_hash = sha256( - global_seed.encode("utf-8") + f"{self.name}:{key}".encode() - ) - seed = np.frombuffer(key_hash.digest(), dtype="uint32") - rng = np.random.default_rng(seed) - - # Advance the RNG state to the realization point - rng.standard_normal(realization) - - # Generate a single sample - value = rng.standard_normal(1) - parameter_values.append(value[0]) - return np.array(parameter_values) + key_hash = sha256( + global_seed.encode("utf-8") + f"{self.group_name}:{self.name}".encode() + ) + seed = np.frombuffer(key_hash.digest(), dtype="uint32") + rng = np.random.default_rng(seed) + + # Advance the RNG state to the realization point + rng.standard_normal(realization) + + # Generate a single sample + value = rng.standard_normal(1) + return np.array([value[0]]) diff --git a/src/ert/gui/ertwidgets/models/ertsummary.py b/src/ert/gui/ertwidgets/models/ertsummary.py index b0b67c6e309..797628668e9 100644 --- a/src/ert/gui/ertwidgets/models/ertsummary.py +++ b/src/ert/gui/ertwidgets/models/ertsummary.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from typing_extensions import TypedDict from ert.config import ErtConfig, Field, GenKwConfig, SurfaceConfig @@ -17,6 +19,7 @@ def getForwardModels(self) -> list[str]: def getParameters(self) -> tuple[list[str], int]: parameters = [] + genkw_groups: dict[str, int] = defaultdict(int) count = 0 for ( key, @@ -24,14 +27,17 @@ def getParameters(self) -> tuple[list[str], int]: ) in self.ert_config.ensemble_config.parameter_configs.items(): match config: case GenKwConfig(): - parameters.append(f"{key} ({len(config)})") - count += len(config) + genkw_groups[config.group_name] += 1 + count += 1 case Field(nx=nx, ny=ny, nz=nz): parameters.append(f"{key} ({nx}, {ny}, {nz})") count += len(config) case SurfaceConfig(ncol=ncol, nrow=nrow): parameters.append(f"{key} ({ncol}, {nrow})") count += len(config) + parameters += [ + f"{group_name} ({cnt})" for group_name, cnt in genkw_groups.items() + ] return sorted(parameters, key=lambda k: k.lower()), count def getObservations(self) -> list[ObservationCount]: diff --git a/src/ert/gui/tools/manage_experiments/manage_experiments_panel.py b/src/ert/gui/tools/manage_experiments/manage_experiments_panel.py index 2b66341e36b..173b1e7126d 100644 --- a/src/ert/gui/tools/manage_experiments/manage_experiments_panel.py +++ b/src/ert/gui/tools/manage_experiments/manage_experiments_panel.py @@ -93,14 +93,13 @@ def _add_initialize_from_scratch_tab(self) -> None: center_layout = QHBoxLayout() design_matrix = self.ert_config.analysis_config.design_matrix parameters_config = self.ert_config.ensemble_config.parameter_configuration - design_matrix_group = None realizations: Collection[int] = range( self.ert_config.runpath_config.num_realizations ) if design_matrix is not None: try: - parameters_config, design_matrix_group = ( - design_matrix.merge_with_existing_parameters(parameters_config) + parameters_config = design_matrix.merge_with_existing_parameters( + parameters_config ) realizations = [ real @@ -119,8 +118,8 @@ def _add_initialize_from_scratch_tab(self) -> None: return parameter_model = SelectableListModel( - [p.name for p in parameters_config] + [design_matrix_group.name] - if design_matrix_group + [p.name for p in parameters_config] + if design_matrix else self.ert_config.ensemble_config.parameters ) parameter_check_list = CheckList(parameter_model, "Parameters") @@ -145,22 +144,14 @@ def initialize_from_scratch(_: bool) -> None: active_realizations = [int(i) for i in members_model.getSelectedItems()] with self.notifier.write_storage() as storage: - if ( - design_matrix is not None - and design_matrix_group is not None - and design_matrix_group.name in parameters - ): - parameters.remove(design_matrix_group.name) - design_matrix.save_to_ensemble( - storage.get_ensemble(ensemble_selector.currentData()), - active_realizations, - design_group_name=design_matrix_group.name, - ) sample_prior( ensemble=storage.get_ensemble(ensemble_selector.currentData()), active_realizations=active_realizations, parameters=parameters, random_seed=self.ert_config.random_seed, + design_matrix_df=( + design_matrix.design_matrix_df if design_matrix else None + ), ) @Slot() diff --git a/src/ert/gui/tools/manage_experiments/storage_widget.py b/src/ert/gui/tools/manage_experiments/storage_widget.py index c1da99bd171..6a3718b5966 100644 --- a/src/ert/gui/tools/manage_experiments/storage_widget.py +++ b/src/ert/gui/tools/manage_experiments/storage_widget.py @@ -169,11 +169,10 @@ def _addItem(self) -> None: if create_experiment_dialog.exec(): parameters_config = self._ert_config.ensemble_config.parameter_configuration design_matrix = self._ert_config.analysis_config.design_matrix - design_matrix_group = None if design_matrix is not None: try: - parameters_config, design_matrix_group = ( - design_matrix.merge_with_existing_parameters(parameters_config) + parameters_config = design_matrix.merge_with_existing_parameters( + parameters_config ) except ConfigValidationError as exc: QMessageBox.warning( @@ -188,11 +187,7 @@ def _addItem(self) -> None: try: with self._notifier.write_storage() as storage: ensemble = storage.create_experiment( - parameters=( - [*parameters_config, design_matrix_group] - if design_matrix_group is not None - else parameters_config - ), + parameters=parameters_config, responses=self._ert_config.ensemble_config.response_configuration, observations=self._ert_config.observations, name=create_experiment_dialog.experiment_name, diff --git a/src/ert/run_models/_create_run_path.py b/src/ert/run_models/_create_run_path.py index beb61840d27..0776f8ff383 100644 --- a/src/ert/run_models/_create_run_path.py +++ b/src/ert/run_models/_create_run_path.py @@ -118,7 +118,8 @@ def _generate_parameter_files( continue export_values = param.write_to_runpath(Path(run_path), iens, fs) if export_values: - exports.update(export_values) + for group, vals in export_values.items(): + exports.setdefault(group, {}).update(vals) continue _value_export_txt(run_path, export_base_name, exports) diff --git a/src/ert/run_models/initial_ensemble_run_model.py b/src/ert/run_models/initial_ensemble_run_model.py index 8d7c9caf4da..abb16383d7c 100644 --- a/src/ert/run_models/initial_ensemble_run_model.py +++ b/src/ert/run_models/initial_ensemble_run_model.py @@ -56,17 +56,14 @@ def _sample_and_evaluate_ensemble( rerun_failed_realizations: bool = False, ensemble_storage: LocalEnsemble | None = None, ) -> LocalEnsemble: - parameters_config, design_matrix, design_matrix_group = ( - self._merge_parameters_from_design_matrix( - cast(list[ParameterConfig], self.parameter_configuration), - self.design_matrix, - rerun_failed_realizations, - ) + parameters_config, design_matrix = self._merge_parameters_from_design_matrix( + cast(list[ParameterConfig], self.parameter_configuration), + self.design_matrix, + rerun_failed_realizations, ) if ensemble_storage is None: experiment_storage = self._storage.create_experiment( - parameters=parameters_config - + ([design_matrix_group] if design_matrix_group else []), + parameters=parameters_config, observations=self.observations, responses=cast(list[ResponseConfig], self.response_configuration), simulation_arguments=simulation_arguments, @@ -78,12 +75,6 @@ def _sample_and_evaluate_ensemble( ensemble_size=self.ensemble_size, name=ensemble_name, ) - if design_matrix_group is not None and design_matrix is not None: - design_matrix.save_to_ensemble( - ensemble_storage, - np.where(self.active_realizations)[0], - design_matrix_group.name, - ) if hasattr(self, "_ensemble_id"): setattr(self, "_ensemble_id", ensemble_storage.id) # noqa: B010 @@ -96,6 +87,9 @@ def _sample_and_evaluate_ensemble( np.where(self.active_realizations)[0], parameters=[param.name for param in parameters_config], random_seed=self.random_seed, + design_matrix_df=design_matrix.design_matrix_df + if design_matrix is not None + else None, ) prior_args = create_run_arguments( diff --git a/src/ert/run_models/run_model.py b/src/ert/run_models/run_model.py index 6013addecc6..9db54e51381 100644 --- a/src/ert/run_models/run_model.py +++ b/src/ert/run_models/run_model.py @@ -30,7 +30,6 @@ ConfigValidationError, DesignMatrix, ForwardModelStep, - GenKwConfig, HookedWorkflowFixtures, HookRuntime, ModelConfig, @@ -238,12 +237,14 @@ def log_at_startup(self) -> None: "run_model": self.name(), "num_realizations": self.runpath_config.num_realizations, "num_active_realizations": self.active_realizations.count(True), - "num_parameters": sum( - len(param_config.parameter_keys) - for param_config in self.parameter_configuration - ) - if hasattr(self, "parameter_configuration") - else "NA", + "num_parameters": ( + sum( + len(param_config.parameter_keys) + for param_config in self.parameter_configuration + ) + if hasattr(self, "parameter_configuration") + else "NA" + ), "localization": getattr( settings_dict.get("analysis_settings", {}), "localization", "NA" ), @@ -824,17 +825,16 @@ def _merge_parameters_from_design_matrix( parameters_config: list[ParameterConfig], design_matrix: DesignMatrix | None, rerun_failed_realizations: bool, - ) -> tuple[list[ParameterConfig], DesignMatrix | None, GenKwConfig | None]: - design_matrix_group = None + ) -> tuple[list[ParameterConfig], DesignMatrix | None]: # If a design matrix is present, we try to merge design matrix parameters # to the experiment parameters and set new active realizations if design_matrix is not None and not rerun_failed_realizations: try: - parameters_config, design_matrix_group = ( - design_matrix.merge_with_existing_parameters(parameters_config) + parameters_config = design_matrix.merge_with_existing_parameters( + parameters_config ) except ConfigValidationError as exc: raise ErtRunError(str(exc)) from exc - return parameters_config, design_matrix, design_matrix_group + return parameters_config, design_matrix diff --git a/src/ert/run_models/update_run_model.py b/src/ert/run_models/update_run_model.py index b8c6bf0b870..124123415fb 100644 --- a/src/ert/run_models/update_run_model.py +++ b/src/ert/run_models/update_run_model.py @@ -14,7 +14,6 @@ from ert.config import ( DesignMatrix, ESSettings, - GenKwConfig, HookRuntime, ObservationSettings, ParameterConfig, @@ -171,11 +170,9 @@ def _merge_parameters_from_design_matrix( parameters_config: list[ParameterConfig], design_matrix: DesignMatrix | None, rerun_failed_realizations: bool, - ) -> tuple[list[ParameterConfig], DesignMatrix | None, GenKwConfig | None]: - parameters_config, design_matrix, design_matrix_group = ( - super()._merge_parameters_from_design_matrix( - parameters_config, design_matrix, rerun_failed_realizations - ) + ) -> tuple[list[ParameterConfig], DesignMatrix | None]: + parameters_config, design_matrix = super()._merge_parameters_from_design_matrix( + parameters_config, design_matrix, rerun_failed_realizations ) if design_matrix and not any(p.update for p in parameters_config): @@ -183,4 +180,4 @@ def _merge_parameters_from_design_matrix( "No parameters to update as all parameters were set to update:false!" ) - return parameters_config, design_matrix, design_matrix_group + return parameters_config, design_matrix diff --git a/src/ert/sample_prior.py b/src/ert/sample_prior.py index 947d72c1bc4..94363e5c91d 100644 --- a/src/ert/sample_prior.py +++ b/src/ert/sample_prior.py @@ -8,7 +8,7 @@ from ert.utils import log_duration -from .config import GenKwConfig +from .config import DataSource, GenKwConfig from .storage import Ensemble logger = logging.getLogger(__name__) @@ -22,6 +22,7 @@ def sample_prior( active_realizations: Iterable[int], random_seed: int, parameters: list[str] | None = None, + design_matrix_df: pl.DataFrame | None = None, ) -> None: """This function is responsible for getting the prior into storage, in the case of GEN_KW we sample the data and store it, and if INIT_FILES @@ -41,23 +42,41 @@ def sample_prior( f"for realizations {active_realizations}" ) if isinstance(config_node, GenKwConfig): - datasets = [ - Ensemble.sample_parameter( - config_node, - realization_nr, - random_seed=random_seed, - ) - for realization_nr in active_realizations - ] - if datasets: + dataset: pl.DataFrame | None = None + if ( + config_node.input_source == DataSource.DESIGN_MATRIX + and design_matrix_df is not None + ): + cols = {"realization", config_node.name} + missing = cols - set(design_matrix_df.columns) + if missing: + raise KeyError( + f"Design matrix is missing column(s): {', '.join(missing)}" + ) + dataset = design_matrix_df.select( + ["realization", config_node.name] + ).filter(pl.col("realization").is_in(list(active_realizations))) + if dataset.is_empty(): + raise KeyError("Active realization mask is not in design matrix!") + elif config_node.input_source == DataSource.SAMPLED: + datasets = [ + Ensemble.sample_parameter( + config_node, + realization_nr, + random_seed=random_seed, + ) + for realization_nr in active_realizations + ] + if datasets: + dataset = pl.concat(datasets, how="vertical") + + if dataset is not None: ensemble.save_parameters( - parameter, - realization=None, - dataset=pl.concat(datasets, how="vertical"), + dataset=dataset, ) else: for realization_nr in active_realizations: ds = config_node.read_from_runpath(Path(), realization_nr, 0) - ensemble.save_parameters(parameter, realization_nr, ds) + ensemble.save_parameters(ds, parameter, realization_nr) ensemble.refresh_ensemble_state() diff --git a/src/ert/shared/storage/extraction.py b/src/ert/shared/storage/extraction.py index 488551ba615..8273f3e5650 100644 --- a/src/ert/shared/storage/extraction.py +++ b/src/ert/shared/storage/extraction.py @@ -25,18 +25,11 @@ def create_priors( ) -> Mapping[str, dict[str, str | float]]: priors_dict = {} - for group, priors in experiment.parameter_configuration.items(): - if isinstance(priors, GenKwConfig): - for func in priors.transform_functions: - prior: dict[str, str | float] = { - "function": _PRIOR_NAME_MAP[func.distribution.name.upper()], - } - for name, value in func.parameter_list.items(): - # Libres calls it steps, but normal stats uses bins - if name == "STEPS": - name = "bins" - prior[name.lower()] = value - - priors_dict[f"{group}:{func.name}"] = prior + for param in experiment.parameter_configuration.values(): + if isinstance(param, GenKwConfig): + prior: dict[str, str | float] = { + "function": _PRIOR_NAME_MAP[param.distribution.name.upper()], + } + priors_dict[f"{param.group}:{param.name}"] = prior return priors_dict diff --git a/src/ert/storage/local_ensemble.py b/src/ert/storage/local_ensemble.py index a31502e10e2..a2c6175465b 100644 --- a/src/ert/storage/local_ensemble.py +++ b/src/ert/storage/local_ensemble.py @@ -19,7 +19,7 @@ from pydantic import BaseModel from typing_extensions import TypedDict -from ert.config import GenKwConfig, ParameterConfig +from ert.config import ParameterCardinality, ParameterConfig from ert.config.response_config import InvalidResponseFile from .load_status import LoadResult @@ -40,6 +40,9 @@ class EverestRealizationInfo(TypedDict): perturbation: int # -1 means it stems from unperturbed controls +SCALAR_FILENAME = "SCALAR" + + class _Index(BaseModel): id: UUID experiment_id: UUID @@ -255,23 +258,29 @@ def get_realization_mask_with_responses(self) -> npt.NDArray[np.bool_]: @cached_property def _existing_scalars(self) -> dict[str, list[int]]: - genkw_mask: dict[str, list[int]] = {} - for parameter in self.experiment.parameter_configuration.values(): - if isinstance(parameter, GenKwConfig): - genkw_mask[parameter.name] = [] - group_path = ( - self.mount_point / f"{_escape_filename(parameter.name)}.parquet" - ) - if group_path.exists(): - genkw_mask[parameter.name] = ( - pl.scan_parquet(group_path) - .select("realization") - .unique() - .collect() # only fetching reals, so use standard engine - .get_column("realization") - .to_list() - ) - return genkw_mask + group_path = self.mount_point / f"{_escape_filename(SCALAR_FILENAME)}.parquet" + genkw_mask: dict[str, list[int]] = { + param: [] + for param in self.experiment.parameter_configuration + if self.experiment.param_cardinality[param] + == ParameterCardinality.one_param_config_per_ensemble_dataset + } + if not group_path.exists(): + return genkw_mask + df = pl.scan_parquet(group_path) + cols = df.collect_schema().names() + real = ( + df.select("realization") + .unique() + .collect() + .get_column("realization") + .to_list() + ) + return { + param: real + for param in cols + if param != "realization" and param in genkw_mask + } def has_data(self) -> bool: """ @@ -543,6 +552,39 @@ def _load_parameters_lazy( df = pl.scan_parquet(group_path) return df + def _load_scalar_keys( + self, + keys: list[str], + realizations: int | npt.NDArray[np.int_] | None = None, + transformed: bool = False, + ) -> pl.DataFrame: + df_lazy = self._load_parameters_lazy(SCALAR_FILENAME) + df_lazy = df_lazy.select(["realization", *keys]) + if realizations is not None: + if isinstance(realizations, int): + realizations = np.array([realizations]) + df_lazy = df_lazy.filter(pl.col("realization").is_in(realizations)) + df = df_lazy.collect(engine="streaming") + if df.is_empty(): + raise IndexError( + f"No matching realizations {realizations} found for {keys}" + ) + + if transformed: + df = df.with_columns( + [ + pl.col(col) + .map_elements( + self.experiment.parameter_configuration[col].transform_data(), + return_dtype=df[col].dtype, + ) + .alias(col) + for col in df.columns + if col != "realization" + ] + ) + return df + def load_parameters( self, group: str, @@ -555,44 +597,45 @@ def load_parameters( otherwise it will return the raw values. """ - if group not in self.experiment.parameter_configuration: + if group not in self.experiment.param_cardinality: raise KeyError(f"{group} is not registered to the experiment.") - config = self.experiment.parameter_configuration[group] - if isinstance(config, GenKwConfig): - df_lazy = self._load_parameters_lazy(group) - if realizations is not None: - if isinstance(realizations, int): - realizations = np.array([realizations]) - df_lazy = df_lazy.filter(pl.col("realization").is_in(realizations)) - df = df_lazy.collect(engine="streaming") - if df.is_empty(): - raise IndexError( - f"No matching realizations {realizations} found for {group}" - ) - if transformed: - df = df.with_columns( - [ - pl.col(col) - .map_elements( - config.transform_col(col), return_dtype=df[col].dtype - ) - .alias(col) - for col in df.columns - if col != "realization" - ] - ) - return df - ds = self._load_dataset( + + if ( + self.experiment.param_cardinality[group] + == ParameterCardinality.multiple_configs_per_ensemble_dataset + ): + return self._load_scalar_keys( + self.experiment.param_groups[group], realizations, transformed + ) + elif ( + self.experiment.param_cardinality[group] + == ParameterCardinality.one_param_config_per_ensemble_dataset + ): + return self._load_scalar_keys([group], realizations, transformed) + return self._load_dataset( group, - realizations - if realizations is not None - else np.flatnonzero(self.get_realization_mask_with_parameters()), + ( + realizations + if realizations is not None + else np.flatnonzero(self.get_realization_mask_with_parameters()) + ), ) - return ds def load_parameters_numpy( self, group: str, realizations: npt.NDArray[np.int_] ) -> npt.NDArray[np.float64]: + if ( + self.experiment.param_cardinality[group] + == ParameterCardinality.multiple_configs_per_ensemble_dataset + ): + return ( + self._load_scalar_keys( + self.experiment.param_groups[group], realizations + ) + .drop("realization") + .to_numpy() + .T.copy() + ) config = self.experiment.parameter_configuration[group] return config.load_parameters(self, realizations) @@ -606,25 +649,30 @@ def save_parameters_numpy( for real, ds in config_node.create_storage_datasets( parameters, iens_active_index ): - self.save_parameters(config_node.name, real, ds) + self.save_parameters(ds, config_node.name, real) def load_scalars( self, group: str | None = None, realizations: npt.NDArray[np.int_] | None = None ) -> pl.DataFrame: dataframes = [] gen_kws = [ - config - for config in self.experiment.parameter_configuration.values() - if isinstance(config, GenKwConfig) + p + for p in self.experiment.parameter_configuration.values() + if p.data_cardinality + == ParameterCardinality.one_param_config_per_ensemble_dataset ] - if group: - gen_kws = [config for config in gen_kws if config.name == group] + if group and group in self.experiment.param_groups: + gen_kws = [ + self.experiment.parameter_configuration[key] + for key in self.experiment.param_groups[group] + ] + for config in gen_kws: df = self.load_parameters(config.name, realizations, transformed=True) assert isinstance(df, pl.DataFrame) df = df.rename( { - col: f"{config.name}:{col}" + col: f"{config.group_name}:{col}" for col in df.columns if col != "realization" } @@ -657,22 +705,16 @@ def sample_parameter( real_nr: int, random_seed: int, ) -> pl.DataFrame: - keys = parameter.parameter_keys - if not keys: - return pl.DataFrame([]) parameter_value = parameter.sample_value( str(random_seed), real_nr, ) - parameter_dict = { - parameter_name: parameter_value[idx] - for idx, parameter_name in enumerate(keys) - } + parameter_dict = {parameter.name: parameter_value[0]} parameter_dict["realization"] = real_nr return pl.DataFrame( parameter_dict, - schema=dict.fromkeys(keys, pl.Float64) | {"realization": pl.Int64}, + schema={parameter.name: pl.Float64, "realization": pl.Int64}, ) def load_responses(self, key: str, realizations: tuple[int, ...]) -> pl.DataFrame: @@ -787,9 +829,9 @@ def load_all_gen_kw_data( @require_write def save_parameters( self, - group: str, - realization: int | None, dataset: xr.Dataset | pl.DataFrame, + group: str | None = None, + realization: int | None = None, ) -> None: """ Saves the provided dataset under a parameter group and realization index(es) @@ -798,32 +840,29 @@ def save_parameters( if isinstance(dataset, pl.DataFrame): try: # since all realizations are saved in a single parquet file, - # this makes sure that we only append new realizations. - df = self._load_parameters_lazy(group) - existing_realizations = ( - df.select("realization") - .unique() - .collect() # only fetch reals, so use standard engine - .get_column("realization") + # this makes sure that we only add / replace new data. + df = self._load_parameters_lazy(SCALAR_FILENAME).collect( + engine="streaming" ) - new_data = dataset.filter( - ~pl.col("realization").is_in(existing_realizations.implode()) + df = df.drop( + [c for c in dataset.columns if c != "realization"], strict=False + ) + df_full = ( + df.join(dataset, on="realization", how="left") + .unique(subset=["realization"], keep="first") + .sort("realization") ) - if new_data.height > 0: - df_full = pl.concat( - [df.collect(), new_data], - # needs all data in memory anyway so using standard engine - how="vertical", - ).sort("realization") - else: - return except KeyError: df_full = dataset - group_path = self.mount_point / f"{_escape_filename(group)}.parquet" + group_path = ( + self.mount_point / f"{_escape_filename(SCALAR_FILENAME)}.parquet" + ) self._storage._to_parquet_transaction(group_path, df_full) return + assert group is not None, "Group must be provided for xarray Dataset" + assert realization is not None, ( "Realization must be provided for xarray Dataset" ) @@ -896,9 +935,6 @@ def save_response( def calculate_std_dev_for_parameter_group( self, parameter_group: str ) -> npt.NDArray[np.float64]: - if parameter_group not in self.experiment.parameter_configuration: - raise ValueError(f"{parameter_group} is not registered to the experiment.") - data = self.load_parameters(parameter_group) if isinstance(data, pl.DataFrame): return data.drop("realization").std().to_numpy().reshape(-1) @@ -925,9 +961,11 @@ def get_response_state( response_configs = self.experiment.response_configuration path = self._realization_dir(realization) return { - e: RealizationStorageState.RESPONSES_LOADED - if (path / f"{e}.parquet").exists() - else RealizationStorageState.UNDEFINED + e: ( + RealizationStorageState.RESPONSES_LOADED + if (path / f"{e}.parquet").exists() + else RealizationStorageState.UNDEFINED + ) for e in response_configs } @@ -1167,9 +1205,11 @@ def all_parameters_and_gen_data(self) -> pl.DataFrame | None: params_wide = pl.concat( [ - pdf.sort("realization").drop("realization") - if i > 0 - else pdf.sort("realization") + ( + pdf.sort("realization").drop("realization") + if i > 0 + else pdf.sort("realization") + ) for i, pdf in enumerate(param_dfs) ], how="horizontal", @@ -1239,7 +1279,7 @@ async def _read_parameters( extra={"Time": f"{(time.perf_counter() - start_time):.4f}s"}, ) start_time = time.perf_counter() - ensemble.save_parameters(config.name, realization, ds) + ensemble.save_parameters(ds, config.name, realization) await asyncio.sleep(0) logger.debug( f"Saved {config.name} to storage", diff --git a/src/ert/storage/local_experiment.py b/src/ert/storage/local_experiment.py index 5b3f8125b00..70d3c91b4e6 100644 --- a/src/ert/storage/local_experiment.py +++ b/src/ert/storage/local_experiment.py @@ -2,6 +2,7 @@ import json import shutil +from collections import defaultdict from collections.abc import Generator from datetime import datetime from functools import cached_property @@ -24,9 +25,8 @@ SummaryConfig, SurfaceConfig, ) -from ert.config import ( - Field as FieldConfig, -) +from ert.config import Field as FieldConfig +from ert.config.parameter_config import ParameterCardinality from ert.config.parsing.context_values import ContextBoolEncoder from .mode import BaseMode, Mode, require_write @@ -53,12 +53,10 @@ class _Index(BaseModel): ] ) -_parameters_adapter = TypeAdapter( - list[ - Annotated[ - (GenKwConfig | SurfaceConfig | FieldConfig | ExtParamConfig), - Field(discriminator="type"), - ] +_parameters_adapter = TypeAdapter( # type: ignore + Annotated[ + (GenKwConfig | SurfaceConfig | FieldConfig | ExtParamConfig), + Field(discriminator="type"), ] ) @@ -377,10 +375,8 @@ def get_surface(self, name: str) -> IrapSurface: @cached_property def parameter_configuration(self) -> dict[str, ParameterConfig]: return { - instance.name: instance - for instance in _parameters_adapter.validate_python( - self.parameter_info.values() - ) + name: _parameters_adapter.validate_python(cfg) + for name, cfg in self.parameter_info.items() } @cached_property @@ -391,6 +387,23 @@ def parameter_keys(self) -> list[str]: return keys + @cached_property + def param_groups(self) -> dict[str, list[str]]: + dict_param: dict[str, list[str]] = defaultdict(list) + for p in self.parameter_configuration.values(): + dict_param[p.group_name].append(p.name) + return dict_param + + @cached_property + def param_cardinality(self) -> dict[str, ParameterCardinality]: + return { + p.name: p.data_cardinality for p in self.parameter_configuration.values() + } | { + p.group_name: ParameterCardinality.multiple_configs_per_ensemble_dataset + for p in self.parameter_configuration.values() + if p.name != p.group_name + } + @cached_property def parameter_group_to_parameter_keys(self) -> dict[str, list[str]]: return { diff --git a/src/ert/storage/local_storage.py b/src/ert/storage/local_storage.py index e1b93c75258..0de78313dc0 100644 --- a/src/ert/storage/local_storage.py +++ b/src/ert/storage/local_storage.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) -_LOCAL_STORAGE_VERSION = 12 +_LOCAL_STORAGE_VERSION = 13 class _Migrations(BaseModel): @@ -487,6 +487,7 @@ def _migrate(self, version: int) -> None: to10, to11, to12, + to13, ) try: @@ -534,7 +535,20 @@ def _migrate(self, version: int) -> None: elif version < _LOCAL_STORAGE_VERSION: migrations = list( enumerate( - [to2, to3, to4, to5, to6, to7, to8, to9, to10, to11, to12], + [ + to2, + to3, + to4, + to5, + to6, + to7, + to8, + to9, + to10, + to11, + to12, + to13, + ], start=1, ) ) diff --git a/src/ert/storage/migration/to13.py b/src/ert/storage/migration/to13.py new file mode 100644 index 00000000000..729525f6ec9 --- /dev/null +++ b/src/ert/storage/migration/to13.py @@ -0,0 +1,95 @@ +import json +import os +from pathlib import Path +from typing import Any + +import polars as pl + +from ert.storage.local_ensemble import _escape_filename + +info = "Convert GenKw group concept into a single parameter" + + +tfd_to_distributions = { + "NORMAL": ["name", "mean", "std"], + "LOGNORMAL": ["name", "mean", "std"], + "UNIFORM": ["name", "min", "max"], + "LOGUNIF": ["name", "min", "max"], + "TRUNCATED_NORMAL": ["name", "mean", "std", "min", "max"], + "RAW": ["name"], + "CONST": ["name", "value"], + "DUNIF": ["name", "steps", "min", "max"], + "TRIANGULAR": ["name", "min", "mode", "max"], + "ERRF": ["name", "min", "max", "skewness", "width"], + "DERRF": ["name", "steps", "min", "max", "skewness", "width"], +} + + +def migrate_gen_kw_param(parameters_json: dict[str, Any]) -> dict[str, Any]: + new_configs = {} + for param_config in parameters_json.values(): + if param_config["type"] == "gen_kw": + group = param_config["name"] + tfds = param_config["transform_function_definitions"] + for tfd in tfds: + dist_type = tfd["param_name"] + keys = tfd_to_distributions[dist_type] + vals = [dist_type.lower()] + tfd["values"] + input_source = ( + "design_matrix" + if tfd["param_name"] == "RAW" and not param_config["update"] + else "sampled" + ) + new_configs[tfd["name"]] = { + "name": tfd["name"], + "type": "gen_kw", + "group": group, + "distribution": dict(zip(keys, vals, strict=False)), + "forward_init": False, + "update": param_config["update"], + "input_source": input_source, + } + else: + new_configs[param_config["name"]] = param_config + return new_configs + + +def migrate_genkw(path: Path) -> None: + for experiment in path.glob("experiments/*"): + ensembles = path.glob("ensembles/*") + + experiment_id = None + with open(experiment / "index.json", encoding="utf-8") as f: + exp_index = json.load(f) + experiment_id = exp_index["id"] + + with open(experiment / "parameter.json", encoding="utf-8") as fin: + parameters_json = json.load(fin) + + new_parameter_configs = migrate_gen_kw_param(parameters_json) + with open(experiment / "parameter.json", "w", encoding="utf-8") as fout: + fout.write(json.dumps(new_parameter_configs, indent=3)) + + # migrate parquet files + for ens in ensembles: + with open(ens / "index.json", encoding="utf-8") as f: + ens_file = json.load(f) + if ens_file["experiment_id"] != experiment_id: + continue + + group_dfs = {} + for param_config in parameters_json.values(): + if param_config["type"] == "gen_kw": + group = param_config["name"] + group_path = ens / f"{_escape_filename(group)}.parquet" + if group_path.exists(): + group_dfs[group] = pl.read_parquet(group_path) + os.remove(group_path) + if group_dfs: + df = pl.concat(list(group_dfs.values()), how="align") + df = df.unique(subset=["realization"], keep="first").sort("realization") + df.write_parquet(ens / "SCALAR.parquet") + + +def migrate(path: Path) -> None: + migrate_genkw(path) diff --git a/tests/ert/performance_tests/test_analysis.py b/tests/ert/performance_tests/test_analysis.py index 5ceb7f3627c..5cda3d1a18c 100644 --- a/tests/ert/performance_tests/test_analysis.py +++ b/tests/ert/performance_tests/test_analysis.py @@ -117,8 +117,6 @@ def g(X): for iens in range(prior_ensemble.ensemble_size): prior_ensemble.save_parameters( - param_group, - iens, xr.Dataset( { "values": xr.DataArray( @@ -127,6 +125,8 @@ def g(X): ), } ), + param_group, + iens, ) prior_ensemble.save_response( diff --git a/tests/ert/performance_tests/test_obs_and_responses_performance.py b/tests/ert/performance_tests/test_obs_and_responses_performance.py index 2a1b958adbf..e574ea55046 100644 --- a/tests/ert/performance_tests/test_obs_and_responses_performance.py +++ b/tests/ert/performance_tests/test_obs_and_responses_performance.py @@ -15,7 +15,6 @@ ObservationSettings, SummaryConfig, ) -from ert.config.gen_kw_config import TransformFunctionDefinition from ert.sample_prior import sample_prior from ert.storage import open_storage @@ -36,7 +35,7 @@ def _add_noise_to_df_values(df: pl.DataFrame): @dataclass class ExperimentInfo: - gen_kw_config: GenKwConfig + gen_kw_configs: list[GenKwConfig] gen_data_config: GenDataConfig summary_config: SummaryConfig summary_observations: pl.DataFrame @@ -55,19 +54,14 @@ def create_experiment_args( num_summary_timesteps: int, num_summary_obs: int, ) -> ExperimentInfo: - gen_kw_config = GenKwConfig( - name="all_my_parameters_live_here", - forward_init=False, - update=True, - transform_function_definitions=[ - TransformFunctionDefinition( - name=f"param_{i}", - param_name="NORMAL", - values=[10, 0.1], - ) - for i in range(num_parameters) - ], - ) + gen_kw_configs = [ + GenKwConfig( + name=f"param_{i}", + group="all_my_parameters_live_here", + distribution={"name": "normal", "mean": 10, "std": 0.1}, + ) + for i in range(num_parameters) + ] gen_data_config = GenDataConfig( name="gen_data", report_steps_list=[list(range(num_gen_data_report_steps))] * num_gen_data_keys, @@ -186,7 +180,7 @@ def create_experiment_args( ) return ExperimentInfo( - gen_kw_config=gen_kw_config, + gen_kw_configs=gen_kw_configs, gen_data_config=gen_data_config, summary_config=summary_config, summary_observations=summary_observations, @@ -383,7 +377,7 @@ def setup_benchmark(tmp_path, request): with open_storage(tmp_path / "storage", mode="w") as storage: experiment = storage.create_experiment( responses=[info.gen_data_config, info.summary_config], - parameters=[info.gen_kw_config], + parameters=info.gen_kw_configs, observations={ "gen_data": info.gen_data_observations, "summary": info.summary_observations, @@ -460,7 +454,7 @@ def setup_es_benchmark(tmp_path, request): with open_storage(tmp_path / "storage", mode="w") as storage: experiment = storage.create_experiment( responses=[info.gen_data_config, info.summary_config], - parameters=[info.gen_kw_config], + parameters=info.gen_kw_configs, observations={ "gen_data": info.gen_data_observations, "summary": info.summary_observations, @@ -483,7 +477,10 @@ def setup_es_benchmark(tmp_path, request): ) sample_prior( - prior, range(config.num_realizations), 42, [info.gen_kw_config.name] + prior, + range(config.num_realizations), + 42, + [c.name for c in info.gen_kw_configs], ) posterior = experiment.create_ensemble( ensemble_size=config.num_realizations, @@ -496,20 +493,20 @@ def setup_es_benchmark(tmp_path, request): alias, prior, posterior, - info.gen_kw_config.name, + [cfg.name for cfg in info.gen_kw_configs], expected_performance, ) @pytest.mark.memory_test def test_memory_performance_of_doing_es_update(setup_es_benchmark, tmp_path): - _, prior, posterior, gen_kw_name, expected_performance = setup_es_benchmark + _, prior, posterior, gen_kw_names, expected_performance = setup_es_benchmark with memray.Tracker(tmp_path / "memray.bin"): smoother_update( prior, posterior, prior.experiment.observation_keys, - [gen_kw_name], + gen_kw_names, ObservationSettings(), ESSettings(), ) @@ -520,7 +517,7 @@ def test_memory_performance_of_doing_es_update(setup_es_benchmark, tmp_path): def test_speed_performance_of_doing_es_update(setup_es_benchmark, benchmark): - alias, prior, posterior, gen_kw_name, _ = setup_es_benchmark + alias, prior, posterior, gen_kw_names, _ = setup_es_benchmark if alias != "small": pytest.skip() @@ -530,7 +527,7 @@ def run(): prior, posterior, prior.experiment.observation_keys, - [gen_kw_name], + gen_kw_names, ObservationSettings(), ESSettings(), ) @@ -540,13 +537,13 @@ def run(): @pytest.mark.memory_test def test_memory_performance_of_doing_enif_update(setup_es_benchmark, tmp_path): - _, prior, posterior, gen_kw_name, expected_performance = setup_es_benchmark + _, prior, posterior, gen_kw_names, expected_performance = setup_es_benchmark with memray.Tracker(tmp_path / "memray.bin"): enif_update( prior, posterior, prior.experiment.observation_keys, - [gen_kw_name], + gen_kw_names, 12345, ) @@ -556,7 +553,7 @@ def test_memory_performance_of_doing_enif_update(setup_es_benchmark, tmp_path): def test_speed_performance_of_doing_enif_update(setup_es_benchmark, benchmark): - alias, prior, posterior, gen_kw_name, _ = setup_es_benchmark + alias, prior, posterior, gen_kw_names, _ = setup_es_benchmark if alias != "small": pytest.skip() @@ -566,7 +563,7 @@ def run(): prior, posterior, prior.experiment.observation_keys, - [gen_kw_name], + gen_kw_names, 123456789, ) diff --git a/tests/ert/ui_tests/cli/analysis/test_design_matrix.py b/tests/ert/ui_tests/cli/analysis/test_design_matrix.py index a476144c6be..2523c783ce3 100644 --- a/tests/ert/ui_tests/cli/analysis/test_design_matrix.py +++ b/tests/ert/ui_tests/cli/analysis/test_design_matrix.py @@ -12,7 +12,7 @@ import pytest from ert.cli.main import ErtCliError -from ert.config import ConfigValidationError, ConfigWarning, ErtConfig +from ert.config import ConfigWarning, ErtConfig from ert.config.design_matrix import DESIGN_MATRIX_GROUP from ert.mode_definitions import ( ENSEMBLE_EXPERIMENT_MODE, @@ -76,16 +76,12 @@ def test_run_poly_example_with_design_matrix(copy_poly_case_with_design_matrix, @pytest.mark.usefixtures("copy_poly_case") @pytest.mark.parametrize( - "default_values, error_msg", + "default_values", [ - ([["b", 1], ["c", 2]], None), - ( - [["b", 1]], - "Only full overlaps of design matrix and one genkw group are supported.", - ), + ([["b", 1], ["c", 2]]), ], ) -def test_run_poly_example_with_design_matrix_and_genkw_merge(default_values, error_msg): +def test_run_poly_example_with_design_matrix_and_genkw_merge(default_values): num_realizations = 10 a_values = list(range(num_realizations)) _create_design_matrix( @@ -162,16 +158,6 @@ def _evaluate(coeffs, x): os.stat("poly_eval.py").st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH, ) - if error_msg: - with pytest.raises(ConfigValidationError, match=error_msg): - run_cli( - ENSEMBLE_EXPERIMENT_MODE, - "--disable-monitoring", - "poly.ert", - "--experiment-name", - "test-experiment", - ) - return with warnings.catch_warnings(): warnings.simplefilter("ignore", category=ConfigWarning) # Expected warning: diff --git a/tests/ert/ui_tests/cli/analysis/test_es_update.py b/tests/ert/ui_tests/cli/analysis/test_es_update.py index 14cf38442a1..d9bde76a29e 100644 --- a/tests/ert/ui_tests/cli/analysis/test_es_update.py +++ b/tests/ert/ui_tests/cli/analysis/test_es_update.py @@ -11,7 +11,6 @@ from ert.analysis._update_commons import _all_parameters from ert.config import ErtConfig, GenKwConfig -from ert.config.gen_kw_config import TransformFunctionDefinition from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE from ert.storage import RealizationStorageState, open_storage from tests.ert.ui_tests.cli.run_cli import run_cli @@ -20,12 +19,9 @@ @pytest.fixture def uniform_parameter(): return GenKwConfig( - name="PARAMETER", - forward_init=False, - transform_function_definitions=[ - TransformFunctionDefinition("KEY1", "UNIFORM", [0, 1]), - ], - update=True, + name="KEY1", + group="PARAMETER", + distribution={"name": "uniform", "min": 0, "max": 1}, ) diff --git a/tests/ert/ui_tests/cli/test_field_parameter.py b/tests/ert/ui_tests/cli/test_field_parameter.py index 488b27688c4..e6a4cb94d8b 100644 --- a/tests/ert/ui_tests/cli/test_field_parameter.py +++ b/tests/ert/ui_tests/cli/test_field_parameter.py @@ -13,9 +13,7 @@ import xtgeo from polars import Float32 -from ert.analysis import ( - smoother_update, -) +from ert.analysis import smoother_update from ert.config import ErtConfig, ESSettings, ObservationSettings from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE from ert.storage import open_storage @@ -460,9 +458,9 @@ def test_field_param_update_using_heat_equation_zero_var_params_and_adaptive_loc ) cond["values"][:, :, :5, 0] = 1.0 for real in range(prior.ensemble_size): - new_prior.save_parameters("COND", real, cond) - new_prior.save_parameters("INIT_TEMP_SCALE", real, init_temp_scale) - new_prior.save_parameters("CORR_LENGTH", real, corr_length) + new_prior.save_parameters(cond, "COND", real) + new_prior.save_parameters(init_temp_scale, "INIT_TEMP_SCALE", real) + new_prior.save_parameters(corr_length, "CORR_LENGTH", real) # Copy responses from existing prior to new prior. # Note that we ideally should generate new responses by running the diff --git a/tests/ert/ui_tests/gui/test_manage_experiments_tool.py b/tests/ert/ui_tests/gui/test_manage_experiments_tool.py index 20779634319..dfb60a3cbd8 100644 --- a/tests/ert/ui_tests/gui/test_manage_experiments_tool.py +++ b/tests/ert/ui_tests/gui/test_manage_experiments_tool.py @@ -45,7 +45,9 @@ def test_design_matrix_in_manage_experiments_panel( with notifier.write_storage() as storage: storage.create_experiment( - parameters=[config.analysis_config.design_matrix.parameter_configuration], + parameters=list( + config.analysis_config.design_matrix.parameter_configurations + ), responses=config.ensemble_config.response_configuration, name="my-experiment", ).create_ensemble( @@ -94,13 +96,9 @@ def test_design_matrix_in_manage_experiments_panel( assert {e.name for e in experiments} == {"my-experiment", "my-experiment-2"} exp2 = notifier.storage.get_experiment_by_name("my-experiment-2") ensemble = exp2.get_ensemble_by_name("my-design-2") - assert "DESIGN_MATRIX" in exp2.parameter_configuration - assert { - t.name - for t in exp2.parameter_configuration[ - "DESIGN_MATRIX" - ].transform_function_definitions - } == {"a", "b", "c"} + for param in exp2.parameter_configuration.values(): + assert param.group_name == "DESIGN_MATRIX" + assert {p.name for p in exp2.parameter_configuration.values()} == {"a", "b", "c"} assert all( RealizationStorageState.UNDEFINED in s for s in ensemble.get_ensemble_state() ) @@ -535,7 +533,16 @@ def test_realization_view( assert {"gen_data - RESPONSES_LOADED", "summary - RESPONSES_LOADED"}.issubset( set(realization_widget._response_text_edit.toPlainText().splitlines()) ) - assert ( - realization_widget._parameter_text_edit.toPlainText() - == "\nSNAKE_OIL_PARAM - PARAMETERS_LOADED\n" - ) + + assert { + "OP1_PERSISTENCE - PARAMETERS_LOADED", + "OP1_OCTAVES - PARAMETERS_LOADED", + "OP1_DIVERGENCE_SCALE - PARAMETERS_LOADED", + "OP1_OFFSET - PARAMETERS_LOADED", + "OP2_PERSISTENCE - PARAMETERS_LOADED", + "OP2_OCTAVES - PARAMETERS_LOADED", + "OP2_DIVERGENCE_SCALE - PARAMETERS_LOADED", + "OP2_OFFSET - PARAMETERS_LOADED", + "BPR_555_PERSISTENCE - PARAMETERS_LOADED", + "BPR_138_PERSISTENCE - PARAMETERS_LOADED", + } == set(realization_widget._parameter_text_edit.toPlainText().strip().splitlines()) diff --git a/tests/ert/unit_tests/analysis/test_es_update.py b/tests/ert/unit_tests/analysis/test_es_update.py index c04b0f032e7..104aa3e233e 100644 --- a/tests/ert/unit_tests/analysis/test_es_update.py +++ b/tests/ert/unit_tests/analysis/test_es_update.py @@ -23,7 +23,6 @@ ObservationSettings, OutlierSettings, ) -from ert.config.gen_kw_config import TransformFunctionDefinition from ert.field_utils import Shape from ert.storage import Ensemble, open_storage @@ -31,14 +30,9 @@ @pytest.fixture def uniform_parameter(): return GenKwConfig( - name="PARAMETER", - forward_init=False, - transform_function_definitions=[ - TransformFunctionDefinition( - name="KEY1", param_name="UNIFORM", values=[0, 1] - ), - ], - update=True, + name="KEY_1", + group="PARAMETER", + distribution={"name": "uniform", "min": 0, "max": 1}, ) @@ -227,14 +221,9 @@ def test_update_handles_precision_loss_in_std_dev(tmp_path): standard deviation. """ gen_kw = GenKwConfig( - name="COEFFS", - forward_init=False, - update=True, - transform_function_definitions=[ - TransformFunctionDefinition( - name="coeff_0", param_name="CONST", values=["0.1"] - ), - ], + name="coeff_0", + group="COEFFS", + distribution={"name": "const", "value": 0.1}, ) # The values given here are chosen so that when computing # `ens_std = S.std(ddof=0, axis=1)`, ens_std[0] is not zero even though @@ -273,13 +262,15 @@ def test_update_handles_precision_loss_in_std_dev(tmp_path): ], ) prior = storage.create_ensemble(experiment.id, ensemble_size=23, name="prior") - for realization_nr in range(prior.ensemble_size): - ds = Ensemble.sample_parameter( + datasets = [ + Ensemble.sample_parameter( gen_kw, realization_nr, random_seed=1234, ) - prior.save_parameters("COEFFS", realization_nr, ds) + for realization_nr in range(prior.ensemble_size) + ] + prior.save_parameters(pl.concat(datasets, how="vertical")) prior.save_response( "gen_data", @@ -326,7 +317,7 @@ def test_update_handles_precision_loss_in_std_dev(tmp_path): prior, posterior, experiment.observation_keys, - ["COEFFS"], + ["coeff_0"], ObservationSettings(auto_scale_observations=[["OBS*"]]), ESSettings(), progress_callback=events.append, @@ -346,14 +337,9 @@ def test_update_raises_on_singular_matrix(tmp_path): standard deviation. """ gen_kw = GenKwConfig( - name="COEFFS", - forward_init=False, - update=True, - transform_function_definitions=[ - TransformFunctionDefinition( - name="coeff_0", param_name="CONST", values=["0.1"] - ), - ], + name="coeff_0", + group="COEFFS", + distribution={"name": "const", "value": 0.1}, ) # The values given here are chosen so that when computing # `ens_std = S.std(ddof=0, axis=1)`, ens_std[0] is not zero even though @@ -392,13 +378,15 @@ def test_update_raises_on_singular_matrix(tmp_path): ], ) prior = storage.create_ensemble(experiment.id, ensemble_size=2, name="prior") - for realization_nr in range(prior.ensemble_size): - ds = Ensemble.sample_parameter( + datasets = [ + Ensemble.sample_parameter( gen_kw, realization_nr, random_seed=1234, ) - prior.save_parameters("COEFFS", realization_nr, ds) + for realization_nr in range(prior.ensemble_size) + ] + prior.save_parameters(pl.concat(datasets, how="vertical")) for i, v in enumerate( [ @@ -441,7 +429,7 @@ def test_update_raises_on_singular_matrix(tmp_path): prior, posterior, experiment.observation_keys, - ["COEFFS"], + ["coeff_0"], ObservationSettings(auto_scale_observations=[["OBS*"]]), ESSettings(), rng=np.random.default_rng(1234), @@ -573,17 +561,16 @@ def test_smoother_snapshot_alpha( name="prior", ) rng = np.random.default_rng(1234) + dataset = [] for iens in range(prior_storage.ensemble_size): data = rng.uniform(0, 1) - prior_storage.save_parameters( - "PARAMETER", - iens, + dataset.append( pl.DataFrame( { "KEY_1": [data], "realization": iens, } - ), + ) ) data = rng.uniform(0.8, 1, 3) prior_storage.save_response( @@ -598,6 +585,8 @@ def test_smoother_snapshot_alpha( ), iens, ) + prior_storage.save_parameters(dataset=pl.concat(dataset, how="vertical")) + posterior_storage = storage.create_ensemble( prior_storage.experiment_id, ensemble_size=prior_storage.ensemble_size, @@ -611,7 +600,7 @@ def test_smoother_snapshot_alpha( prior_storage, posterior_storage, observations=["OBSERVATION"], - parameters=["PARAMETER"], + parameters=["KEY_1"], update_settings=ObservationSettings( outlier_settings=OutlierSettings(alpha=alpha) ), @@ -728,7 +717,7 @@ def test_temporary_parameter_storage_with_inactive_fields( ] for iens in range(ensemble_size): - prior_ensemble.save_parameters(param_group, iens, fields[iens]) + prior_ensemble.save_parameters(fields[iens], param_group, iens) realization_list = list(range(ensemble_size)) param_ensemble_array = prior_ensemble.load_parameters_numpy( @@ -1060,17 +1049,16 @@ def test_gen_data_obs_data_mismatch(storage, uniform_parameter): name="prior", ) rng = np.random.default_rng(1234) + dataset = [] for iens in range(prior.ensemble_size): data = rng.uniform(0, 1) - prior.save_parameters( - "PARAMETER", - iens, + dataset.append( pl.DataFrame( { "KEY_1": [data], "realization": iens, } - ), + ) ) data = rng.uniform(0.8, 1, 3) @@ -1086,6 +1074,8 @@ def test_gen_data_obs_data_mismatch(storage, uniform_parameter): ), iens, ) + + prior.save_parameters(dataset=pl.concat(dataset, how="vertical")) posterior_ens = storage.create_ensemble( prior.experiment_id, ensemble_size=prior.ensemble_size, @@ -1101,7 +1091,7 @@ def test_gen_data_obs_data_mismatch(storage, uniform_parameter): prior, posterior_ens, ["OBSERVATION"], - ["PARAMETER"], + ["KEY_1"], ObservationSettings(), ESSettings(), ) @@ -1122,17 +1112,16 @@ def test_gen_data_missing(storage, uniform_parameter, obs): name="prior", ) rng = np.random.default_rng(1234) + dataset = [] for iens in range(prior.ensemble_size): data = rng.uniform(0, 1) - prior.save_parameters( - "PARAMETER", - iens, + dataset.append( pl.DataFrame( { "KEY_1": [data], "realization": iens, } - ), + ) ) data = rng.uniform(0.8, 1, 2) # Importantly, shorter than obs prior.save_response( @@ -1147,6 +1136,7 @@ def test_gen_data_missing(storage, uniform_parameter, obs): ), iens, ) + prior.save_parameters(dataset=pl.concat(dataset, how="vertical")) posterior_ens = storage.create_ensemble( prior.experiment_id, ensemble_size=prior.ensemble_size, @@ -1160,7 +1150,7 @@ def test_gen_data_missing(storage, uniform_parameter, obs): prior, posterior_ens, ["OBSERVATION"], - ["PARAMETER"], + ["KEY_1"], ObservationSettings(), ESSettings(), progress_callback=events.append, @@ -1176,14 +1166,10 @@ def test_gen_data_missing(storage, uniform_parameter, obs): @pytest.mark.usefixtures("use_tmpdir") def test_update_subset_parameters(storage, uniform_parameter, obs): no_update_param = GenKwConfig( - name="EXTRA_PARAMETER", - forward_init=False, - transform_function_definitions=[ - TransformFunctionDefinition( - name="KEY1", param_name="UNIFORM", values=[0, 1] - ), - ], + name="KEY_2", + group="EXTRA_PARAMETER", update=False, + distribution={"name": "uniform", "min": 0, "max": 1}, ) resp = GenDataConfig(keys=["RESPONSE"]) experiment = storage.create_experiment( @@ -1198,27 +1184,25 @@ def test_update_subset_parameters(storage, uniform_parameter, obs): name="prior", ) rng = np.random.default_rng(1234) + dataset_key_1 = [] + dataset_key_2 = [] for iens in range(prior.ensemble_size): data = rng.uniform(0, 1) - prior.save_parameters( - "PARAMETER", - iens, + dataset_key_1.append( pl.DataFrame( { "KEY_1": [data], "realization": iens, } - ), + ) ) - prior.save_parameters( - "EXTRA_PARAMETER", - iens, + dataset_key_2.append( pl.DataFrame( { - "KEY_1": [data], + "KEY_2": [data], "realization": iens, } - ), + ) ) data = rng.uniform(0.8, 1, 10) @@ -1234,6 +1218,9 @@ def test_update_subset_parameters(storage, uniform_parameter, obs): ), iens, ) + + prior.save_parameters(dataset=pl.concat(dataset_key_1, how="vertical")) + prior.save_parameters(dataset=pl.concat(dataset_key_2, how="vertical")) posterior_ens = storage.create_ensemble( prior.experiment_id, ensemble_size=prior.ensemble_size, @@ -1245,7 +1232,7 @@ def test_update_subset_parameters(storage, uniform_parameter, obs): prior, posterior_ens, ["OBSERVATION"], - ["PARAMETER"], + ["KEY_1"], ObservationSettings(), ESSettings(), ) diff --git a/tests/ert/unit_tests/config/test_ensemble_config.py b/tests/ert/unit_tests/config/test_ensemble_config.py index 6cf9e66214d..6ad1dfccd56 100644 --- a/tests/ert/unit_tests/config/test_ensemble_config.py +++ b/tests/ert/unit_tests/config/test_ensemble_config.py @@ -122,7 +122,7 @@ def test_ensemble_config_duplicate_node_names(): duplicate_name, ("FAULT_TEMPLATE", ""), "MULTFLT.INC", - ("MULTFLT.TXT", "a UNIFORM 0 1"), + ("MULTFLT.TXT", f"{duplicate_name} UNIFORM 0 1"), {"FORWARD_INIT": "FALSE"}, ] ], diff --git a/tests/ert/unit_tests/config/test_ert_config.py b/tests/ert/unit_tests/config/test_ert_config.py index a491379038b..39ba44f19bb 100644 --- a/tests/ert/unit_tests/config/test_ert_config.py +++ b/tests/ert/unit_tests/config/test_ert_config.py @@ -2384,34 +2384,6 @@ def test_queue_options_are_joined_after_option_name(): ) -def test_validation_error_on_gen_kw_with_design_matrix_group_name(tmp_path): - design_matrix_file = tmp_path / "my_design_matrix.xlsx" - _create_design_matrix( - design_matrix_file, - pl.DataFrame( - { - "REAL": [0, 1], - "letters": ["x", "y"], - } - ), - pl.DataFrame([["a", 1], ["c", 2]], orient="row"), - ) - with open(tmp_path / "coeffs_priors", mode="w", encoding="utf-8") as fh: - fh.write("a CONST 0") - with pytest.raises( - ConfigValidationError, - match="Cannot have GEN_KW with group name DESIGN_MATRIX " - "when using DESIGN_MATRIX keyword\\.", - ): - ErtConfig.from_file_contents( - f"""\ - NUM_REALIZATIONS 1 - DESIGN_MATRIX {tmp_path}/my_design_matrix.xlsx - GEN_KW DESIGN_MATRIX {tmp_path}/coeffs_priors - """ - ) - - @pytest.mark.parametrize( "invalid_parameter_definition_name", [ diff --git a/tests/ert/unit_tests/config/test_gen_kw_config.py b/tests/ert/unit_tests/config/test_gen_kw_config.py index 3602c4c6034..f2c83b1ab64 100644 --- a/tests/ert/unit_tests/config/test_gen_kw_config.py +++ b/tests/ert/unit_tests/config/test_gen_kw_config.py @@ -6,64 +6,14 @@ import networkx as nx import pytest from lark import Token -from pydantic import ValidationError from ert.config import ConfigValidationError, ConfigWarning, ErtConfig, GenKwConfig -from ert.config.gen_kw_config import TransformFunctionDefinition from ert.config.parsing.file_context_token import FileContextToken from ert.run_models._create_run_path import create_run_path from ert.runpaths import Runpaths from ert.sample_prior import sample_prior -@pytest.mark.usefixtures("use_tmpdir") -def test_gen_kw_config(): - conf = GenKwConfig( - name="KEY", - forward_init=False, - transform_function_definitions=[ - TransformFunctionDefinition( - name="KEY1", param_name="UNIFORM", values=[0, 1] - ), - TransformFunctionDefinition( - name="KEY2", param_name="UNIFORM", values=[0, 1] - ), - TransformFunctionDefinition( - name="KEY3", param_name="UNIFORM", values=[0, 1] - ), - ], - update=True, - ) - assert len(conf.transform_functions) == 3 - - -@pytest.mark.usefixtures("use_tmpdir") -def test_gen_kw_config_duplicate_keys_raises(): - with pytest.raises( - ValidationError, - match="Duplicate GEN_KW keys 'KEY2' found, keys must be unique\\.", - ): - GenKwConfig( - name="KEY", - forward_init=False, - transform_function_definitions=[ - TransformFunctionDefinition( - name="KEY1", param_name="UNIFORM", values=[0, 1] - ), - TransformFunctionDefinition( - name="KEY2", param_name="UNIFORM", values=[0, 1] - ), - TransformFunctionDefinition( - name="KEY2", param_name="UNIFORM", values=[0, 1] - ), - TransformFunctionDefinition( - name="KEY3", param_name="UNIFORM", values=[0, 1] - ), - ], - update=True, - ) - - def test_short_definition_raises_config_error(tmp_path): parameter_file = tmp_path / "parameter.txt" parameter_file.write_text("incorrect", encoding="utf-8") @@ -78,106 +28,122 @@ def test_short_definition_raises_config_error(tmp_path): ) -def test_gen_kw_config_get_priors(): - conf = GenKwConfig( - name="KW_NAME", - forward_init=False, - transform_function_definitions=[ - TransformFunctionDefinition( - name="KEY1", param_name="NORMAL", values=["0", "1"] - ), - TransformFunctionDefinition( - name="KEY2", param_name="LOGNORMAL", values=["2", "3"] - ), - TransformFunctionDefinition( - name="KEY3", param_name="TRUNCATED_NORMAL", values=["4", "5", "6", "7"] - ), - TransformFunctionDefinition( - name="KEY4", param_name="TRIANGULAR", values=["0", "1", "2"] - ), - TransformFunctionDefinition( - name="KEY5", param_name="UNIFORM", values=["2", "3"] - ), - TransformFunctionDefinition( - name="KEY6", param_name="DUNIF", values=["3", "0", "1"] - ), - TransformFunctionDefinition( - name="KEY7", param_name="ERRF", values=["0", "1", "2", "3"] - ), - TransformFunctionDefinition( - name="KEY8", param_name="DERRF", values=["1", "1", "2", "3", "4"] - ), - TransformFunctionDefinition( - name="KEY9", param_name="LOGUNIF", values=["1", "2"] - ), - TransformFunctionDefinition( - name="KEY10", param_name="CONST", values=["10"] - ), - ], - update=True, - ) - priors = conf.get_priors() - assert len(conf.transform_functions) == 10 - - assert { - "key": "KEY1", - "function": "NORMAL", - "parameters": {"MEAN": 0, "STD": 1}, - } in priors - - assert { - "key": "KEY2", - "function": "LOGNORMAL", - "parameters": {"MEAN": 2, "STD": 3}, - } in priors - - assert { - "key": "KEY3", - "function": "TRUNCATED_NORMAL", - "parameters": {"MEAN": 4, "STD": 5, "MIN": 6, "MAX": 7}, - } in priors - - assert { - "key": "KEY4", - "function": "TRIANGULAR", - "parameters": {"MIN": 0, "MODE": 1, "MAX": 2}, - } in priors - - assert { - "key": "KEY5", - "function": "UNIFORM", - "parameters": {"MIN": 2, "MAX": 3}, - } in priors - - assert { - "key": "KEY6", - "function": "DUNIF", - "parameters": {"STEPS": 3, "MIN": 0, "MAX": 1}, - } in priors - - assert { - "key": "KEY7", - "function": "ERRF", - "parameters": {"MIN": 0, "MAX": 1, "SKEWNESS": 2, "WIDTH": 3}, - } in priors - - assert { - "key": "KEY8", - "function": "DERRF", - "parameters": {"STEPS": 1, "MIN": 1, "MAX": 2, "SKEWNESS": 3, "WIDTH": 4}, - } in priors - - assert { - "key": "KEY9", - "function": "LOGUNIF", - "parameters": {"MIN": 1, "MAX": 2}, - } in priors - - assert { - "key": "KEY10", - "function": "CONST", - "parameters": {"VALUE": 10}, - } in priors +@pytest.mark.parametrize( + "spec, expected", + [ + ( + {"name": "KEY1", "distribution": {"name": "normal", "mean": 0, "std": 1}}, + {"key": "KEY1", "function": "NORMAL", "parameters": {"MEAN": 0, "STD": 1}}, + ), + ( + { + "name": "KEY2", + "distribution": {"name": "lognormal", "mean": 2, "std": 3}, + }, + { + "key": "KEY2", + "function": "LOGNORMAL", + "parameters": {"MEAN": 2, "STD": 3}, + }, + ), + ( + { + "name": "KEY3", + "distribution": { + "name": "truncated_normal", + "mean": 4, + "std": 5, + "min": 6, + "max": 7, + }, + }, + { + "key": "KEY3", + "function": "TRUNCATED_NORMAL", + "parameters": {"MEAN": 4, "STD": 5, "MIN": 6, "MAX": 7}, + }, + ), + ( + { + "name": "KEY4", + "distribution": {"name": "triangular", "min": 0, "mode": 1, "max": 2}, + }, + { + "key": "KEY4", + "function": "TRIANGULAR", + "parameters": {"MIN": 0, "MODE": 1, "MAX": 2}, + }, + ), + ( + {"name": "KEY5", "distribution": {"name": "uniform", "min": 2, "max": 3}}, + {"key": "KEY5", "function": "UNIFORM", "parameters": {"MIN": 2, "MAX": 3}}, + ), + ( + { + "name": "KEY6", + "distribution": {"name": "dunif", "steps": 3, "min": 0, "max": 1}, + }, + { + "key": "KEY6", + "function": "DUNIF", + "parameters": {"STEPS": 3, "MIN": 0, "MAX": 1}, + }, + ), + ( + { + "name": "KEY7", + "distribution": { + "name": "errf", + "min": 0, + "max": 1, + "skewness": 2, + "width": 3, + }, + }, + { + "key": "KEY7", + "function": "ERRF", + "parameters": {"MIN": 0, "MAX": 1, "SKEWNESS": 2, "WIDTH": 3}, + }, + ), + ( + { + "name": "KEY8", + "distribution": { + "name": "derrf", + "steps": 1, + "min": 1, + "max": 2, + "skewness": 3, + "width": 4, + }, + }, + { + "key": "KEY8", + "function": "DERRF", + "parameters": { + "STEPS": 1, + "MIN": 1, + "MAX": 2, + "SKEWNESS": 3, + "WIDTH": 4, + }, + }, + ), + ( + {"name": "KEY9", "distribution": {"name": "logunif", "min": 1, "max": 2}}, + {"key": "KEY9", "function": "LOGUNIF", "parameters": {"MIN": 1, "MAX": 2}}, + ), + ( + {"name": "KEY10", "distribution": {"name": "const", "value": 10}}, + {"key": "KEY10", "function": "CONST", "parameters": {"VALUE": 10}}, + ), + ], + ids=[f"KEY{i}" for i in range(1, 11)], +) +def test_gen_kw_config_get_priors(spec, expected): + cfg = GenKwConfig(**spec) + assert expected in cfg.get_priors() number_regex = r"[-+]?(?:\d*\.\d+|\d+)" @@ -222,7 +188,7 @@ def test_gen_kw_is_log_or_not( ert_config = ErtConfig.from_file("config.ert") - gen_kw_config = ert_config.ensemble_config.parameter_configs["KW_NAME"] + gen_kw_config = ert_config.ensemble_config.parameter_configs["MY_KEYWORD"] assert isinstance(gen_kw_config, GenKwConfig) experiment_id = storage.create_experiment( parameters=ert_config.ensemble_config.parameter_configuration @@ -353,27 +319,19 @@ def test_gen_kw_distribution_errors(tmpdir, distribution, mean, std, error): ) def test_gen_kw_params_parsing(tmpdir, params, error): with tmpdir.as_cwd(): - ss = params.split() + parts = params.split() + name, dist_name, values = parts[0], parts[1], parts[2:] - tfd = TransformFunctionDefinition( - name=ss[0], - param_name=ss[1], - values=ss[2:], - ) if error: - with pytest.raises(ValidationError, match=error): - GenKwConfig( - name="MY_PARAM", - forward_init=False, - update=False, - transform_function_definitions=[tfd], - ) + with pytest.raises(ConfigValidationError, match=error): + GenKwConfig._parse_distribution(name, dist_name, values) else: + dist = GenKwConfig._parse_distribution(name, dist_name, values) GenKwConfig( - name="MY_PARAM", + name=name, forward_init=False, update=False, - transform_function_definitions=[tfd], + distribution=dist, ) @@ -430,26 +388,15 @@ def test_gen_kw_params_parsing(tmpdir, params, error): ], ) def test_gen_kw_trans_func(tmpdir, params, xinput, expected): - args = params.split()[2:] - float_args = [] - for a in args: - float_args.append(float(a)) - - tfd = TransformFunctionDefinition( - name=params.split()[0], - param_name=params.split()[1], - values=params.split()[2:], - ) - + name, dist_name, *values = params.split() with tmpdir.as_cwd(): - gkw = GenKwConfig( - name="MY_PARAM", + cfg = GenKwConfig( + name=name, forward_init=False, update=False, - transform_function_definitions=[tfd], + distribution=GenKwConfig._parse_distribution(name, dist_name, values), ) - tf = gkw.transform_functions[0] - assert abs(tf.distribution.transform(xinput) - expected) < 10**-15 + assert abs(cfg.distribution.transform(xinput) - expected) < 10**-15 def test_gen_kw_objects_equal(tmpdir): @@ -463,23 +410,19 @@ def test_gen_kw_objects_equal(tmpdir): ("prior.txt", "MY_KEYWORD UNIFORM 1 2"), {}, ] - ) - assert g1.transform_functions[0].name == "MY_KEYWORD" - - tfd = TransformFunctionDefinition( - name="MY_KEYWORD", param_name="UNIFORM", values=["1", "2"] - ) + )[0] + assert g1.name == "MY_KEYWORD" + assert g1.group == "KW_NAME" g2 = GenKwConfig( - name="KW_NAME", - forward_init=False, - transform_function_definitions=[tfd], - update=True, + name="MY_KEYWORD", + group="KW_NAME", + distribution={"name": "uniform", "min": 1, "max": 2}, ) + assert g1.name == g2.name - assert ( - g1.transform_function_definitions[0] == g2.transform_function_definitions[0] - ) + assert g1.group == g2.group + assert g1.distribution == g2.distribution @pytest.mark.usefixtures("use_tmpdir") @@ -789,41 +732,11 @@ def test_validation_derrf_distribution( GenKwConfig.from_config_list(config_list) -@pytest.mark.parametrize( - "transform_fns", - [ - [], - [{"name": "dummy", "param_name": "NORMAL", "values": [0, 0.1]}], - [ - {"name": f"dummy_{i}", "param_name": "NORMAL", "values": [0, 0.1]} - for i in range(100) - ], - ], -) -def test_that_transfer_function_names_are_reflected_as_parameter_keys(transform_fns): +def test_genkw_paramgraph_transformfn_node_correspondence(): config = GenKwConfig( - name="a_group", - forward_init=False, - update=True, - transform_function_definitions=[ - TransformFunctionDefinition(**tf) for tf in transform_fns - ], - ) - assert config.parameter_keys == [tf["name"] for tf in transform_fns] - - -@pytest.mark.parametrize("num_tfs", [0, 1, 3, 8]) -def test_genkw_paramgraph_transformfn_node_correspondence(num_tfs): - config = GenKwConfig( - name="COEFFS", - forward_init=True, - update=True, - transform_function_definitions=[ - TransformFunctionDefinition( - name=f"tf_{i}", param_name="UNIFORM", values=[0, 1] - ) - for i in range(num_tfs) - ], + name="param", + group="COEFFS", + distribution={"name": "uniform", "min": 1, "max": 2}, ) graph = config.load_parameter_graph() @@ -831,4 +744,4 @@ def test_genkw_paramgraph_transformfn_node_correspondence(num_tfs): data = nx.node_link_data(graph) assert data["links"] == [] - assert data["nodes"] == [{"id": i} for i in range(num_tfs)] + assert data["nodes"] == [{"id": 0}] diff --git a/tests/ert/unit_tests/config/test_surface_config.py b/tests/ert/unit_tests/config/test_surface_config.py index ad33a2a4ab2..e3158b0e791 100644 --- a/tests/ert/unit_tests/config/test_surface_config.py +++ b/tests/ert/unit_tests/config/test_surface_config.py @@ -51,7 +51,7 @@ def test_runpath_roundtrip(tmp_path, storage, surface): # run_path -> storage ds = config.read_from_runpath(tmp_path, 0, 0) - ensemble.save_parameters(config.name, 0, ds) + ensemble.save_parameters(ds, config.name, 0) # storage -> run_path config.forward_init_file = "output_%d" diff --git a/tests/ert/unit_tests/dark_storage/test_http_endpoints.py b/tests/ert/unit_tests/dark_storage/test_http_endpoints.py index 58e3eb6f50d..8191c268be2 100644 --- a/tests/ert/unit_tests/dark_storage/test_http_endpoints.py +++ b/tests/ert/unit_tests/dark_storage/test_http_endpoints.py @@ -169,26 +169,30 @@ def test_get_ensemble_parameters(poly_example_tmp_dir, dark_storage_client): experiment_json = resp.json()[0] assert experiment_json["parameters"] == { - "COEFFS": [ + "a": [ { "key": "COEFFS:a", "transformation": "UNIFORM", "dimensionality": 1, "userdata": {"data_origin": "GEN_KW"}, - }, + } + ], + "b": [ { "key": "COEFFS:b", "transformation": "UNIFORM", "dimensionality": 1, "userdata": {"data_origin": "GEN_KW"}, - }, + } + ], + "c": [ { "key": "COEFFS:c", "transformation": "UNIFORM", "dimensionality": 1, "userdata": {"data_origin": "GEN_KW"}, - }, - ] + } + ], } diff --git a/tests/ert/unit_tests/gui/ertwidgets/models/test_ertsummary.py b/tests/ert/unit_tests/gui/ertwidgets/models/test_ertsummary.py index bfa067830d5..3eca9396034 100644 --- a/tests/ert/unit_tests/gui/ertwidgets/models/test_ertsummary.py +++ b/tests/ert/unit_tests/gui/ertwidgets/models/test_ertsummary.py @@ -3,7 +3,6 @@ import pytest from ert.config import Field, GenKwConfig, SurfaceConfig -from ert.config.gen_kw_config import TransformFunctionDefinition from ert.field_utils import FieldFileFormat from ert.gui.ertwidgets.models.ertsummary import ErtSummary @@ -17,22 +16,20 @@ def mock_ert(monkeypatch): "forward_model_2", ] - gen_kw = GenKwConfig( - name="KEY", - forward_init=False, - transform_function_definitions=[ - TransformFunctionDefinition( - name="KEY1", param_name="UNIFORM", values=[0, 1] - ), - TransformFunctionDefinition( - name="KEY2", param_name="NORMAL", values=[0, 1] - ), - TransformFunctionDefinition( - name="KEY3", param_name="LOGNORMAL", values=[0, 1] - ), - ], - update=True, - ) + gen_kws = { + "KEY_1": GenKwConfig( + name="KEY_1", + distribution={"name": "uniform", "min": 0, "max": 1}, + ), + "KEY_2": GenKwConfig( + name="KEY_2", + distribution={"name": "normal", "mean": 0, "std": 1}, + ), + "KEY_3": GenKwConfig( + name="KEY_3", + distribution={"name": "lognormal", "mean": 0, "std": 1}, + ), + } surface = SurfaceConfig( name="some_name", @@ -70,9 +67,8 @@ def mock_ert(monkeypatch): ert_mock.ensemble_config.parameter_configs = { "surface": surface, - "gen_kw": gen_kw, "field": field, - } + } | gen_kws yield ert_mock @@ -84,7 +80,7 @@ def test_getForwardModels(mock_ert): def test_getParameters(mock_ert): - expected_list = ["field (10, 5, 3)", "gen_kw (3)", "surface (10, 7)"] + expected_list = ["DEFAULT (3)", "field (10, 5, 3)", "surface (10, 7)"] parameter_list, parameter_count = ErtSummary(mock_ert).getParameters() assert parameter_list == expected_list assert parameter_count == 223 diff --git a/tests/ert/unit_tests/gui/ertwidgets/test_ensembleselector.py b/tests/ert/unit_tests/gui/ertwidgets/test_ensembleselector.py index 910d764db8e..a8063b280d6 100644 --- a/tests/ert/unit_tests/gui/ertwidgets/test_ensembleselector.py +++ b/tests/ert/unit_tests/gui/ertwidgets/test_ensembleselector.py @@ -1,7 +1,6 @@ import pytest from ert.config import GenDataConfig, GenKwConfig -from ert.config.gen_kw_config import TransformFunctionDefinition from ert.gui.ertnotifier import ErtNotifier from ert.gui.ertwidgets.ensembleselector import EnsembleSelector from ert.storage.realization_storage_state import RealizationStorageState @@ -10,14 +9,8 @@ @pytest.fixture def uniform_parameter(): return GenKwConfig( - name="parameter", - forward_init=False, - transform_function_definitions=[ - TransformFunctionDefinition( - name="KEY1", param_name="UNIFORM", values=[0, 1] - ), - ], - update=True, + name="KEY_1", + distribution={"name": "uniform", "min": 0, "max": 1}, ) diff --git a/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py b/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py index ae714bddf76..8025c854979 100644 --- a/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py +++ b/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py @@ -9,12 +9,10 @@ import pandas as pd import polars as pl import pytest -import xarray as xr from pandas.testing import assert_frame_equal from starlette.testclient import TestClient from ert.config import GenKwConfig, SummaryConfig -from ert.config.gen_kw_config import TransformFunctionDefinition from ert.config.parameter_config import ParameterMetadata from ert.config.response_config import ResponseMetadata from ert.dark_storage import common @@ -279,14 +277,10 @@ def test_plot_api_handles_empty_gen_kw(api_and_storage): experiment = storage.create_experiment( parameters=[ GenKwConfig( - name=key, - forward_init=False, + name=name, + group=key, update=False, - transform_function_definitions=[ - TransformFunctionDefinition( - name=name, param_name="NORMAL", values=[0, 0.1] - ) - ], + distribution={"name": "normal", "mean": 0, "std": 0.1}, ), ], responses=[], @@ -295,8 +289,6 @@ def test_plot_api_handles_empty_gen_kw(api_and_storage): ensemble = storage.create_ensemble(experiment.id, ensemble_size=10) assert api.data_for_parameter(str(ensemble.id), key).empty ensemble.save_parameters( - key, - realization=None, dataset=pl.DataFrame( { name: [1.0], @@ -319,26 +311,15 @@ def test_plot_api_handles_non_existant_gen_kw(api_and_storage): experiment = storage.create_experiment( parameters=[ GenKwConfig( - name="gen_kw", - forward_init=False, - update=False, - transform_function_definitions=[], + name="KEY_1", + group="gen_kw", + distribution={"name": "normal", "mean": 0, "std": 1}, ), ], responses=[], observations={}, ) ensemble = storage.create_ensemble(experiment.id, ensemble_size=10) - ensemble.save_parameters( - "gen_kw", - 1, - xr.Dataset( - { - "values": ("names", [1.0]), - "names": ["key"], - } - ), - ) assert api.data_for_parameter(str(ensemble.id), "gen_kw").empty assert api.data_for_parameter(str(ensemble.id), "gen_kw:does_not_exist").empty @@ -348,14 +329,10 @@ def test_plot_api_handles_colons_in_parameter_keys(api_and_storage): experiment = storage.create_experiment( parameters=[ GenKwConfig( - name="group", - forward_init=False, + name="subgroup:1:2:2", + group="group", update=False, - transform_function_definitions=[ - TransformFunctionDefinition( - name="subgroup:1:2:2", param_name="RAW", values=[] - ), - ], + distribution={"name": "raw"}, ), ], responses=[], @@ -363,8 +340,6 @@ def test_plot_api_handles_colons_in_parameter_keys(api_and_storage): ) ensemble = storage.create_ensemble(experiment.id, ensemble_size=10) ensemble.save_parameters( - "group", - 0, pl.DataFrame( { "subgroup:1:2:2": pl.Series([10], dtype=pl.Float32), diff --git a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_enif_matches_snapshot/heat_equationconfig.ert/config.json b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_enif_matches_snapshot/heat_equationconfig.ert/config.json index 7b865ce2373..c86de6496f8 100644 --- a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_enif_matches_snapshot/heat_equationconfig.ert/config.json +++ b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_enif_matches_snapshot/heat_equationconfig.ert/config.json @@ -186,35 +186,29 @@ "parameter_configuration": [ { "type": "gen_kw", - "name": "INIT_TEMP_SCALE", + "name": "t", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "t", - "param_name": "UNIFORM", - "values": [ - "0", - "1" - ] - } - ] + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 1.0 + }, + "group": "INIT_TEMP_SCALE", + "input_source": "sampled" }, { "type": "gen_kw", - "name": "CORR_LENGTH", + "name": "x", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "x", - "param_name": "NORMAL", - "values": [ - "0.8", - "0.1" - ] - } - ] + "distribution": { + "name": "normal", + "mean": 0.8, + "std": 0.1 + }, + "group": "CORR_LENGTH", + "input_source": "sampled" }, { "type": "field", diff --git a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_enif_matches_snapshot/poly_examplepoly.ert/poly.json b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_enif_matches_snapshot/poly_examplepoly.ert/poly.json index 0a9992c7c41..7c78783b5eb 100644 --- a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_enif_matches_snapshot/poly_examplepoly.ert/poly.json +++ b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_enif_matches_snapshot/poly_examplepoly.ert/poly.json @@ -176,35 +176,42 @@ "parameter_configuration": [ { "type": "gen_kw", - "name": "COEFFS", + "name": "a", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "a", - "param_name": "UNIFORM", - "values": [ - "0", - "1" - ] - }, - { - "name": "b", - "param_name": "UNIFORM", - "values": [ - "0", - "2" - ] - }, - { - "name": "c", - "param_name": "UNIFORM", - "values": [ - "0", - "5" - ] - } - ] + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 1.0 + }, + "group": "COEFFS", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "b", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 2.0 + }, + "group": "COEFFS", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "c", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 5.0 + }, + "group": "COEFFS", + "input_source": "sampled" } ], "response_configuration": [ diff --git a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_enif_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_enif_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json index 5957efb46cf..9aae73172e1 100644 --- a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_enif_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json +++ b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_enif_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json @@ -148,91 +148,133 @@ "parameter_configuration": [ { "type": "gen_kw", - "name": "SNAKE_OIL_PARAM", + "name": "OP1_PERSISTENCE", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "OP1_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.01", - "0.4" - ] - }, - { - "name": "OP1_OCTAVES", - "param_name": "UNIFORM", - "values": [ - "3", - "5" - ] - }, - { - "name": "OP1_DIVERGENCE_SCALE", - "param_name": "UNIFORM", - "values": [ - "0.25", - "1.25" - ] - }, - { - "name": "OP1_OFFSET", - "param_name": "UNIFORM", - "values": [ - "-0.1", - "0.1" - ] - }, - { - "name": "OP2_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.1", - "0.6" - ] - }, - { - "name": "OP2_OCTAVES", - "param_name": "UNIFORM", - "values": [ - "5", - "12" - ] - }, - { - "name": "OP2_DIVERGENCE_SCALE", - "param_name": "UNIFORM", - "values": [ - "0.5", - "1.5" - ] - }, - { - "name": "OP2_OFFSET", - "param_name": "UNIFORM", - "values": [ - "-0.2", - "0.2" - ] - }, - { - "name": "BPR_555_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.1", - "0.5" - ] - }, - { - "name": "BPR_138_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.2", - "0.7" - ] - } - ] + "distribution": { + "name": "uniform", + "min": 0.01, + "max": 0.4 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP1_OCTAVES", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 3.0, + "max": 5.0 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP1_DIVERGENCE_SCALE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.25, + "max": 1.25 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP1_OFFSET", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": -0.1, + "max": 0.1 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_PERSISTENCE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.1, + "max": 0.6 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_OCTAVES", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 5.0, + "max": 12.0 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_DIVERGENCE_SCALE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.5, + "max": 1.5 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_OFFSET", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": -0.2, + "max": 0.2 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "BPR_555_PERSISTENCE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.1, + "max": 0.5 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "BPR_138_PERSISTENCE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.2, + "max": 0.7 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" } ], "response_configuration": [ diff --git a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_experiment_matches_snapshot/heat_equationconfig.ert/config.json b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_experiment_matches_snapshot/heat_equationconfig.ert/config.json index 3744ee159e7..4cab0c1c9f5 100644 --- a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_experiment_matches_snapshot/heat_equationconfig.ert/config.json +++ b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_experiment_matches_snapshot/heat_equationconfig.ert/config.json @@ -186,35 +186,29 @@ "parameter_configuration": [ { "type": "gen_kw", - "name": "INIT_TEMP_SCALE", + "name": "t", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "t", - "param_name": "UNIFORM", - "values": [ - "0", - "1" - ] - } - ] + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 1.0 + }, + "group": "INIT_TEMP_SCALE", + "input_source": "sampled" }, { "type": "gen_kw", - "name": "CORR_LENGTH", + "name": "x", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "x", - "param_name": "NORMAL", - "values": [ - "0.8", - "0.1" - ] - } - ] + "distribution": { + "name": "normal", + "mean": 0.8, + "std": 0.1 + }, + "group": "CORR_LENGTH", + "input_source": "sampled" }, { "type": "field", diff --git a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_experiment_matches_snapshot/poly_examplepoly.ert/poly.json b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_experiment_matches_snapshot/poly_examplepoly.ert/poly.json index 5c85e89b624..d2c71313dbf 100644 --- a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_experiment_matches_snapshot/poly_examplepoly.ert/poly.json +++ b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_experiment_matches_snapshot/poly_examplepoly.ert/poly.json @@ -176,35 +176,42 @@ "parameter_configuration": [ { "type": "gen_kw", - "name": "COEFFS", + "name": "a", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "a", - "param_name": "UNIFORM", - "values": [ - "0", - "1" - ] - }, - { - "name": "b", - "param_name": "UNIFORM", - "values": [ - "0", - "2" - ] - }, - { - "name": "c", - "param_name": "UNIFORM", - "values": [ - "0", - "5" - ] - } - ] + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 1.0 + }, + "group": "COEFFS", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "b", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 2.0 + }, + "group": "COEFFS", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "c", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 5.0 + }, + "group": "COEFFS", + "input_source": "sampled" } ], "response_configuration": [ diff --git a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_experiment_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_experiment_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json index f79f00c5aac..4ed4229579f 100644 --- a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_experiment_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json +++ b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_experiment_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json @@ -148,91 +148,133 @@ "parameter_configuration": [ { "type": "gen_kw", - "name": "SNAKE_OIL_PARAM", + "name": "OP1_PERSISTENCE", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "OP1_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.01", - "0.4" - ] - }, - { - "name": "OP1_OCTAVES", - "param_name": "UNIFORM", - "values": [ - "3", - "5" - ] - }, - { - "name": "OP1_DIVERGENCE_SCALE", - "param_name": "UNIFORM", - "values": [ - "0.25", - "1.25" - ] - }, - { - "name": "OP1_OFFSET", - "param_name": "UNIFORM", - "values": [ - "-0.1", - "0.1" - ] - }, - { - "name": "OP2_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.1", - "0.6" - ] - }, - { - "name": "OP2_OCTAVES", - "param_name": "UNIFORM", - "values": [ - "5", - "12" - ] - }, - { - "name": "OP2_DIVERGENCE_SCALE", - "param_name": "UNIFORM", - "values": [ - "0.5", - "1.5" - ] - }, - { - "name": "OP2_OFFSET", - "param_name": "UNIFORM", - "values": [ - "-0.2", - "0.2" - ] - }, - { - "name": "BPR_555_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.1", - "0.5" - ] - }, - { - "name": "BPR_138_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.2", - "0.7" - ] - } - ] + "distribution": { + "name": "uniform", + "min": 0.01, + "max": 0.4 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP1_OCTAVES", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 3.0, + "max": 5.0 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP1_DIVERGENCE_SCALE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.25, + "max": 1.25 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP1_OFFSET", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": -0.1, + "max": 0.1 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_PERSISTENCE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.1, + "max": 0.6 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_OCTAVES", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 5.0, + "max": 12.0 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_DIVERGENCE_SCALE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.5, + "max": 1.5 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_OFFSET", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": -0.2, + "max": 0.2 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "BPR_555_PERSISTENCE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.1, + "max": 0.5 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "BPR_138_PERSISTENCE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.2, + "max": 0.7 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" } ], "response_configuration": [ diff --git a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_smoother_matches_snapshot/heat_equationconfig.ert/config.json b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_smoother_matches_snapshot/heat_equationconfig.ert/config.json index 7b865ce2373..c86de6496f8 100644 --- a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_smoother_matches_snapshot/heat_equationconfig.ert/config.json +++ b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_smoother_matches_snapshot/heat_equationconfig.ert/config.json @@ -186,35 +186,29 @@ "parameter_configuration": [ { "type": "gen_kw", - "name": "INIT_TEMP_SCALE", + "name": "t", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "t", - "param_name": "UNIFORM", - "values": [ - "0", - "1" - ] - } - ] + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 1.0 + }, + "group": "INIT_TEMP_SCALE", + "input_source": "sampled" }, { "type": "gen_kw", - "name": "CORR_LENGTH", + "name": "x", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "x", - "param_name": "NORMAL", - "values": [ - "0.8", - "0.1" - ] - } - ] + "distribution": { + "name": "normal", + "mean": 0.8, + "std": 0.1 + }, + "group": "CORR_LENGTH", + "input_source": "sampled" }, { "type": "field", diff --git a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_smoother_matches_snapshot/poly_examplepoly.ert/poly.json b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_smoother_matches_snapshot/poly_examplepoly.ert/poly.json index 0a9992c7c41..7c78783b5eb 100644 --- a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_smoother_matches_snapshot/poly_examplepoly.ert/poly.json +++ b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_smoother_matches_snapshot/poly_examplepoly.ert/poly.json @@ -176,35 +176,42 @@ "parameter_configuration": [ { "type": "gen_kw", - "name": "COEFFS", + "name": "a", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "a", - "param_name": "UNIFORM", - "values": [ - "0", - "1" - ] - }, - { - "name": "b", - "param_name": "UNIFORM", - "values": [ - "0", - "2" - ] - }, - { - "name": "c", - "param_name": "UNIFORM", - "values": [ - "0", - "5" - ] - } - ] + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 1.0 + }, + "group": "COEFFS", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "b", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 2.0 + }, + "group": "COEFFS", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "c", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 5.0 + }, + "group": "COEFFS", + "input_source": "sampled" } ], "response_configuration": [ diff --git a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_smoother_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_smoother_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json index 5957efb46cf..9aae73172e1 100644 --- a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_smoother_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json +++ b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_ensemble_smoother_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json @@ -148,91 +148,133 @@ "parameter_configuration": [ { "type": "gen_kw", - "name": "SNAKE_OIL_PARAM", + "name": "OP1_PERSISTENCE", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "OP1_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.01", - "0.4" - ] - }, - { - "name": "OP1_OCTAVES", - "param_name": "UNIFORM", - "values": [ - "3", - "5" - ] - }, - { - "name": "OP1_DIVERGENCE_SCALE", - "param_name": "UNIFORM", - "values": [ - "0.25", - "1.25" - ] - }, - { - "name": "OP1_OFFSET", - "param_name": "UNIFORM", - "values": [ - "-0.1", - "0.1" - ] - }, - { - "name": "OP2_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.1", - "0.6" - ] - }, - { - "name": "OP2_OCTAVES", - "param_name": "UNIFORM", - "values": [ - "5", - "12" - ] - }, - { - "name": "OP2_DIVERGENCE_SCALE", - "param_name": "UNIFORM", - "values": [ - "0.5", - "1.5" - ] - }, - { - "name": "OP2_OFFSET", - "param_name": "UNIFORM", - "values": [ - "-0.2", - "0.2" - ] - }, - { - "name": "BPR_555_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.1", - "0.5" - ] - }, - { - "name": "BPR_138_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.2", - "0.7" - ] - } - ] + "distribution": { + "name": "uniform", + "min": 0.01, + "max": 0.4 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP1_OCTAVES", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 3.0, + "max": 5.0 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP1_DIVERGENCE_SCALE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.25, + "max": 1.25 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP1_OFFSET", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": -0.1, + "max": 0.1 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_PERSISTENCE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.1, + "max": 0.6 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_OCTAVES", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 5.0, + "max": 12.0 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_DIVERGENCE_SCALE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.5, + "max": 1.5 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_OFFSET", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": -0.2, + "max": 0.2 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "BPR_555_PERSISTENCE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.1, + "max": 0.5 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "BPR_138_PERSISTENCE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.2, + "max": 0.7 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" } ], "response_configuration": [ diff --git a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_esmda_matches_snapshot/heat_equationconfig.ert/config.json b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_esmda_matches_snapshot/heat_equationconfig.ert/config.json index 84080e1dd82..c539d50937e 100644 --- a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_esmda_matches_snapshot/heat_equationconfig.ert/config.json +++ b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_esmda_matches_snapshot/heat_equationconfig.ert/config.json @@ -186,35 +186,29 @@ "parameter_configuration": [ { "type": "gen_kw", - "name": "INIT_TEMP_SCALE", + "name": "t", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "t", - "param_name": "UNIFORM", - "values": [ - "0", - "1" - ] - } - ] + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 1.0 + }, + "group": "INIT_TEMP_SCALE", + "input_source": "sampled" }, { "type": "gen_kw", - "name": "CORR_LENGTH", + "name": "x", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "x", - "param_name": "NORMAL", - "values": [ - "0.8", - "0.1" - ] - } - ] + "distribution": { + "name": "normal", + "mean": 0.8, + "std": 0.1 + }, + "group": "CORR_LENGTH", + "input_source": "sampled" }, { "type": "field", diff --git a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_esmda_matches_snapshot/poly_examplepoly.ert/poly.json b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_esmda_matches_snapshot/poly_examplepoly.ert/poly.json index f44830769ae..38211c07166 100644 --- a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_esmda_matches_snapshot/poly_examplepoly.ert/poly.json +++ b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_esmda_matches_snapshot/poly_examplepoly.ert/poly.json @@ -176,35 +176,42 @@ "parameter_configuration": [ { "type": "gen_kw", - "name": "COEFFS", + "name": "a", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "a", - "param_name": "UNIFORM", - "values": [ - "0", - "1" - ] - }, - { - "name": "b", - "param_name": "UNIFORM", - "values": [ - "0", - "2" - ] - }, - { - "name": "c", - "param_name": "UNIFORM", - "values": [ - "0", - "5" - ] - } - ] + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 1.0 + }, + "group": "COEFFS", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "b", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 2.0 + }, + "group": "COEFFS", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "c", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.0, + "max": 5.0 + }, + "group": "COEFFS", + "input_source": "sampled" } ], "response_configuration": [ diff --git a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_esmda_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_esmda_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json index 4c85e8a5af4..6b9995add8c 100644 --- a/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_esmda_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json +++ b/tests/ert/unit_tests/run_models/snapshots/test_experiment_serialization/test_that_dumped_esmda_matches_snapshot/snake_oilsnake_oil.ert/snake_oil.json @@ -148,91 +148,133 @@ "parameter_configuration": [ { "type": "gen_kw", - "name": "SNAKE_OIL_PARAM", + "name": "OP1_PERSISTENCE", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "OP1_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.01", - "0.4" - ] - }, - { - "name": "OP1_OCTAVES", - "param_name": "UNIFORM", - "values": [ - "3", - "5" - ] - }, - { - "name": "OP1_DIVERGENCE_SCALE", - "param_name": "UNIFORM", - "values": [ - "0.25", - "1.25" - ] - }, - { - "name": "OP1_OFFSET", - "param_name": "UNIFORM", - "values": [ - "-0.1", - "0.1" - ] - }, - { - "name": "OP2_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.1", - "0.6" - ] - }, - { - "name": "OP2_OCTAVES", - "param_name": "UNIFORM", - "values": [ - "5", - "12" - ] - }, - { - "name": "OP2_DIVERGENCE_SCALE", - "param_name": "UNIFORM", - "values": [ - "0.5", - "1.5" - ] - }, - { - "name": "OP2_OFFSET", - "param_name": "UNIFORM", - "values": [ - "-0.2", - "0.2" - ] - }, - { - "name": "BPR_555_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.1", - "0.5" - ] - }, - { - "name": "BPR_138_PERSISTENCE", - "param_name": "UNIFORM", - "values": [ - "0.2", - "0.7" - ] - } - ] + "distribution": { + "name": "uniform", + "min": 0.01, + "max": 0.4 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP1_OCTAVES", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 3.0, + "max": 5.0 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP1_DIVERGENCE_SCALE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.25, + "max": 1.25 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP1_OFFSET", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": -0.1, + "max": 0.1 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_PERSISTENCE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.1, + "max": 0.6 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_OCTAVES", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 5.0, + "max": 12.0 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_DIVERGENCE_SCALE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.5, + "max": 1.5 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "OP2_OFFSET", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": -0.2, + "max": 0.2 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "BPR_555_PERSISTENCE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.1, + "max": 0.5 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" + }, + { + "type": "gen_kw", + "name": "BPR_138_PERSISTENCE", + "forward_init": false, + "update": true, + "distribution": { + "name": "uniform", + "min": 0.2, + "max": 0.7 + }, + "group": "SNAKE_OIL_PARAM", + "input_source": "sampled" } ], "response_configuration": [ diff --git a/tests/ert/unit_tests/run_models/test_base_run_model.py b/tests/ert/unit_tests/run_models/test_base_run_model.py index 22d5301b446..9b44f2f1fce 100644 --- a/tests/ert/unit_tests/run_models/test_base_run_model.py +++ b/tests/ert/unit_tests/run_models/test_base_run_model.py @@ -12,13 +12,7 @@ import pytest from pydantic import ConfigDict -from ert.config import ( - ErtConfig, - GenKwConfig, - ModelConfig, - QueueConfig, -) -from ert.config.gen_kw_config import TransformFunctionDefinition +from ert.config import ErtConfig, GenKwConfig, ModelConfig, QueueConfig from ert.ensemble_evaluator import EndEvent, EvaluatorServerConfig from ert.ensemble_evaluator.snapshot import EnsembleSnapshot from ert.run_models.run_model import RunModel, UserCancelled @@ -605,17 +599,10 @@ def test_create_mask_from_failed_realizations_returns_initial_active_realization assert failed_realization_mask == initial_active_realizations +# TODO remove this test? def test_run_model_logs_number_of_parameters(use_tmpdir): - tfds = [ - TransformFunctionDefinition(name="a", param_name="NORMAL", values=[1, 2]), - TransformFunctionDefinition(name="b", param_name="NORMAL", values=[1, 2]), - TransformFunctionDefinition(name="c", param_name="NORMAL", values=[1, 2]), - TransformFunctionDefinition(name="d", param_name="NORMAL", values=[1, 2]), - TransformFunctionDefinition(name="e", param_name="NORMAL", values=[1, 2]), - ] - parameters = GenKwConfig( - transform_function_definitions=tfds, + distribution={"name": "normal", "mean": 0, "std": 1}, name="parameter_configuration", forward_init=False, update=True, @@ -628,7 +615,7 @@ def mock_logging(_, log_str): match = re.search(regex, log_str) num_param = int(match.group(1)) - assert num_param == len(tfds) + assert num_param == 1 with patch.object(Logger, "info", mock_logging): rm.log_at_startup() diff --git a/tests/ert/unit_tests/run_models/test_experiment_serialization.py b/tests/ert/unit_tests/run_models/test_experiment_serialization.py index 8c687cbe606..42659a288ef 100644 --- a/tests/ert/unit_tests/run_models/test_experiment_serialization.py +++ b/tests/ert/unit_tests/run_models/test_experiment_serialization.py @@ -27,6 +27,7 @@ Field, ForwardModelStep, GenDataConfig, + GenKwConfig, HookRuntime, ModelConfig, ObservationSettings, @@ -35,7 +36,6 @@ SurfaceConfig, Workflow, ) -from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition from ert.config.parsing import SchemaItemType from ert.config.queue_config import ( LocalQueueOptions, @@ -398,17 +398,18 @@ def multidass(_): } -transform_function_definitions = st.builds( - TransformFunctionDefinition, - param_name=st.just("NORMAL"), - values=st.just([0, 1]), +distribution_strategy = st.fixed_dictionaries( + { + "name": st.sampled_from(["normal", "lognormal"]), + "mean": st.floats(min_value=-100, max_value=100), + "std": st.floats(min_value=0.001, max_value=10), + } ) gen_kw_configs = st.builds( GenKwConfig, - transform_function_definitions=st.lists( - transform_function_definitions, unique_by=lambda tdf: tdf.name - ), + name=st.text(min_size=1, max_size=20), + distribution=distribution_strategy, ) diff --git a/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py b/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py index 57406aff9b5..732a814d54a 100644 --- a/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py +++ b/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py @@ -3,9 +3,8 @@ import pytest from xlsxwriter import Workbook -from ert.config import DesignMatrix, GenKwConfig +from ert.config import DataSource, DesignMatrix, GenKwConfig from ert.config.design_matrix import DESIGN_MATRIX_GROUP -from ert.config.gen_kw_config import TransformFunctionDefinition from tests.ert.conftest import _create_design_matrix @@ -95,7 +94,7 @@ def test_merge_multiple_occurrences( design_matrix_1.merge_with_other(design_matrix_2) else: design_matrix_1.merge_with_other(design_matrix_2) - design_params = design_matrix_1.parameter_configuration + design_params = [cfg.name for cfg in design_matrix_1.parameter_configurations] assert all(param in design_params for param in ("a", "b", "c", "d")) assert design_matrix_1.active_realizations == [True, True, True] df = design_matrix_1.design_matrix_df @@ -106,50 +105,60 @@ def test_merge_multiple_occurrences( @pytest.mark.parametrize( - "parameters, error_msg", + "parameters, num_configs, input_source, group_name", [ pytest.param( - {"COEFFS": ["a", "b"]}, - "", + ["a", "b"], + 2, + {"a": DataSource.DESIGN_MATRIX, "b": DataSource.DESIGN_MATRIX}, + {"a": "COEFFS", "b": "COEFFS"}, id="genkw_replaced", ), pytest.param( - {"COEFFS": ["a"]}, - "Overlapping parameter names found in design matrix!", - id="ValidationErrorOverlapping", - ), - pytest.param( - {"COEFFS": ["aa", "bb"], "COEFFS2": ["cc", "dd"]}, - "", - id="DESIGN_MATRIX_GROUP", + ["aa", "bb"], + 4, + { + "a": DataSource.DESIGN_MATRIX, + "b": DataSource.DESIGN_MATRIX, + "aa": DataSource.SAMPLED, + "bb": DataSource.SAMPLED, + }, + { + "a": DESIGN_MATRIX_GROUP, + "b": DESIGN_MATRIX_GROUP, + "aa": "COEFFS", + "bb": "COEFFS", + }, + id="genkw_added", ), pytest.param( - {"COEFFS": ["a", "b"], "COEFFS2": ["a", "b"]}, - ( - "Multiple overlapping groups with design matrix " - "found in existing parameters!" - ), - id="ValidationErrorMultipleGroups", + ["a", "bb"], + 3, + { + "a": DataSource.DESIGN_MATRIX, + "b": DataSource.DESIGN_MATRIX, + "bb": DataSource.SAMPLED, + }, + { + "a": "COEFFS", + "b": DESIGN_MATRIX_GROUP, + "bb": "COEFFS", + }, + id="genkw_added_and_replaced", ), ], ) -def test_read_and_merge_with_existing_parameters(tmp_path, parameters, error_msg): - extra_genkw_config = [] - if parameters: - for group_name in parameters: - extra_genkw_config.append( - GenKwConfig( - name=group_name, - forward_init=False, - transform_function_definitions=[ - TransformFunctionDefinition( - name=param, param_name="UNIFORM", values=[0, 1] - ) - for param in parameters[group_name] - ], - update=True, - ) - ) +def test_read_and_merge_with_existing_parameters( + tmp_path, parameters, num_configs, input_source, group_name +): + genkw_configs = [ + GenKwConfig( + name=param, + group="COEFFS", + distribution={"name": "uniform", "min": 0, "max": 1}, + ) + for param in parameters + ] realizations = [0, 1, 2] design_path = tmp_path / "design_matrix.xlsx" @@ -163,21 +172,16 @@ def test_read_and_merge_with_existing_parameters(tmp_path, parameters, error_msg default_sheet_df = pl.DataFrame([["a", 1], ["b", 4]], orient="row") _create_design_matrix(design_path, design_matrix_df, default_sheet_df) design_matrix = DesignMatrix(design_path, "DesignSheet", "DefaultSheet") - if error_msg: - with pytest.raises(ValueError, match=error_msg): - design_matrix.merge_with_existing_parameters(extra_genkw_config) - elif len(parameters) == 1: - new_config_parameters, design_group = ( - design_matrix.merge_with_existing_parameters(extra_genkw_config) + new_config_parameters = design_matrix.merge_with_existing_parameters(genkw_configs) + assert len(new_config_parameters) == num_configs + for config in new_config_parameters: + assert config.name in input_source + assert config.input_source == input_source[config.name], ( + f"{config} mismatch in input source" ) - assert len(new_config_parameters) == 0 - assert design_group.name == "COEFFS" - elif len(parameters) == 2: - new_config_parameters, design_group = ( - design_matrix.merge_with_existing_parameters(extra_genkw_config) + assert config.group == group_name[config.name], ( + f"{config} mismatch in group name" ) - assert len(new_config_parameters) == 2 - assert design_group.name == DESIGN_MATRIX_GROUP def test_reading_design_matrix(tmp_path): @@ -196,7 +200,7 @@ def test_reading_design_matrix(tmp_path): ) _create_design_matrix(design_path, design_matrix_df, default_sheet_df) design_matrix = DesignMatrix(design_path, "DesignSheet", "DefaultSheet") - design_params = design_matrix.parameter_configuration + design_params = [cfg.name for cfg in design_matrix.parameter_configurations] assert all(param in design_params for param in ("a", "b", "c", "one", "d")) assert design_matrix.active_realizations == [True, True, False, False, True] diff --git a/tests/ert/unit_tests/storage/migration/test_to13.py b/tests/ert/unit_tests/storage/migration/test_to13.py new file mode 100644 index 00000000000..41b329933b4 --- /dev/null +++ b/tests/ert/unit_tests/storage/migration/test_to13.py @@ -0,0 +1,48 @@ +from ert.storage.migration.to13 import migrate_gen_kw_param + + +def test_that_migrate_genkw_parameters_maps_tfds_to_single_param_instances(): + original_gen_kw = { + "COEFFS": { + "name": "COEFFS", + "forward_init": False, + "update": False, + "transform_function_definitions": [ + {"name": "a", "param_name": "UNIFORM", "values": ["0", "1"]}, + {"name": "b", "param_name": "RAW", "values": []}, + {"name": "c", "param_name": "LOGNORMAL", "values": ["0", "2"]}, + ], + "type": "gen_kw", + } + } + + migrated = migrate_gen_kw_param(original_gen_kw) + + assert set(migrated.keys()) == {"a", "b", "c"} + assert migrated["a"] == { + "name": "a", + "type": "gen_kw", + "group": "COEFFS", + "distribution": {"name": "uniform", "min": "0", "max": "1"}, + "forward_init": False, + "update": False, + "input_source": "sampled", + } + assert migrated["b"] == { + "name": "b", + "type": "gen_kw", + "group": "COEFFS", + "distribution": {"name": "raw"}, + "forward_init": False, + "update": False, + "input_source": "design_matrix", + } + assert migrated["c"] == { + "name": "c", + "type": "gen_kw", + "group": "COEFFS", + "distribution": {"name": "lognormal", "mean": "0", "std": "2"}, + "forward_init": False, + "update": False, + "input_source": "sampled", + } diff --git a/tests/ert/unit_tests/storage/migration/test_version_1.py b/tests/ert/unit_tests/storage/migration/test_version_1.py deleted file mode 100644 index 54dc3ec7d05..00000000000 --- a/tests/ert/unit_tests/storage/migration/test_version_1.py +++ /dev/null @@ -1,28 +0,0 @@ -import json - -import pytest - -from ert.config import ErtConfig -from ert.storage import open_storage -from ert.storage.local_storage import local_storage_set_ert_config - - -@pytest.fixture(scope="module", autouse=True) -def set_ert_config(block_storage_path): - ert_config = ErtConfig.from_file( - str(block_storage_path / "version-1/poly_example/poly.ert") - ) - yield local_storage_set_ert_config(ert_config) - local_storage_set_ert_config(None) - - -@pytest.mark.filterwarnings("ignore:.*The SIMULATION_JOB keyword has been removed") -def test_migrate_gen_kw(setup_case): - setup_case("block_storage/version-1/poly_example", "poly.ert") - with open_storage("storage", "w") as storage: - assert len(list(storage.experiments)) == 1 - experiment = next(iter(storage.experiments)) - param_info = json.loads( - (experiment._path / "parameter.json").read_text(encoding="utf-8") - ) - assert "COEFFS" in param_info diff --git a/tests/ert/unit_tests/storage/migration/test_version_2.py b/tests/ert/unit_tests/storage/migration/test_version_2.py deleted file mode 100644 index 288aebcdf7f..00000000000 --- a/tests/ert/unit_tests/storage/migration/test_version_2.py +++ /dev/null @@ -1,55 +0,0 @@ -import json - -import pytest - -from ert.config import ErtConfig -from ert.storage import open_storage -from ert.storage.local_storage import local_storage_set_ert_config - - -@pytest.fixture(scope="module", autouse=True) -def set_ert_config(block_storage_path): - ert_config = ErtConfig.from_file( - str(block_storage_path / "version-2/snake_oil/snake_oil.ert") - ) - yield local_storage_set_ert_config(ert_config) - local_storage_set_ert_config(None) - - -@pytest.mark.filterwarnings("ignore:Config contains a SUMMARY key") -@pytest.mark.filterwarnings("ignore:IES_ENKF has been removed and has no effect") -def test_migrate_responses(setup_case, set_ert_config): - ert_config = setup_case("block_storage/version-2/snake_oil", "snake_oil.ert") - with open_storage(ert_config.ens_path, "w") as storage: - assert len(list(storage.experiments)) == 1 - experiment = next(iter(storage.experiments)) - response_info = json.loads( - (experiment._path / "responses.json").read_text(encoding="utf-8") - ) - - response_config_exp = experiment.response_configuration - response_config_ens = ert_config.ensemble_config.response_configs - - # From storage v9 and onwards the response config is mutated - # when migrating an existing experiment, because we check that the - # keys in response.json are aligned with the dataset. - response_config_ens["summary"].has_finalized_keys = response_config_exp[ - "summary" - ].has_finalized_keys - response_config_ens["summary"].keys = response_config_exp["summary"].keys - - assert response_config_exp == response_config_ens - - assert set(response_info) == { - "gen_data", - "summary", - } - - -@pytest.mark.filterwarnings("ignore:Config contains a SUMMARY key") -@pytest.mark.filterwarnings("ignore:IES_ENKF has been removed and has no effect") -def test_migrate_gen_kw_config(setup_case, set_ert_config): - ert_config = setup_case("block_storage/version-2/snake_oil", "snake_oil.ert") - with open_storage(ert_config.ens_path, "w") as storage: - experiment = next(iter(storage.experiments)) - assert "template_file_path" not in experiment.parameter_configuration diff --git a/tests/ert/unit_tests/storage/migration/test_version_3.py b/tests/ert/unit_tests/storage/migration/test_version_3.py deleted file mode 100644 index 9af1b52000f..00000000000 --- a/tests/ert/unit_tests/storage/migration/test_version_3.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -from ert.config import ErtConfig -from ert.storage import open_storage -from ert.storage.local_storage import local_storage_set_ert_config - - -@pytest.fixture(scope="module", autouse=True) -def set_ert_config(block_storage_path): - ert_config = ErtConfig.from_file( - str(block_storage_path / "version-3/poly_example/poly.ert") - ) - yield local_storage_set_ert_config(ert_config) - local_storage_set_ert_config(None) - - -@pytest.mark.filterwarnings("ignore:.*The SIMULATION_JOB keyword has been removed") -def test_migrate_observations(setup_case, set_ert_config): - ert_config = setup_case("block_storage/version-3/poly_example", "poly.ert") - with open_storage(ert_config.ens_path, "w") as storage: - assert len(list(storage.experiments)) == 1 - experiment = next(iter(storage.experiments)) - - assert experiment.observations.keys() == ert_config.observations.keys() - assert all( - experiment.observations[k].equals(ert_config.observations[k]) - for k in experiment.observations - ) - - -@pytest.mark.filterwarnings("ignore:.*The SIMULATION_JOB keyword has been removed") -def test_migrate_gen_kw_config(setup_case, set_ert_config): - ert_config = setup_case("block_storage/version-3/poly_example", "poly.ert") - with open_storage(ert_config.ens_path, "w") as storage: - experiment = next(iter(storage.experiments)) - assert "template_file_path" not in experiment.parameter_configuration diff --git a/tests/ert/unit_tests/storage/snapshots/test_storage_migration/test_migration_to_genkw_with_polars_and_design_matrix/14.2/design_matrix_snapshot.json b/tests/ert/unit_tests/storage/snapshots/test_storage_migration/test_migration_to_genkw_with_polars_and_design_matrix/14.2/design_matrix_snapshot.json index 50bb087193b..86f7a0a769a 100644 --- a/tests/ert/unit_tests/storage/snapshots/test_storage_migration/test_migration_to_genkw_with_polars_and_design_matrix/14.2/design_matrix_snapshot.json +++ b/tests/ert/unit_tests/storage/snapshots/test_storage_migration/test_migration_to_genkw_with_polars_and_design_matrix/14.2/design_matrix_snapshot.json @@ -1,72 +1,72 @@ [ { + "realization": 0, "a": 0, "category": "cat1", "b": 1, - "c": 2, - "realization": 0 + "c": 2 }, { + "realization": 1, "a": 1, "category": "cat1", "b": 1, - "c": 2, - "realization": 1 + "c": 2 }, { + "realization": 2, "a": 2, "category": "cat1", "b": 1, - "c": 2, - "realization": 2 + "c": 2 }, { + "realization": 3, "a": 3, "category": "cat1", "b": 1, - "c": 2, - "realization": 3 + "c": 2 }, { + "realization": 4, "a": 4, "category": "cat1", "b": 1, - "c": 2, - "realization": 4 + "c": 2 }, { + "realization": 5, "a": 5, "category": "cat2", "b": 1, - "c": 2, - "realization": 5 + "c": 2 }, { + "realization": 6, "a": 6, "category": "cat2", "b": 1, - "c": 2, - "realization": 6 + "c": 2 }, { + "realization": 7, "a": 7, "category": "cat2", "b": 1, - "c": 2, - "realization": 7 + "c": 2 }, { + "realization": 8, "a": 8, "category": "cat2", "b": 1, - "c": 2, - "realization": 8 + "c": 2 }, { + "realization": 9, "a": 9, "category": "cat2", "b": 1, - "c": 2, - "realization": 9 + "c": 2 } ] diff --git a/tests/ert/unit_tests/storage/snapshots/test_storage_migration/test_that_storage_matches/parameters b/tests/ert/unit_tests/storage/snapshots/test_storage_migration/test_that_storage_matches/parameters index a762b114f2b..78882b03c90 100644 --- a/tests/ert/unit_tests/storage/snapshots/test_storage_migration/test_that_storage_matches/parameters +++ b/tests/ert/unit_tests/storage/snapshots/test_storage_migration/test_that_storage_matches/parameters @@ -4,16 +4,13 @@ "name": "BPR", "forward_init": false, "update": true, - "transform_function_definitions": [ - { - "name": "BPR", - "param_name": "NORMAL", - "values": [ - "0", - "1" - ] - } - ] + "distribution": { + "name": "normal", + "mean": 0.0, + "std": 1.0 + }, + "group": "BPR", + "input_source": "sampled" }, "PORO": { "type": "field", diff --git a/tests/ert/unit_tests/storage/test_local_storage.py b/tests/ert/unit_tests/storage/test_local_storage.py index 9071d47ad8b..97bbc554021 100644 --- a/tests/ert/unit_tests/storage/test_local_storage.py +++ b/tests/ert/unit_tests/storage/test_local_storage.py @@ -35,7 +35,6 @@ SurfaceConfig, ) from ert.config.design_matrix import DESIGN_MATRIX_GROUP -from ert.config.gen_kw_config import TransformFunctionDefinition from ert.sample_prior import sample_prior from ert.storage import ( ErtStorageException, @@ -194,7 +193,7 @@ def test_that_saving_empty_parameters_fails_nicely(tmp_path): "must contain a 'values' variable" ), ): - prior.save_parameters("PARAMETER", 0, xr.Dataset()) + prior.save_parameters(xr.Dataset(), "PARAMETER", 0) # Test for dataset with 'values' and 'transformed_values' but no actual data empty_data = xr.Dataset( @@ -210,19 +209,13 @@ def test_that_saving_empty_parameters_fails_nicely(tmp_path): "Cannot proceed with saving to storage\\." ), ): - prior.save_parameters("PARAMETER", 0, empty_data) + prior.save_parameters(empty_data, "PARAMETER", 0) def test_that_loading_parameter_via_response_api_fails(tmp_path): uniform_parameter = GenKwConfig( - name="PARAMETER", - forward_init=False, - transform_function_definitions=[ - TransformFunctionDefinition( - name="KEY1", param_name="UNIFORM", values=[0, 1] - ), - ], - update=True, + name="KEY_1", + distribution={"name": "uniform", "min": 0, "max": 1}, ) with open_storage(tmp_path, mode="w") as storage: experiment = storage.create_experiment( @@ -236,17 +229,15 @@ def test_that_loading_parameter_via_response_api_fails(tmp_path): ) prior.save_parameters( - "PARAMETER", - 0, - xr.Dataset( + pl.DataFrame( { - "values": ("names", [1.0]), - "names": ["KEY_1"], + "realization": [0], + "KEY_1": [1.0], } - ), + ) ) - with pytest.raises(ValueError, match="PARAMETER is not a response"): - prior.load_responses("PARAMETER", (0,)) + with pytest.raises(ValueError, match="KEY_1 is not a response"): + prior.load_responses("KEY_1", (0,)) def test_that_load_responses_throws_exception(tmp_path): @@ -496,10 +487,10 @@ def _inner(params): st.one_of( st.builds( GenKwConfig, - name=words, + name=st.text(), + group_name=st.text(), update=st.booleans(), - forward_init=st.booleans(), - transform_function_definitions=st.just([]), + distribution=st.just({"name": "uniform", "min": 0, "max": 1}), ), st.builds(SurfaceConfig), ), @@ -867,7 +858,7 @@ def test_that_all_parameters_and_gen_data_consolidation_works( for batch, realization_info in enumerate(ensemble_realization_infos): failed_realizations = failed_realizations_per_batch.get(batch, {}) num_realizations = len(realization_info) - everest_realization_info = {i: v for i, v in enumerate(realization_info)} # noqa: C416 + everest_realization_info = dict(enumerate(realization_info)) ensemble = storage.create_ensemble( experiment, ensemble_size=num_realizations, iteration=batch ) @@ -884,7 +875,7 @@ def test_that_all_parameters_and_gen_data_consolidation_works( "names": param_keys, } ) - ensemble.save_parameters("point", realization, param_data) + ensemble.save_parameters(param_data, "point", realization) if realization in failed_realizations: ensemble.set_failure( @@ -945,9 +936,7 @@ def test_that_all_parameters_and_gen_data_consolidation_works( pytest.param([10, 11], True, id="incorrect_active_realizations"), ], ) -def test_save_parameters_to_storage_from_design_dataframe( - tmp_path, reals, expect_error -): +def test_sample_parameter_with_design_matrix(tmp_path, reals, expect_error): design_path = tmp_path / "design_matrix.xlsx" ensemble_size = 10 a_values = np.random.default_rng().uniform(-5, 5, 10) @@ -962,16 +951,32 @@ def test_save_parameters_to_storage_from_design_dataframe( design_matrix = DesignMatrix(design_path, "DesignSheet", "DefaultSheet") with open_storage(tmp_path / "storage", mode="w") as storage: experiment_id = storage.create_experiment( - parameters=[design_matrix.parameter_configuration] + parameters=list(design_matrix.parameter_configurations) ) ensemble = storage.create_ensemble( experiment_id, name="default", ensemble_size=ensemble_size ) if expect_error: with pytest.raises(KeyError): - design_matrix.save_to_ensemble(ensemble, reals) + sample_prior( + ensemble, + reals, + random_seed=123, + parameters=[ + param.name for param in design_matrix.parameter_configurations + ], + design_matrix_df=design_matrix.design_matrix_df, + ) else: - design_matrix.save_to_ensemble(ensemble, reals) + sample_prior( + ensemble, + reals, + random_seed=123, + parameters=[ + param.name for param in design_matrix.parameter_configurations + ], + design_matrix_df=design_matrix.design_matrix_df, + ) params = ensemble.load_parameters(DESIGN_MATRIX_GROUP).drop("realization") assert isinstance(params, pl.DataFrame) assert params.columns == ["a", "b", "c"] @@ -1229,13 +1234,13 @@ def save_field(self, model_ensemble: Ensemble, field_data): for f in fields: model_ensemble.parameter_values[f.name] = field_data storage_ensemble.save_parameters( - f.name, - 1, xr.DataArray( field_data, name="values", dims=["x", "y", "z"], # type: ignore ).to_dataset(), + f.name, + 1, ) @rule( @@ -1268,13 +1273,13 @@ def write_error_in_save_field(self, model_ensemble: Ensemble, field_data): pytest.raises(RuntimeError), ): storage_ensemble.save_parameters( - f.name, - self.iens_to_edit, xr.DataArray( field_data, name="values", dims=["x", "y", "z"], # type: ignore ).to_dataset(), + f.name, + self.iens_to_edit, ) assert temp_file.entered diff --git a/tests/ert/unit_tests/storage/test_parameter_sample_types.py b/tests/ert/unit_tests/storage/test_parameter_sample_types.py index e45ad47d66b..3141a5918af 100644 --- a/tests/ert/unit_tests/storage/test_parameter_sample_types.py +++ b/tests/ert/unit_tests/storage/test_parameter_sample_types.py @@ -11,7 +11,6 @@ from resdata.geometry import Surface from ert.config import ConfigValidationError, ErtConfig, GenKwConfig -from ert.config.gen_kw_config import TransformFunctionDefinition from ert.sample_prior import sample_prior from ert.storage import open_storage from ert.storage.local_ensemble import load_parameters_and_responses_from_runpath @@ -232,58 +231,62 @@ def test_that_first_three_parameters_sampled_snapshot(tmpdir, storage): [4, 5, 10], ) @pytest.mark.parametrize( - "template, prior", + "template, scalars", [ ( "MY_KEYWORD \nMY_SECOND_KEYWORD ", [ - TransformFunctionDefinition( - name="MY_KEYWORD", param_name="NORMAL", values=[0, 1] + GenKwConfig( + name="MY_KEYWORD", + group="KW_NAME", + distribution={"name": "normal", "mean": 0, "std": 1}, ), - TransformFunctionDefinition( - name="MY_SECOND_KEYWORD", param_name="NORMAL", values=[0, 1] + GenKwConfig( + name="MY_SECOND_KEYWORD", + group="KW_NAME", + distribution={"name": "normal", "mean": 0, "std": 1}, ), ], ), ( "MY_KEYWORD ", [ - TransformFunctionDefinition( - name="MY_KEYWORD", param_name="NORMAL", values=[0, 1] - ) + GenKwConfig( + name="MY_KEYWORD", + group="KW_NAME", + distribution={"name": "normal", "mean": 0, "std": 1}, + ), ], ), ( "MY_FIRST_KEYWORD \nMY_KEYWORD ", [ - TransformFunctionDefinition( - name="MY_FIRST_KEYWORD", param_name="NORMAL", values=[0, 1] + GenKwConfig( + name="MY_FIRST_KEYWORD", + group="KW_NAME", + distribution={"name": "normal", "mean": 0, "std": 1}, ), - TransformFunctionDefinition( - name="MY_KEYWORD", param_name="NORMAL", values=[0, 1] + GenKwConfig( + name="MY_KEYWORD", + group="KW_NAME", + distribution={"name": "normal", "mean": 0, "std": 1}, ), ], ), ], ) def test_that_sampling_is_fixed_from_name( - tmpdir, storage, template, prior, num_realisations + tmpdir, storage, template, scalars, num_realisations ): """ Testing that the order and number of parameters is not relevant for the values, only that name of the parameter and the global seed determine the values. """ with tmpdir.as_cwd(): - conf = GenKwConfig( - name="KW_NAME", - forward_init=False, - transform_function_definitions=prior, - update=True, - ) with open("template.txt", "w", encoding="utf-8") as fh: fh.writelines(template) fs = storage.create_ensemble( - storage.create_experiment(parameters=[conf]), + storage.create_experiment(parameters=scalars), name="prior", ensemble_size=num_realisations, ) @@ -441,13 +444,14 @@ def test_gen_kw(storage, tmpdir, config_str, expected, extra_files, expectation) ( [False, False], pytest.raises( - KeyError, match="No KW_NAME dataset in storage for ensemble default" + KeyError, match="No SCALAR dataset in storage for ensemble default" ), ), ( [False, True], pytest.raises( - IndexError, match=r"No matching realizations \[0\] found for KW_NAME" + IndexError, + match=r"No matching realizations \[0\] found for \['MY_KEYWORD'\]", ), ), ], diff --git a/tests/ert/unit_tests/storage/test_storage_migration.py b/tests/ert/unit_tests/storage/test_storage_migration.py index e3ee893bc09..50cb2eb1f73 100644 --- a/tests/ert/unit_tests/storage/test_storage_migration.py +++ b/tests/ert/unit_tests/storage/test_storage_migration.py @@ -64,14 +64,16 @@ def test_migration_to_genkw_with_polars_and_design_matrix( ensemble = ensembles[0] df = ensemble.load_parameters("DESIGN_MATRIX") assert isinstance(df, pl.DataFrame) - assert df.schema == pl.Schema( - { - "a": pl.Int64, - "category": pl.String, - "b": pl.Int64, - "c": pl.Int64, - "realization": pl.Int64, - } + assert dict(df.schema) == dict( + pl.Schema( + { + "a": pl.Int64, + "category": pl.String, + "b": pl.Int64, + "c": pl.Int64, + "realization": pl.Int64, + } + ) ) snapshot.assert_match( orjson.dumps(df.to_dicts(), option=orjson.OPT_INDENT_2) @@ -137,18 +139,6 @@ def test_migration_to_genkw_with_polars_and_design_matrix( "6.0.2", "6.0.1", "6.0.0", - "5.0.12", - "5.0.11", - "5.0.10", - "5.0.9", - "5.0.8", - "5.0.7", - "5.0.6", - "5.0.5", - "5.0.4", - "5.0.2", - "5.0.1", - "5.0.0", ], ) def test_that_storage_matches( @@ -199,7 +189,7 @@ def test_that_storage_matches( assert experiment.templates_configuration == [("\nBPR:\n", "params.txt")] df = ensemble.load_parameters("BPR") assert isinstance(df, pl.DataFrame) - assert df.schema == pl.Schema({"BPR": pl.Float64, "realization": pl.Int64}) + assert dict(df.schema) == {"BPR": pl.Float64, "realization": pl.Int64} assert df["realization"].to_list() == list(range(ensemble.ensemble_size)) snapshot.assert_match( json.dumps( @@ -319,18 +309,6 @@ def test_that_storage_matches( "6.0.2", "6.0.1", "6.0.0", - "5.0.12", - "5.0.11", - "5.0.10", - "5.0.9", - "5.0.8", - "5.0.7", - "5.0.6", - "5.0.5", - "5.0.4", - "5.0.2", - "5.0.1", - "5.0.0", ], ) def test_that_storage_works_with_missing_parameters_and_responses( @@ -440,7 +418,6 @@ def test_that_migrate_blockfs_creates_backup_folder(tmp_path, caplog): "8.4.5", "8.0.11", "6.0.5", - "5.0.0", ], ) def test_that_manual_update_from_migrated_storage_works( @@ -579,16 +556,6 @@ def test_that_manual_update_from_migrated_storage_works( "6.0.3", "6.0.1", "6.0.0", - "5.0.11", - "5.0.9", - "5.0.8", - "5.0.7", - "5.0.6", - "5.0.5", - "5.0.4", - "5.0.2", - "5.0.1", - "5.0.0", ], ) def test_migrate_storage_with_no_responses( diff --git a/tests/ert/unit_tests/test_run_path_creation.py b/tests/ert/unit_tests/test_run_path_creation.py index d5e6bb14cb0..1fd831344f5 100644 --- a/tests/ert/unit_tests/test_run_path_creation.py +++ b/tests/ert/unit_tests/test_run_path_creation.py @@ -823,27 +823,28 @@ def test_that_ertcase_is_replaced_in_runpath(placeholder, make_run_path): def save_zeros(prior_ensemble, num_realizations, dim_size): parameter_configs = prior_ensemble.experiment.parameter_configuration for config_node in parameter_configs.values(): - for realization_nr in range(num_realizations): - if isinstance(config_node, SurfaceConfig): + if isinstance(config_node, SurfaceConfig): + for realization_nr in range(num_realizations): prior_ensemble.save_parameters_numpy( np.zeros(dim_size**2).reshape(-1, 1), config_node.name, np.array([realization_nr]), ) - elif isinstance(config_node, Field): + elif isinstance(config_node, Field): + for realization_nr in range(num_realizations): prior_ensemble.save_parameters_numpy( np.zeros(dim_size**3).reshape(-1, 1), config_node.name, np.array([realization_nr]), ) - elif isinstance(config_node, GenKwConfig): - prior_ensemble.save_parameters_numpy( - np.zeros(1).reshape(-1, 1), - config_node.name, - np.array([realization_nr]), - ) - else: - raise ValueError(f"unexpected {config_node}") + elif isinstance(config_node, GenKwConfig): + prior_ensemble.save_parameters_numpy( + np.zeros(2), + config_node.name, + np.array(range(num_realizations)), + ) + else: + raise ValueError(f"unexpected {config_node}") @pytest.mark.usefixtures("use_tmpdir") From 7092e04a0d3e47e031391e313a679b7c3de596d0 Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Tue, 23 Sep 2025 09:23:24 +0200 Subject: [PATCH 2/2] fixup webviz-ert --- src/ert/shared/storage/extraction.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ert/shared/storage/extraction.py b/src/ert/shared/storage/extraction.py index 8273f3e5650..2305ebb22a6 100644 --- a/src/ert/shared/storage/extraction.py +++ b/src/ert/shared/storage/extraction.py @@ -27,9 +27,12 @@ def create_priors( for param in experiment.parameter_configuration.values(): if isinstance(param, GenKwConfig): + dist_dict = param.distribution.model_dump(mode="json") + dist_dict.pop("name", None) + prior: dict[str, str | float] = { "function": _PRIOR_NAME_MAP[param.distribution.name.upper()], - } + } | dist_dict priors_dict[f"{param.group}:{param.name}"] = prior return priors_dict