diff --git a/benchkit/benchmark.py b/benchkit/benchmark.py index 31876490..f4496eea 100644 --- a/benchkit/benchmark.py +++ b/benchkit/benchmark.py @@ -918,6 +918,24 @@ def _temp_record_data_dir(self, record_data_dir: pathlib.Path): # never an absolute path. return self._temp_record_prefix() / f"./{record_data_dir}" + def _update_pretty_variables(self, experiment_results: Dict[str, Any]): + if self._pretty_variables: + for var_name in self._pretty_variables: + ugly2pretty = self._pretty_variables[var_name] + ugly_var_value = experiment_results.get(var_name) + + if not isinstance(ugly2pretty, dict): + # If the pretty variable is not a dict, assume it's the pretty column name + experiment_results[ugly2pretty] = ugly_var_value + continue + + pretty_var_value = ugly2pretty.get(ugly_var_value, ugly_var_value) + experiment_results[f"{var_name}_pretty"] = f'"{pretty_var_value}"' + # If __category__ is defined, also create a column with that name + category = ugly2pretty.get("__category__") + if category is not None: + experiment_results[category] = f'"{pretty_var_value}"' + def _run_single_run( self, record_parameters: Dict[str, Any], @@ -960,22 +978,7 @@ def _run_single_run( experiment_results.update(run_variables) experiment_results.update(other_variables) - if self._pretty_variables: - for var_name in self._pretty_variables: - ugly2pretty = self._pretty_variables[var_name] - ugly_var_value = experiment_results.get(var_name) - - if not isinstance(ugly2pretty, dict): - # If the pretty variable is not a dict, assume it's the pretty column name - experiment_results[ugly2pretty] = ugly_var_value - continue - - pretty_var_value = ugly2pretty.get(ugly_var_value, ugly_var_value) - experiment_results[f"{var_name}_pretty"] = f'"{pretty_var_value}"' - # If __category__ is defined, also create a column with that name - category = ugly2pretty.get("__category__") - if category is not None: - experiment_results[category] = f'"{pretty_var_value}"' + self._update_pretty_variables(experiment_results=experiment_results) experiment_results.update({"rep": run_id})