Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/ert/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
ForwardModelStepWarning,
)
from .gen_data_config import GenDataConfig
from .gen_kw_config import GenKwConfig, PriorDict, TransformFunction
from .gen_kw_config import DataSource, GenKwConfig, PriorDict
from .lint_file import lint_file
from .model_config import ModelConfig
from .parameter_config import ParameterConfig, ParameterMetadata
from .parameter_config import ParameterCardinality, ParameterConfig, ParameterMetadata
from .parsing import (
ConfigValidationError,
ConfigWarning,
Expand Down Expand Up @@ -74,6 +74,7 @@
"AnalysisModule",
"ConfigValidationError",
"ConfigWarning",
"DataSource",
"DesignMatrix",
"ESSettings",
"EnsembleConfig",
Expand Down Expand Up @@ -109,6 +110,7 @@
"ObservationSettings",
"ObservationType",
"OutlierSettings",
"ParameterCardinality",
"ParameterConfig",
"ParameterMetadata",
"PostExperimentFixtures",
Expand All @@ -125,7 +127,6 @@
"ResponseMetadata",
"SummaryConfig",
"SurfaceConfig",
"TransformFunction",
"WarningInfo",
"Workflow",
"WorkflowConfigs",
Expand Down
110 changes: 34 additions & 76 deletions src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from collections import Counter
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, cast
Expand All @@ -10,12 +9,12 @@
import polars as pl
from polars.exceptions import InvalidOperationError

from .gen_kw_config import GenKwConfig, TransformFunctionDefinition
from .distribution import RawSettings
from .gen_kw_config import DataSource, GenKwConfig
from .parsing import ConfigValidationError, ErrorInfo

if TYPE_CHECKING:
from ert.config import ParameterConfig
from ert.storage import Ensemble


DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"
Expand All @@ -32,7 +31,7 @@ def __post_init__(self) -> None:
(
self.active_realizations,
self.design_matrix_df,
self.parameter_configuration,
self.parameter_configurations,
) = self.read_and_validate_design_matrix()
except (ValueError, AttributeError) as exc:
raise ConfigValidationError.with_context(
Expand All @@ -42,26 +41,6 @@ def __post_init__(self) -> None:
str(self.xls_filename),
) from exc

def save_to_ensemble(
self,
ensemble: Ensemble,
active_realizations: Iterable[int],
design_group_name: str = DESIGN_MATRIX_GROUP,
) -> None:
design_matrix_df = self.design_matrix_df
assert not design_matrix_df.is_empty()
if not set(active_realizations) <= set(
design_matrix_df["realization"].to_list()
):
raise KeyError("Active realization mask is not in design matrix!")
ensemble.save_parameters(
design_group_name,
realization=None,
dataset=design_matrix_df.filter(
pl.col("realization").is_in(list(active_realizations))
),
)

@classmethod
def from_config_list(cls, config_list: list[str | dict[str, str]]) -> DesignMatrix:
filename = Path(cast(str, config_list[0]))
Expand Down Expand Up @@ -149,66 +128,44 @@ def merge_with_other(self, dm_other: DesignMatrix) -> None:
f"{dm_other.default_sheet or ''})': {exc}!"
) from exc

for tfd in dm_other.parameter_configuration.transform_function_definitions:
if tfd.name not in common_keys:
self.parameter_configuration.transform_function_definitions.append(tfd)
self.parameter_configurations.extend(
cfg
for cfg in dm_other.parameter_configurations
if cfg.name not in common_keys
)

def merge_with_existing_parameters(
self, existing_parameters: list[ParameterConfig]
) -> tuple[list[ParameterConfig], GenKwConfig]:
) -> list[ParameterConfig]:
"""
This method merges the design matrix parameters with the existing parameters and
returns the new list of existing parameters, wherein we drop GEN_KW group having
a full overlap with the design matrix group. GEN_KW group that was dropped will
acquire a new name from the design matrix group. Additionally, the
ParameterConfig which is the design matrix group is returned separately.
returns the new list of existing parameters.

Args:
existing_parameters (List[ParameterConfig]): List of existing parameters

Raises:
ConfigValidationError: If there is a partial overlap between the design
matrix group and any existing GEN_KW group

Returns:
tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters
and the dedicated design matrix group
List[ParameterConfig]: List of new parameters after merge
"""

new_param_config: list[ParameterConfig] = []
new_param_configs: list[ParameterConfig] = []

design_parameter_group = self.parameter_configuration
design_keys = [e.name for e in design_parameter_group.transform_functions]
design_cfgs = {cfg.name: cfg for cfg in self.parameter_configurations}

design_group_added = False
for parameter_group in existing_parameters:
if not isinstance(parameter_group, GenKwConfig):
new_param_config += [parameter_group]
continue
existing_keys = [e.name for e in parameter_group.transform_functions]
if set(existing_keys) == set(design_keys):
if design_group_added:
raise ConfigValidationError(
"Multiple overlapping groups with design matrix found in "
"existing parameters!\n"
f"{design_parameter_group.name} and {parameter_group.name}"
)

design_parameter_group.name = parameter_group.name
design_group_added = True
elif set(design_keys) & set(existing_keys):
raise ConfigValidationError(
"Overlapping parameter names found in design matrix!\n"
f"{DESIGN_MATRIX_GROUP}:{design_keys}\n{parameter_group.name}:{existing_keys}"
"\nThey need to match exactly or not at all."
)
else:
new_param_config += [parameter_group]
return new_param_config, design_parameter_group
for param_cfg in existing_parameters:
if isinstance(param_cfg, GenKwConfig) and param_cfg.name in design_cfgs:
param_cfg.input_source = DataSource.DESIGN_MATRIX
param_cfg.update = False
param_cfg.distribution = RawSettings()
del design_cfgs[param_cfg.name]
new_param_configs += [param_cfg]
if design_cfgs.values():
new_param_configs += list(design_cfgs.values())
return new_param_configs

def read_and_validate_design_matrix(
self,
) -> tuple[list[bool], pl.DataFrame, GenKwConfig]:
) -> tuple[list[bool], pl.DataFrame, list[GenKwConfig]]:
# Read the parameter names (first row) as strings to prevent polars from
# modifying them. This ensures that duplicate or empty column names are
# preserved exactly as they appear in the Excel sheet. By doing this, we
Expand Down Expand Up @@ -308,23 +265,24 @@ def read_and_validate_design_matrix(
design_matrix_df = design_matrix_df.with_row_index(name="realization")

design_matrix_df = convert_numeric_string_columns(design_matrix_df)
transform_function_definitions = [
TransformFunctionDefinition(name=col, param_name="RAW", values=[])

parameter_configurations: list[GenKwConfig] = [
GenKwConfig(
name=col,
update=False,
group=DESIGN_MATRIX_GROUP,
input_source=DataSource.DESIGN_MATRIX,
distribution={"name": "raw"},
)
for col in design_matrix_df.columns
if col != "realization"
]
parameter_configuration = GenKwConfig(
name=DESIGN_MATRIX_GROUP,
forward_init=False,
transform_function_definitions=transform_function_definitions,
update=False,
)

reals = design_matrix_df.get_column("realization").to_list()
return (
[x in reals for x in range(max(reals) + 1)],
design_matrix_df,
parameter_configuration,
parameter_configurations,
)

@staticmethod
Expand Down
14 changes: 14 additions & 0 deletions src/ert/config/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,20 @@ def transform(self, x: float) -> float:
return float(result)


DistributionSettings = (
UnifSettings
| LogNormalSettings
| LogUnifSettings
| DUnifSettings
| RawSettings
| ConstSettings
| NormalSettings
| TruncNormalSettings
| ErrfSettings
| DerrfSettings
| TriangularSettings
)

DISTRIBUTION_CLASSES: dict[str, type[TransSettingsValidation]] = {
"NORMAL": NormalSettings,
"LOGNORMAL": LogNormalSettings,
Expand Down
15 changes: 7 additions & 8 deletions src/ert/config/ensemble_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ def set_derived_fields(self) -> Self:
[p.name for p in self.parameter_configs.values()],
[key for config in self.response_configs.values() for key in config.keys],
)
self._check_for_duplicate_gen_kw_param_names(
[p for p in self.parameter_configs.values() if isinstance(p, GenKwConfig)]
)

return self

Expand All @@ -69,9 +66,7 @@ def _check_for_duplicate_names(

@staticmethod
def _check_for_duplicate_gen_kw_param_names(gen_kw_list: list[GenKwConfig]) -> None:
gen_kw_param_count = Counter(
keyword.name for p in gen_kw_list for keyword in p.transform_functions
)
gen_kw_param_count = Counter(p.name for p in gen_kw_list)
duplicate_gen_kw_names = [
(n, c) for n, c in gen_kw_param_count.items() if c > 1
]
Expand Down Expand Up @@ -156,12 +151,16 @@ def make_field(field_list: list[str | dict[str, str]]) -> FieldConfig:

return FieldConfig.from_config_list(grid_file_path, dims, field_list)

gen_kw_cfgs = [
cfg for g in gen_kw_list for cfg in GenKwConfig.from_config_list(g)
]

parameter_configs = (
[GenKwConfig.from_config_list(g) for g in gen_kw_list]
gen_kw_cfgs
+ [SurfaceConfig.from_config_list(s) for s in surface_list]
+ [make_field(f) for f in field_list]
)

EnsembleConfig._check_for_duplicate_gen_kw_param_names(gen_kw_cfgs)
response_configs: list[KnownResponseTypes] = []

for config_cls in _KNOWN_RESPONSE_TYPES:
Expand Down
52 changes: 16 additions & 36 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
make_observation_declarations,
)
from .analysis_config import AnalysisConfig
from .design_matrix import DESIGN_MATRIX_GROUP
from .ensemble_config import EnsembleConfig
from .forward_model_step import (
ForwardModelJSON,
Expand Down Expand Up @@ -759,10 +758,10 @@ def validate_genkw_parameter_name_overlap(self) -> Self:
def validate_dm_parameter_name_overlap(self) -> Self:
if not self.analysis_config.design_matrix:
return self
dm_param_config = self.analysis_config.design_matrix.parameter_configuration
dm_param_configs = self.analysis_config.design_matrix.parameter_configurations
overlapping_parameter_names = [
parameter_definition.name
for parameter_definition in dm_param_config.transform_function_definitions
for parameter_definition in dm_param_configs
if f"<{parameter_definition.name}>" in self.substitutions
or parameter_definition.name in ErtConfig.RESERVED_KEYWORDS
]
Expand Down Expand Up @@ -998,39 +997,20 @@ def from_dict(cls, config_dict: ConfigDict) -> Self:
raise ConfigValidationError.from_collected(errors)

if dm := analysis_config.design_matrix:
dm_errors = []
dm_params = {
x.name
for x in dm.parameter_configuration.transform_function_definitions
}
for group_name, config in ensemble_config.parameter_configs.items():
if not isinstance(config, GenKwConfig):
continue
group_params = {x.name for x in config.transform_function_definitions}
if group_name == DESIGN_MATRIX_GROUP:
dm_errors.append(
ConfigValidationError(
f"Cannot have GEN_KW with group name {DESIGN_MATRIX_GROUP} "
"when using DESIGN_MATRIX keyword."
)
)
if dm_params == group_params:
ConfigWarning.warn(
f"Parameters {group_params} from GEN_KW group '{group_name}' "
"will be overridden by design matrix. This will cause "
"updates to be turned off for these parameters."
)
elif intersection := dm_params & group_params:
dm_errors.append(
ConfigValidationError(
"Only full overlaps of design matrix and "
"one genkw group are supported.\n"
f"design matrix parameters: {dm_params}\n"
f"parameters in genkw group <{group_name}>: "
f"{group_params}\n"
f"overlap between them: {intersection}"
)
)
dm_errors: list[ErrorInfo | ConfigValidationError] = []
dm_params = {x.name for x in dm.parameter_configurations}
overwrite_params = [
cfg.name
for cfg in ensemble_config.parameter_configs.values()
if isinstance(cfg, GenKwConfig) and cfg.name in dm_params
]
if overwrite_params:
ConfigWarning.warn(
f"Parameters {dm_params} "
"will be overridden by design matrix. This will cause "
"updates to be turned off for these parameters."
)

if dm_errors:
raise ConfigValidationError.from_collected(dm_errors)

Expand Down
Loading