From b3923453c667ab53a185afb52d4e4e8e87f219cd Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Tue, 22 Jul 2025 11:06:15 +0100 Subject: [PATCH 1/8] Plotting tweaks --- .../compressor/plotting/error_dist_plotter.py | 6 +- .../compressor/plotting/plot_metrics.py | 108 ++++++++++-------- 2 files changed, 62 insertions(+), 52 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/error_dist_plotter.py b/src/climatebenchpress/compressor/plotting/error_dist_plotter.py index ab79802..b32c77b 100644 --- a/src/climatebenchpress/compressor/plotting/error_dist_plotter.py +++ b/src/climatebenchpress/compressor/plotting/error_dist_plotter.py @@ -53,7 +53,6 @@ def plot_error_bound_histograms( # does not change the error plot distribution. Hence, we ignore the PCO # compressors here. compressors = [comp for comp in compressors if "-pco" not in comp] - for j, var in enumerate(variables): for comp in compressors: color, linestyle = get_line_info(comp) @@ -63,8 +62,11 @@ def plot_error_bound_histograms( label = "BitRound" elif label.startswith("StochRound"): label = "StochRound" + # Filter out inf values + error_data = self.errors[var][comp] + error_data = error_data[~np.isinf(error_data) & ~np.isnan(error_data)] self.axes[j, col_index].hist( - self.errors[var][comp], + np.float64(error_data), bins=100, density=True, histtype="step", diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 0b0f468..905c288 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -46,6 +46,13 @@ def get_lineinfo(compressor: str) -> tuple[str, str]: ("tthresh", "TTHRESH"), ] +DISTORTION2LEGEND_NAME = { + "Relative MAE": "Mean Absolute Error", + "Relative DSSIM": "DSSIM", + "Relative MaxAbsError": "Max Absolute Error", + "Spectral Error": "Spectral Error", +} + def get_legend_name(compressor: str) -> str: """Get the legend name for a given compressor.""" @@ -85,6 +92,7 @@ def plot_metrics( compressed_datasets=compressed_datasets, plots_path=plots_path, all_results=df, + rd_curves_metrics=["Max Absolute Error", "MAE", "DSSIM", "Spectral Error"], ) df = rename_compressors(df) @@ -95,16 +103,23 @@ def plot_metrics( plot_throughput(df, plots_path / "throughput.pdf") plot_instruction_count(df, plots_path / "instruction_count.pdf") - for metric in ["Relative MAE", "Relative DSSIM", "Relative MaxAbsError"]: + for metric in [ + "Relative MAE", + "Relative DSSIM", + "Relative MaxAbsError", + "Spectral Error", + ]: with plt.rc_context(rc={"text.usetex": use_latex}): plot_aggregated_rd_curve( normalized_df, normalizer=normalizer, compression_metric="Relative CR", distortion_metric=metric, - outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", + outfile=plots_path + / f"rd_curve_{metric.lower().replace(' ', '_')}_exclude=ta_tos_pr_rlut.pdf", agg="median", bound_names=bound_names, + exclude_vars=["ta", "tos", "pr", "rlut"], ) @@ -192,6 +207,7 @@ def plot_per_variable_metrics( compressed_datasets: Path, plots_path: Path, all_results: pd.DataFrame, + rd_curves_metrics: list[str] = ["Max Absolute Error", "MAE"], ): """Creates all the plots which only depend on a single variable.""" for dataset in all_results["Dataset"].unique(): @@ -202,7 +218,7 @@ def plot_per_variable_metrics( # For each variable and compressor, plot the input, output, and error fields. variables = df["Variable"].unique() for var in variables: - for dist_metric in ["Max Absolute Error", "MAE"]: + for dist_metric in rd_curves_metrics: metric_name = dist_metric.lower().replace(" ", "_") if df[df["Variable"] == var][dist_metric].isnull().all(): continue @@ -361,12 +377,12 @@ def plot_aggregated_rd_curve( outfile: None | Path = None, agg="median", bound_names=["low", "mid", "high"], + exclude_vars=None, ): plt.figure(figsize=(8, 6)) - if distortion_metric == "DSSIM": - # For fields with large number of NaNs, the DSSIM values are unreliable - # which is why we exclude them here. - normalized_df = normalized_df[~normalized_df["Variable"].isin(["ta", "tos"])] + if exclude_vars: + # Exclude variables that are not relevant for the distortion metric. + normalized_df = normalized_df[~normalized_df["Variable"].isin(exclude_vars)] compressors = normalized_df["Compressor"].unique() agg_distortion = normalized_df.groupby(["Error Bound Name", "Compressor"])[ @@ -400,11 +416,6 @@ def plot_aggregated_rd_curve( plt.yscale("log") plt.ylabel(f"{agg.title()} {distortion_metric}", fontsize=14) - plt.legend( - title="Compressor", - fontsize=10, - title_fontsize=12, - ) plt.tick_params( axis="both", which="major", @@ -422,31 +433,32 @@ def plot_aggregated_rd_curve( top=True, right=True, ) - normalizer_label = get_legend_name(normalizer) - if "MAE" in distortion_metric: - plt.legend( - title="Compressor", - loc="upper right", - bbox_to_anchor=(0.95, 0.7), - fontsize=12, - title_fontsize=14, - ) - plt.xlabel( - rf"Median Compression Ratio Relative to {normalizer_label} ($\uparrow$)", - fontsize=16, - ) - plt.ylabel( - rf"Median Mean Absolute Error Relative to {normalizer_label} ($\downarrow$)", - fontsize=16, - ) - arrow_color = "black" - # Add an arrow pointing into the lower right corner + plt.xlabel( + rf"Median Compression Ratio Relative to {normalizer_label} ($\uparrow$)", + fontsize=16, + ) + metric_name = DISTORTION2LEGEND_NAME.get(distortion_metric, distortion_metric) + plt.ylabel( + rf"Median {metric_name} Relative to {normalizer_label} ($\downarrow$)", + fontsize=16, + ) + plt.legend( + title="Compressor", + loc="upper right", + bbox_to_anchor=(0.95, 0.7), + fontsize=12, + title_fontsize=14, + ) + + arrow_color = "black" + if "DSSIM" in distortion_metric: + # Add an arrow pointing into the top right corner plt.annotate( "", - xy=(0.95, 0.05), + xy=(0.95, 0.95), xycoords="axes fraction", - xytext=(-60, 50), + xytext=(-60, -50), textcoords="offset points", arrowprops=dict( arrowstyle="-|>, head_length=0.5, head_width=0.5", @@ -454,32 +466,25 @@ def plot_aggregated_rd_curve( lw=5, ), ) + # Attach the text to the lower left of the arrow plt.text( 0.83, - 0.08, + 0.92, "Better", transform=plt.gca().transAxes, fontsize=16, fontweight="bold", color=arrow_color, ha="center", + va="center", ) - elif "DSSIM" in distortion_metric: - plt.xlabel( - rf"Median Compression Ratio Relative to {normalizer_label} ($\uparrow$)", - fontsize=16, - ) - plt.ylabel( - rf"Median DSSIM to {normalizer_label} ($\downarrow$)", - fontsize=16, - ) - arrow_color = "black" - # Add an arrow pointing into the top right corner + else: + # Add an arrow pointing into the lower right corner plt.annotate( "", - xy=(0.95, 0.95), + xy=(0.95, 0.05), xycoords="axes fraction", - xytext=(-60, -50), + xytext=(-60, 50), textcoords="offset points", arrowprops=dict( arrowstyle="-|>, head_length=0.5, head_width=0.5", @@ -487,18 +492,21 @@ def plot_aggregated_rd_curve( lw=5, ), ) - # Attach the text to the lower left of the arrow plt.text( 0.83, - 0.92, + 0.08, "Better", transform=plt.gca().transAxes, fontsize=16, fontweight="bold", color=arrow_color, ha="center", - va="center", ) + if ( + "DSSIM" in distortion_metric + or "MaxAbsError" in distortion_metric + or "Spectral Error" in distortion_metric + ): plt.legend().remove() plt.tight_layout() From a95e26d55e2725b9bcc48f2575b4197e796507bc Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Thu, 24 Jul 2025 15:33:46 +0100 Subject: [PATCH 2/8] Ensure consistent color bar for sea surface temperature --- .../compressor/plotting/variable_plotters.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/variable_plotters.py b/src/climatebenchpress/compressor/plotting/variable_plotters.py index 8c76858..aaed144 100644 --- a/src/climatebenchpress/compressor/plotting/variable_plotters.py +++ b/src/climatebenchpress/compressor/plotting/variable_plotters.py @@ -58,13 +58,21 @@ class CmipOceanPlotter(Plotter): datasets = ["cmip6-access-tos-tiny", "cmip6-access-tos"] def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + # Calculate shared vmin and vmax for consistent color ranges + data_orig = ds.isel(time=0).values.squeeze() + data_new = ds_new.isel(time=0).values.squeeze() + vmin = min(np.nanmin(data_orig), np.nanmin(data_new)) + vmax = max(np.nanmax(data_orig), np.nanmax(data_new)) + pcm0 = ax[0].pcolormesh( ds.longitude.values, ds.latitude.values, - ds.isel(time=0).values.squeeze(), + data_orig, transform=ccrs.PlateCarree(), shading="auto", cmap="coolwarm", + vmin=vmin, + vmax=vmax, rasterized=True, ) fig.colorbar( @@ -74,10 +82,12 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): pcm1 = ax[1].pcolormesh( ds_new.longitude.values, ds_new.latitude.values, - ds_new.isel(time=0).values.squeeze(), + data_new, transform=ccrs.PlateCarree(), shading="auto", cmap="coolwarm", + vmin=vmin, + vmax=vmax, rasterized=True, ) fig.colorbar( From 506619fd724eb395b9550c89cf82e179b8659982 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Thu, 31 Jul 2025 15:00:17 +0100 Subject: [PATCH 3/8] Improve compressor plots --- .../compressor/plotting/plot_metrics.py | 10 +- .../compressor/plotting/variable_plotters.py | 272 +++++++++++++----- 2 files changed, 217 insertions(+), 65 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 905c288..2b7d7e8 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -272,6 +272,7 @@ def plot_per_variable_metrics( dataset, comp, var, + error_bound_vals[var], outfile=err_bound_path / f"{var}_{comp}.png", ) @@ -295,6 +296,7 @@ def plot_variable_error( dataset_name: str, compressor: str, var: str, + err_bound: tuple[str, float], outfile: None | Path = None, ): if outfile is not None and outfile.exists(): @@ -304,7 +306,13 @@ def plot_variable_error( plotter = PLOTTERS.get(dataset_name, None) if plotter: plotter().plot( - uncompressed_data, compressed_data, dataset_name, compressor, var, outfile + uncompressed_data, + compressed_data, + dataset_name, + compressor, + var, + err_bound, + outfile, ) else: print(f"No plotter found for dataset {dataset_name}") diff --git a/src/climatebenchpress/compressor/plotting/variable_plotters.py b/src/climatebenchpress/compressor/plotting/variable_plotters.py index aaed144..6fb4f3c 100644 --- a/src/climatebenchpress/compressor/plotting/variable_plotters.py +++ b/src/climatebenchpress/compressor/plotting/variable_plotters.py @@ -5,6 +5,8 @@ import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np +import xarray as xr +import xarray.plot.utils as xplot_utils class Plotter(ABC): @@ -12,13 +14,21 @@ class Plotter(ABC): def __init__(self): self.projection = ccrs.Robinson() + self.error_title = "Error" @abstractmethod - def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): pass def plot( - self, ds, ds_new, dataset_name, compressor, var, outfile: None | Path = None + self, + ds, + ds_new, + dataset_name, + compressor, + var, + err_bound, + outfile: None | Path = None, ): fig, ax = plt.subplots( nrows=1, @@ -26,14 +36,14 @@ def plot( figsize=(20, 7), subplot_kw={"projection": self.projection}, ) - self.plot_fields(fig, ax, ds, ds_new, dataset_name, var) + self.plot_fields(fig, ax, ds, ds_new, dataset_name, var, err_bound) ax[0].coastlines() ax[1].coastlines() ax[2].coastlines() ax[0].set_title("Original Dataset") ax[1].set_title("Compressed Dataset") - ax[2].set_title("Error") - fig.suptitle(f"{var} Error for {dataset_name} ({compressor})") + ax[2].set_title(self.error_title) + # fig.suptitle(f"{var} Error for {dataset_name} ({compressor})") fig.tight_layout() if outfile is not None: with outfile.open("wb") as f: @@ -44,77 +54,109 @@ def plot( class CmipAtmosPlotter(Plotter): datasets = ["cmip6-access-ta-tiny", "cmip6-access-ta"] - def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): selector = dict(time=0, plev=3) - ds.isel(**selector).plot(ax=ax[0], transform=ccrs.PlateCarree()) - ds_new.isel(**selector).plot( - ax=ax[1], transform=ccrs.PlateCarree(), robust=True + # Calculate shared vmin and vmax for consistent color ranges + data_orig = ds.isel(**selector) + data_new = ds_new.isel(**selector) + vmin = np.nanmin(data_orig.values.squeeze()) + vmax = np.nanmax(data_orig.values.squeeze()) + + data_orig.plot(ax=ax[0], transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax) + data_new.plot( + ax=ax[1], transform=ccrs.PlateCarree(), robust=True, vmin=vmin, vmax=vmax ) - error = ds.isel(**selector) - ds_new.isel(**selector) - error.plot(ax=ax[2], transform=ccrs.PlateCarree(), rasterized=True) + error = data_orig - data_new + error.attrs["long_name"] = data_orig.attrs.get("long_name", "") + error.attrs["units"] = data_orig.attrs.get("units", "") + + _, bound_value = err_bound + vmin_error, vmax_error = -bound_value, bound_value + error.plot( + ax=ax[2], + transform=ccrs.PlateCarree(), + rasterized=True, + vmin=vmin_error, + vmax=vmax_error, + cbar_kwargs={"ticks": [-bound_value, 0, bound_value]}, + cmap="seismic", + ) + self.error_title = "Absolute Error" class CmipOceanPlotter(Plotter): datasets = ["cmip6-access-tos-tiny", "cmip6-access-tos"] - def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): # Calculate shared vmin and vmax for consistent color ranges data_orig = ds.isel(time=0).values.squeeze() - data_new = ds_new.isel(time=0).values.squeeze() - vmin = min(np.nanmin(data_orig), np.nanmin(data_new)) - vmax = max(np.nanmax(data_orig), np.nanmax(data_new)) - - pcm0 = ax[0].pcolormesh( - ds.longitude.values, - ds.latitude.values, - data_orig, + vmin, vmax = np.nanmin(data_orig), np.nanmax(data_orig) + + ds.isel(time=0).plot( + ax=ax[0], + x="longitude", + y="latitude", transform=ccrs.PlateCarree(), - shading="auto", cmap="coolwarm", vmin=vmin, vmax=vmax, rasterized=True, + cbar_kwargs={ + "orientation": "vertical", + "fraction": 0.046, + "pad": 0.04, + "label": "degC", + }, ) - fig.colorbar( - pcm0, ax=ax[0], orientation="vertical", fraction=0.046, pad=0.04 - ).set_label("degC") - pcm1 = ax[1].pcolormesh( - ds_new.longitude.values, - ds_new.latitude.values, - data_new, + ds_new.isel(time=0).plot( + ax=ax[1], + x="longitude", + y="latitude", transform=ccrs.PlateCarree(), - shading="auto", cmap="coolwarm", vmin=vmin, vmax=vmax, rasterized=True, + cbar_kwargs={ + "orientation": "vertical", + "fraction": 0.046, + "pad": 0.04, + "label": "degC", + }, ) - fig.colorbar( - pcm1, ax=ax[1], orientation="vertical", fraction=0.046, pad=0.04 - ).set_label("degC") error = ds.isel(time=0) - ds_new.isel(time=0) - pcm2 = ax[2].pcolormesh( - ds.longitude.values, - ds.latitude.values, - error.values.squeeze(), + _, bound_value = err_bound + vmin_error, vmax_error = -bound_value, bound_value + error.plot( + ax=ax[2], + x="longitude", + y="latitude", transform=ccrs.PlateCarree(), - shading="auto", - cmap="coolwarm", + vmin=vmin_error, + vmax=vmax_error, + cmap="seismic", rasterized=True, + cbar_kwargs={ + "orientation": "vertical", + "fraction": 0.046, + "pad": 0.04, + "label": "degC", + "ticks": [-bound_value, 0, bound_value], + }, ) - fig.colorbar( - pcm2, ax=ax[2], orientation="vertical", fraction=0.046, pad=0.04 - ).set_label("degC") + self.error_title = "Absolute Error" class Era5Plotter(Plotter): datasets = ["era5-tiny", "era5"] - def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): selector = dict(time=0) - error = ds.isel(**selector) - ds_new.isel(**selector) + # Calculate shared vmin and vmax for consistent color ranges + data_orig = ds.isel(**selector).values.squeeze() + vmin, vmax = np.nanmin(data_orig), np.nanmax(data_orig) # Instead of using the inbuilt xarray plot method, we are manually doing # the projection and calling pcolormesh. By doing so we can avoid having @@ -128,25 +170,57 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): # Wind variable plots coolwarm because they lie around 0 and change in sign # signifies change in wind direction. cmap = "coolwarm" if var.startswith("10m") else "viridis" - c1 = ax[0].pcolormesh(x, y, ds.isel(**selector).values.squeeze(), cmap=cmap) + c1 = ax[0].pcolormesh( + x, y, ds.isel(**selector).values.squeeze(), cmap=cmap, vmin=vmin, vmax=vmax + ) c2 = ax[1].pcolormesh( x, y, ds_new.isel(**selector).values.squeeze(), cmap=cmap, rasterized=True, + vmin=vmin, + vmax=vmax, ) - c3 = ax[2].pcolormesh(x, y, error.values.squeeze(), cmap="coolwarm") + + error = ds.isel(**selector) - ds_new.isel(**selector) + error.attrs["long_name"] = ds.isel(**selector).attrs.get("long_name", "") + error.attrs["units"] = ds.isel(**selector).attrs.get("units", "") + + _, bound_value = err_bound + c3 = ax[2].pcolormesh( + x, + y, + error.values.squeeze(), + cmap="seismic", + vmin=-bound_value, + vmax=bound_value, + ) + self.error_title = "Absolute Error" for i, c in enumerate([c1, c2, c3]): - fig.colorbar(c, ax=ax[i], shrink=0.6) + if i == 0: + extend = xplot_utils._determine_extend(ds.isel(**selector), vmin, vmax) + label = xplot_utils.label_from_attrs(ds) + elif i == 1: + extend = xplot_utils._determine_extend( + ds_new.isel(**selector), vmin, vmax + ) + label = xplot_utils.label_from_attrs(ds_new) + else: + extend = xplot_utils._determine_extend(error, -bound_value, bound_value) + label = xplot_utils.label_from_attrs(error) + cbar = fig.colorbar(c, ax=ax[i], shrink=0.6, extend=extend, label=label) + if i == 2: + cbar.ax.set_yticks([-bound_value, 0, bound_value]) class NextGEMSPlotter(Plotter): datasets = ["nextgems-icon-tiny", "nextgems-icon"] - def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): selector = dict(time=0) - error = ds.isel(**selector) - ds_new.isel(**selector) + data_orig = ds.isel(**selector).values.squeeze() + vmin, vmax = np.nanmin(data_orig), np.nanmax(data_orig) lons = ds.isel(**selector).lon.values lats = ds.isel(**selector).lat.values @@ -155,13 +229,13 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): x, y = xys[..., 0], xys[..., 1] cmap = "Blues" - max_val = max( - ds.isel(**selector).max().values.item(), - ds_new.isel(**selector).max().values.item(), - ) - color_norm = mcolors.LogNorm(vmin=1e-12, vmax=max_val) if var == "pr" else None # Avoid zero values for log transformation for precipitation offset = 1e-12 if var == "pr" else 0 + color_norm = ( + mcolors.LogNorm(vmin=1e-12, vmax=vmax + offset) + if var == "pr" + else mcolors.Normalize(vmin=vmin + offset, vmax=vmax + offset) + ) c1 = ax[0].pcolormesh( x, y, @@ -178,15 +252,49 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): cmap=cmap, rasterized=True, ) - c3 = ax[2].pcolormesh(x, y, error.values.squeeze(), cmap="coolwarm") + + bound_type, bound_value = err_bound + error = ds.isel(**selector) - ds_new.isel(**selector) + self.error_title = "Absolute Error" + if bound_type == "rel_error": + error = error / np.abs(ds.isel(**selector)) + self.error_title = "Relative Error" + + c3 = ax[2].pcolormesh( + x, + y, + error.values.squeeze(), + cmap="seismic", + vmin=-bound_value, + vmax=bound_value, + ) for i, c in enumerate([c1, c2, c3]): - fig.colorbar(c, ax=ax[i], shrink=0.6) + if i == 0: + extend = xplot_utils._determine_extend( + ds.isel(**selector), vmin + offset, vmax + offset + ) + label = xplot_utils.label_from_attrs(ds) + elif i == 1: + extend = xplot_utils._determine_extend( + ds_new.isel(**selector), vmin + offset, vmax + offset + ) + label = xplot_utils.label_from_attrs(ds_new) + elif i == 2 and bound_type == "rel_error": + extend = xplot_utils._determine_extend(error, -bound_value, bound_value) + label = "" + elif i == 2 and bound_type == "abs_error": + extend = xplot_utils._determine_extend(error, -bound_value, bound_value) + # Error has same label as original dataset. + label = xplot_utils.label_from_attrs(ds) + cbar = fig.colorbar(c, ax=ax[i], shrink=0.6, label=label, extend=extend) + if i == 2: + cbar.ax.set_yticks([-bound_value, 0, bound_value]) class CamsPlotter(Plotter): datasets = ["cams-nitrogen-dioxide-tiny", "cams-nitrogen-dioxide"] - def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): selector = dict(valid_time=0, pressure_level=3) in_min = ds.isel(**selector).min().values.item() in_max = ds.isel(**selector).max().values.item() @@ -209,22 +317,58 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): cmap="gist_earth", rasterized=True, ) + error = ds.isel(**selector) - ds_new.isel(**selector) - error.plot(ax=ax[2], transform=ccrs.PlateCarree()) + rel_error = error / np.abs(ds.isel(**selector)) + # Sets the colorbar label. + rel_error.attrs["long_name"] = "" + + _, bound_value = err_bound + vmin_error, vmax_error = -bound_value, bound_value + rel_error.plot( + ax=ax[2], + transform=ccrs.PlateCarree(), + vmin=vmin_error, + vmax=vmax_error, + cmap="seismic", + cbar_kwargs={"ticks": [-bound_value, 0, bound_value]}, + ) + self.error_title = "Relative Error" class EsaBiomassPlotter(Plotter): datasets = ["esa-biomass-cci-tiny", "esa-biomass-cci"] - def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): selector = dict(time=0) - ds.isel(**selector).plot(ax=ax[0]) - ds_new.isel(**selector).plot(ax=ax[1]) + data_orig = ds.isel(**selector).values.squeeze() + vmin, vmax = np.nanmin(data_orig), np.nanmax(data_orig) + + ds.isel(**selector).plot(ax=ax[0], cmap="Greens", vmin=vmin, vmax=vmax) + ds_new.isel(**selector).plot(ax=ax[1], cmap="Greens", vmin=vmin, vmax=vmax) + + _, bound_value = err_bound error = ds.isel(**selector) - ds_new.isel(**selector) - error.plot(ax=ax[2], rasterized=True) - ax[0].set_title("Original Dataset") - ax[1].set_title("Compressed Dataset") - ax[2].set_title("Error") + non_zero_mask = np.abs(ds.isel(**selector)) > 0.0 + # Check where both original and new data are zero + both_zero_mask = (np.abs(ds.isel(**selector)) == 0.0) & ( + np.abs(ds_new.isel(**selector)) == 0.0 + ) + rel_error = xr.where( + both_zero_mask, + 0.0, + xr.where(non_zero_mask, error / np.abs(ds.isel(**selector)), 1e12), + ) + rel_error.attrs["long_name"] = "" + rel_error.plot( + ax=ax[2], + rasterized=True, + cmap="seismic", + vmin=-bound_value, + vmax=bound_value, + cbar_kwargs={"ticks": [-bound_value, 0, bound_value]}, + ) + self.error_title = "Relative Error" plotter_clss: list[type[Plotter]] = [ From f319735e0d717c4d4164daa8ddc3302dc1fb8e51 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Wed, 6 Aug 2025 16:32:50 +0100 Subject: [PATCH 4/8] Fine-tune plots for paper --- .../compressor/plotting/error_dist_plotter.py | 76 +++++++++++++------ .../compressor/plotting/plot_metrics.py | 40 +++++----- .../compressor/plotting/variable_plotters.py | 11 ++- 3 files changed, 83 insertions(+), 44 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/error_dist_plotter.py b/src/climatebenchpress/compressor/plotting/error_dist_plotter.py index b32c77b..ab64461 100644 --- a/src/climatebenchpress/compressor/plotting/error_dist_plotter.py +++ b/src/climatebenchpress/compressor/plotting/error_dist_plotter.py @@ -5,12 +5,20 @@ class ErrorDistPlotter: def __init__(self, variables, error_bounds): - self.fig, self.axes = plt.subplots( - len(variables), - len(error_bounds), - figsize=(17, 5 * len(variables)), - squeeze=False, - ) + self.variables = variables + self.error_bounds = error_bounds + self.figs = {} + self.axes = {} + # Create a separate figure for each variable + for var in variables: + self.figs[var], self.axes[var] = plt.subplots( + 1, + len(error_bounds), + figsize=(17, 5), + squeeze=False, + ) + # Remove the first dimension since we only have 1 row + self.axes[var] = self.axes[var][0] self.errors = {var: dict() for var in variables} @@ -53,7 +61,7 @@ def plot_error_bound_histograms( # does not change the error plot distribution. Hence, we ignore the PCO # compressors here. compressors = [comp for comp in compressors if "-pco" not in comp] - for j, var in enumerate(variables): + for var in variables: for comp in compressors: color, linestyle = get_line_info(comp) label = get_legend_name(comp) @@ -65,7 +73,7 @@ def plot_error_bound_histograms( # Filter out inf values error_data = self.errors[var][comp] error_data = error_data[~np.isinf(error_data) & ~np.isnan(error_data)] - self.axes[j, col_index].hist( + self.axes[var][col_index].hist( np.float64(error_data), bins=100, density=True, @@ -73,26 +81,27 @@ def plot_error_bound_histograms( label=label, color=color, linestyle=linestyle, - linewidth=2, - alpha=0.8, + linewidth=4, + alpha=0.6, ) error_bound_name, error_bound_value = error_bound_vals[var] - self.axes[j, col_index].set_xlabel("Error Value") - self.axes[j, col_index].set_ylabel("Log Probability Density") - self.axes[j, col_index].set_yscale("log") + self.axes[var][col_index].set_xlabel("Error Value", fontsize=14) + self.axes[var][col_index].set_ylabel("Log Probability Density", fontsize=14) + self.axes[var][col_index].set_yscale("log") xticks = np.linspace(-2 * error_bound_value, 2 * error_bound_value, num=5) - xlabels = ["0.0" if x == 0.0 else f"{x:.2e}" for x in xticks] - self.axes[j, col_index].set_xticks(xticks, labels=xlabels) - self.axes[j, col_index].set_xlim( + xlabels = [_format_number(x) for x in xticks] + self.axes[var][col_index].set_xticks(xticks, labels=xlabels, fontsize=12) + self.axes[var][col_index].set_xlim( -2 * error_bound_value, 2 * error_bound_value ) - self.axes[j, col_index].set_title( - f"{var}\n{error_bound_name} = {error_bound_value:.2e}" - if col_index == 1 - else f"{error_bound_name} = {error_bound_value:.2e}" + err_bound_type = ( + "Abs. Error" if error_bound_name == "abs_error" else "Rel. Error" + ) + self.axes[var][col_index].set_title( + f"{err_bound_type} = {_format_number(error_bound_value)}", fontsize=14 ) # Reset errors for the next iteration. Ensures we don't plot the wrong errors @@ -100,9 +109,12 @@ def plot_error_bound_histograms( self.errors = {var: dict() for var in variables} def get_final_figure(self): - self.axes[0, -1].legend(loc="center left", bbox_to_anchor=(1, 0.5)) - self.fig.tight_layout() - return self.fig, self.axes + for var in self.variables: + self.axes[var][-1].legend( + loc="center left", bbox_to_anchor=(1, 0.5), fontsize=14 + ) + self.figs[var].tight_layout() + return self.figs, self.axes def robust_error(x, y): @@ -127,3 +139,21 @@ def robust_error(x, y): result = xr.where(both_nan | (both_inf & ~inf_sign_mismatch), 0.0, x - y) return result + + +def _format_number(value: float) -> str: + if abs(value) > 10_000: + text = f"{value:.1e}" + elif abs(value) > 10: + text = f"{value:.0f}" + elif abs(value) > 1: + text = f"{value:.1f}" + elif value == 0: + text = "0.0" + elif abs(value) < 0.001: + text = f"{value:.1e}" + else: + # 0.001 <= abs(value) < 1 + text = f"{value:.3f}" + + return text diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 2b7d7e8..13b9fe6 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -115,11 +115,10 @@ def plot_metrics( normalizer=normalizer, compression_metric="Relative CR", distortion_metric=metric, - outfile=plots_path - / f"rd_curve_{metric.lower().replace(' ', '_')}_exclude=ta_tos_pr_rlut.pdf", + outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", agg="median", bound_names=bound_names, - exclude_vars=["ta", "tos", "pr", "rlut"], + # exclude_vars=["ta", "tos", "pr", "rlut"], ) @@ -285,9 +284,12 @@ def plot_per_variable_metrics( get_lineinfo, ) - fig, _ = error_dist_plotter.get_final_figure() - savefig(dataset_plots_path / f"error_histograms_{dataset}.pdf") - plt.close(fig) + figs, _ = error_dist_plotter.get_final_figure() + for var, fig in figs.items(): + savefig( + dataset_plots_path / f"error_histograms_{dataset}_{var}.pdf", fig=fig + ) + plt.close(fig) def plot_variable_error( @@ -540,6 +542,7 @@ def plot_throughput(df, outfile: None | Path = None): grouped_df, title="", ylabel="Throughput [s / MB]", + logy=True, outfile=outfile, ) @@ -552,6 +555,7 @@ def plot_instruction_count(df, outfile: None | Path = None): grouped_df, title="", ylabel="Instructions [# / raw B]", + logy=True, outfile=outfile, ) @@ -581,7 +585,7 @@ def get_median_and_quantiles(df, encode_column, decode_column): ) -def plot_grouped_df(grouped_df, title, ylabel, outfile: None | Path = None): +def plot_grouped_df(grouped_df, title, ylabel, outfile: None | Path = None, logy=False): fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True) # Bar width @@ -626,20 +630,21 @@ def plot_grouped_df(grouped_df, title, ylabel, outfile: None | Path = None): # Add labels and title ax.set_xticks([p + bar_width / 2 for p in x_positions]) - ax.set_xticklabels(x_labels, rotation=45, ha="right") - ax.set_title(f"{error_bound.capitalize()} Error Bound") + ax.set_xticklabels(x_labels, rotation=45, ha="right", fontsize=14) + ax.set_yscale("log" if logy else "linear") + ax.set_title(f"{error_bound.capitalize()} Error Bound", fontsize=14) ax.grid(axis="y", linestyle="--", alpha=0.7) if i == 0: - ax.legend() - ax.set_ylabel(ylabel) + ax.legend(fontsize=14) + ax.set_ylabel(ylabel, fontsize=14) ax.annotate( "Better", - xy=(0.05, 0.8), + xy=(0.1, 0.8), xycoords="axes fraction", - xytext=(0.05, 0.95), + xytext=(0.1, 0.95), textcoords="axes fraction", arrowprops=dict(arrowstyle="->", lw=3, color="black"), - fontsize=12, + fontsize=14, ha="center", va="bottom", ) @@ -694,16 +699,17 @@ def plot_bound_violations(df, bound_names, outfile: None | Path = None): plt.close() -def savefig(outfile: Path): +def savefig(outfile: Path, fig=None): ispdf = outfile.suffix == ".pdf" + fig = fig if fig is not None else plt.gcf() if ispdf: # Saving a PDF with the alternative code below leads to a corrupted file. # Hence, we use the default savefig method. # NOTE: This means passing a virtual UPath is only supported for non-PDF files. - plt.savefig(outfile, dpi=300) + fig.savefig(outfile, dpi=300) else: with outfile.open("wb") as f: - plt.savefig(f, dpi=300) + fig.savefig(f, dpi=300) if __name__ == "__main__": diff --git a/src/climatebenchpress/compressor/plotting/variable_plotters.py b/src/climatebenchpress/compressor/plotting/variable_plotters.py index 6fb4f3c..8b849de 100644 --- a/src/climatebenchpress/compressor/plotting/variable_plotters.py +++ b/src/climatebenchpress/compressor/plotting/variable_plotters.py @@ -33,16 +33,16 @@ def plot( fig, ax = plt.subplots( nrows=1, ncols=3, - figsize=(20, 7), + figsize=(18, 6), subplot_kw={"projection": self.projection}, ) self.plot_fields(fig, ax, ds, ds_new, dataset_name, var, err_bound) ax[0].coastlines() ax[1].coastlines() ax[2].coastlines() - ax[0].set_title("Original Dataset") - ax[1].set_title("Compressed Dataset") - ax[2].set_title(self.error_title) + ax[0].set_title("Original Dataset", fontsize=14) + ax[1].set_title("Compressed Dataset", fontsize=14) + ax[2].set_title(self.error_title, fontsize=14) # fig.suptitle(f"{var} Error for {dataset_name} ({compressor})") fig.tight_layout() if outfile is not None: @@ -106,6 +106,7 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): "fraction": 0.046, "pad": 0.04, "label": "degC", + "shrink": 0.6, }, ) @@ -123,6 +124,7 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): "fraction": 0.046, "pad": 0.04, "label": "degC", + "shrink": 0.6, }, ) @@ -143,6 +145,7 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): "fraction": 0.046, "pad": 0.04, "label": "degC", + "shrink": 0.6, "ticks": [-bound_value, 0, bound_value], }, ) From a7104ccfe8928e5b6d3eb7acd7a2814a129f0f8a Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Tue, 12 Aug 2025 08:18:55 +0100 Subject: [PATCH 5/8] Normalize by mean, std adjustment --- .../compressor/plotting/plot_metrics.py | 133 ++++++++++-------- 1 file changed, 78 insertions(+), 55 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 13b9fe6..315a802 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -67,7 +67,6 @@ def plot_metrics( basepath: Path = Path(), data_loader_basepath: None | Path = None, bound_names: list[str] = ["low", "mid", "high"], - normalizer: str = "sz3", exclude_dataset: list[str] = [], exclude_compressor: list[str] = [], tiny_datasets: bool = False, @@ -96,7 +95,7 @@ def plot_metrics( ) df = rename_compressors(df) - normalized_df = normalize(df, bound_normalize="mid", normalizer=normalizer) + normalized_df = normalize(df) plot_bound_violations( normalized_df, bound_names, plots_path / "bound_violations.pdf" ) @@ -107,18 +106,27 @@ def plot_metrics( "Relative MAE", "Relative DSSIM", "Relative MaxAbsError", - "Spectral Error", + "Relative SpectralError", ]: with plt.rc_context(rc={"text.usetex": use_latex}): plot_aggregated_rd_curve( normalized_df, - normalizer=normalizer, compression_metric="Relative CR", distortion_metric=metric, outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", - agg="median", + agg="mean", bound_names=bound_names, - # exclude_vars=["ta", "tos", "pr", "rlut"], + ) + + plot_aggregated_rd_curve( + normalized_df, + compression_metric="Relative CR", + distortion_metric=metric, + outfile=plots_path + / f"full_rd_curve_{metric.lower().replace(' ', '_')}.pdf", + agg="mean", + bound_names=bound_names, + remove_outliers=False, ) @@ -150,51 +158,32 @@ def sort_error_bounds(error_bounds: list[str]) -> list[str]: ) -def normalize(data, bound_normalize="mid", normalizer=None): - """Generate normalized metrics for each compressor and variable. The normalization - is done either with respect to either a user provided compressor or the - compressor with the highest average rank over all variables (ranked by - compression ratio). - - For each metric, the normalization is done by dividing the metric by the value of the - normalizer for the same variable and error bound, i.e.: - normalized_metric = metric[compressor, variable] / metric[normalizer, variable]. - """ - if normalizer is None: - # Group by Variable and rank compressors within each variable - ranked = data.copy() - ranked = ranked[ranked["Error Bound Name"] == bound_normalize] - ranked["CompRatio_Rank"] = ranked.groupby("Variable")[ - "Compression Ratio [raw B / enc B]" - ].rank(ascending=False) - - # Calculate average rank for each compressor across all variables - avg_ranks = ranked.groupby("Compressor")["CompRatio_Rank"].mean().reset_index() - avg_ranks.columns = ["Compressor", "Average_Rank"] - avg_ranks = avg_ranks.sort_values("Average_Rank") - - normalizer = avg_ranks.iloc[0]["Compressor"] - +def normalize(data): normalized = data.copy() normalize_vars = [ ("Compression Ratio [raw B / enc B]", "Relative CR"), ("MAE", "Relative MAE"), ("DSSIM", "Relative DSSIM"), ("Max Absolute Error", "Relative MaxAbsError"), + ("Spectral Error", "Relative SpectralError"), ] - # Avoid negative values. By default, DSSIM is in the range [-1, 1]. - normalized["DSSIM"] = normalized["DSSIM"] + 1.0 - def get_normalizer(row): - return normalized[ - (data["Compressor"] == normalizer) - & (data["Variable"] == row["Variable"]) - & (data["Error Bound Name"] == bound_normalize) - ][col].item() + variables = normalized["Variable"].unique() + + dssim_unreliable = normalized["Variable"].isin(["ta", "tos"]) + normalized.loc[dssim_unreliable, "DSSIM"] = np.nan for col, new_col in normalize_vars: + mean_std = dict() + for var in variables: + mean = normalized[normalized["Variable"] == var][col].mean() + std = normalized[normalized["Variable"] == var][col].std() + mean_std[var] = (mean, std) + + # Normalize each variable by its mean and std normalized[new_col] = normalized.apply( - lambda x: x[col] / get_normalizer(x), + lambda x: (x[col] - mean_std[x["Variable"]][0]) + / mean_std[x["Variable"]][1], axis=1, ) @@ -320,17 +309,23 @@ def plot_variable_error( print(f"No plotter found for dataset {dataset_name}") -def plot_variable_rd_curve(df, distortion_metric, outfile: None | Path = None): +def plot_variable_rd_curve( + df, distortion_metric, bounds=["low", "mid", "high"], outfile: None | Path = None +): plt.figure(figsize=(8, 6)) compressors = df["Compressor"].unique() for comp in compressors: compressor_data = df[df["Compressor"] == comp] - sorting_ixs = np.argsort(compressor_data["Compression Ratio [raw B / enc B]"]) + assert len(compressor_data) == len(bounds) + bound_ixs = [ + compressor_data[compressor_data["Error Bound Name"] == bound].index[0] + for bound in bounds + ] compr_ratio = [ - compressor_data["Compression Ratio [raw B / enc B]"].iloc[i] - for i in sorting_ixs + compressor_data["Compression Ratio [raw B / enc B]"].loc[i] + for i in bound_ixs ] - distortion = [compressor_data[distortion_metric].iloc[i] for i in sorting_ixs] + distortion = [compressor_data[distortion_metric].loc[i] for i in bound_ixs] color, linestyle = get_lineinfo(comp) plt.plot( compr_ratio, @@ -381,13 +376,13 @@ def plot_variable_rd_curve(df, distortion_metric, outfile: None | Path = None): def plot_aggregated_rd_curve( normalized_df, - normalizer, compression_metric, distortion_metric, outfile: None | Path = None, agg="median", bound_names=["low", "mid", "high"], exclude_vars=None, + remove_outliers=True, ): plt.figure(figsize=(8, 6)) if exclude_vars: @@ -398,6 +393,7 @@ def plot_aggregated_rd_curve( agg_distortion = normalized_df.groupby(["Error Bound Name", "Compressor"])[ [compression_metric, distortion_metric] ].agg(agg) + for comp in compressors: compr_ratio = [ agg_distortion.loc[(bound, comp), compression_metric] @@ -419,11 +415,34 @@ def plot_aggregated_rd_curve( markersize=8, ) + if remove_outliers: + # SZ3 and JPEG2000 often give outlier values and violate the bounds. + exclude_compressors = ["sz3", "jpeg2000"] + filtered_agg = agg_distortion[ + ~agg_distortion.index.get_level_values("Compressor").isin( + exclude_compressors + ) + ] + cr_mean, cr_std = ( + filtered_agg[compression_metric].mean(), + filtered_agg[compression_metric].std(), + ) + distortion_mean, distortion_std = ( + filtered_agg[distortion_metric].mean(), + filtered_agg[distortion_metric].std(), + ) + + # Adjust the plot limits + xlims = plt.xlim() + xlims_min = max(xlims[0], cr_mean - 4 * cr_std) + xlims_max = min(xlims[1], cr_mean + 4 * cr_std) + plt.xlim(xlims_min, xlims_max) + ylims = plt.ylim() + ylims_min = max(ylims[0], distortion_mean - 4 * distortion_std) + ylims_max = min(ylims[1], distortion_mean + 4 * distortion_std) + plt.ylim(ylims_min, ylims_max) + plt.xlabel(f"{agg.title()} {compression_metric}", fontsize=14) - plt.xscale("log") - if "PSNR" not in distortion_metric: - # PSNR is already on log scale. - plt.yscale("log") plt.ylabel(f"{agg.title()} {distortion_metric}", fontsize=14) plt.tick_params( @@ -443,20 +462,19 @@ def plot_aggregated_rd_curve( top=True, right=True, ) - normalizer_label = get_legend_name(normalizer) plt.xlabel( - rf"Median Compression Ratio Relative to {normalizer_label} ($\uparrow$)", + r"Mean Normalized Compression Ratio ($\uparrow$)", fontsize=16, ) metric_name = DISTORTION2LEGEND_NAME.get(distortion_metric, distortion_metric) plt.ylabel( - rf"Median {metric_name} Relative to {normalizer_label} ($\downarrow$)", + rf"Mean Normalized {metric_name} ($\downarrow$)", fontsize=16, ) plt.legend( title="Compressor", loc="upper right", - bbox_to_anchor=(0.95, 0.7), + bbox_to_anchor=(0.8, 0.99), fontsize=12, title_fontsize=14, ) @@ -488,6 +506,11 @@ def plot_aggregated_rd_curve( ha="center", va="center", ) + # Correct the y-label to point upwards + plt.ylabel( + rf"Mean Normalized {metric_name} ($\uparrow$)", + fontsize=16, + ) else: # Add an arrow pointing into the lower right corner plt.annotate( @@ -515,7 +538,7 @@ def plot_aggregated_rd_curve( if ( "DSSIM" in distortion_metric or "MaxAbsError" in distortion_metric - or "Spectral Error" in distortion_metric + or "SpectralError" in distortion_metric ): plt.legend().remove() From de812dcf686e8a8c25cd522bdf9f0faf66900185 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Fri, 17 Oct 2025 16:21:45 +0100 Subject: [PATCH 6/8] Change font sizes --- .../compressor/plotting/variable_plotters.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/variable_plotters.py b/src/climatebenchpress/compressor/plotting/variable_plotters.py index 8b849de..f139f49 100644 --- a/src/climatebenchpress/compressor/plotting/variable_plotters.py +++ b/src/climatebenchpress/compressor/plotting/variable_plotters.py @@ -11,6 +11,7 @@ class Plotter(ABC): datasets: list[str] + title_fontsize = 22 def __init__(self): self.projection = ccrs.Robinson() @@ -40,9 +41,9 @@ def plot( ax[0].coastlines() ax[1].coastlines() ax[2].coastlines() - ax[0].set_title("Original Dataset", fontsize=14) - ax[1].set_title("Compressed Dataset", fontsize=14) - ax[2].set_title(self.error_title, fontsize=14) + ax[0].set_title("Original Dataset", fontsize=self.title_fontsize) + ax[1].set_title("Compressed Dataset", fontsize=self.title_fontsize) + ax[2].set_title(self.error_title, fontsize=self.title_fontsize) # fig.suptitle(f"{var} Error for {dataset_name} ({compressor})") fig.tight_layout() if outfile is not None: @@ -87,6 +88,9 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): class CmipOceanPlotter(Plotter): datasets = ["cmip6-access-tos-tiny", "cmip6-access-tos"] + cbar_label_fontsize = 20 + cbar_tick_fontsize = 16 + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): # Calculate shared vmin and vmax for consistent color ranges data_orig = ds.isel(time=0).values.squeeze() @@ -149,12 +153,23 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): "ticks": [-bound_value, 0, bound_value], }, ) + for a in ax: + a.collections[0].colorbar.ax.set_ylabel( + "degC", fontsize=self.cbar_label_fontsize + ) + a.collections[0].colorbar.ax.tick_params(labelsize=self.cbar_tick_fontsize) + ax[2].collections[0].colorbar.ax.yaxis.get_offset_text().set( + size=self.cbar_tick_fontsize + ) self.error_title = "Absolute Error" class Era5Plotter(Plotter): datasets = ["era5-tiny", "era5"] + cbar_label_fontsize = 18 + cbar_tick_fontsize = 14 + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): selector = dict(time=0) # Calculate shared vmin and vmax for consistent color ranges @@ -215,6 +230,8 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): cbar = fig.colorbar(c, ax=ax[i], shrink=0.6, extend=extend, label=label) if i == 2: cbar.ax.set_yticks([-bound_value, 0, bound_value]) + cbar.ax.tick_params(labelsize=self.cbar_tick_fontsize) + cbar.ax.set_ylabel(label, fontsize=self.cbar_label_fontsize) class NextGEMSPlotter(Plotter): From 285c012d3722058ec70d44b6373ff030ccc418b3 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Fri, 28 Nov 2025 09:34:38 +0000 Subject: [PATCH 7/8] Chunking adjustment --- src/climatebenchpress/compressor/plotting/plot_metrics.py | 6 +++++- .../compressor/plotting/variable_plotters.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 315a802..f8372ad 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -239,7 +239,10 @@ def plot_per_variable_metrics( / comp / "decompressed.zarr" ) - input = datasets / dataset / "standardized.zarr" + input_dataset_name = dataset + if dataset.endswith("-chunked"): + input_dataset_name = dataset.removesuffix("-chunked") + input = datasets / input_dataset_name / "standardized.zarr" ds = xr.open_dataset(input, chunks=dict(), engine="zarr") ds_new = xr.open_dataset(compressed, chunks=dict(), engine="zarr") @@ -294,6 +297,7 @@ def plot_variable_error( # These plots can be quite expensive to generate, so we skip if they already exist. return + dataset_name = dataset_name.removesuffix("-chunked") plotter = PLOTTERS.get(dataset_name, None) if plotter: plotter().plot( diff --git a/src/climatebenchpress/compressor/plotting/variable_plotters.py b/src/climatebenchpress/compressor/plotting/variable_plotters.py index f139f49..d59c6a2 100644 --- a/src/climatebenchpress/compressor/plotting/variable_plotters.py +++ b/src/climatebenchpress/compressor/plotting/variable_plotters.py @@ -165,7 +165,7 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): class Era5Plotter(Plotter): - datasets = ["era5-tiny", "era5"] + datasets = ["era5-tiny", "era5", "ifs-uncompressed"] cbar_label_fontsize = 18 cbar_tick_fontsize = 14 @@ -315,7 +315,7 @@ class CamsPlotter(Plotter): datasets = ["cams-nitrogen-dioxide-tiny", "cams-nitrogen-dioxide"] def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): - selector = dict(valid_time=0, pressure_level=3) + selector = dict(valid_time=0, hybrid=3) in_min = ds.isel(**selector).min().values.item() in_max = ds.isel(**selector).max().values.item() out_min = ds_new.isel(**selector).min().values.item() From c7c6005b83457059bb3914d3954c1da28c2cc155 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Tue, 2 Dec 2025 18:19:16 +0000 Subject: [PATCH 8/8] Minor changes --- .../compressor/plotting/error_dist_plotter.py | 11 +++++ .../compressor/plotting/plot_metrics.py | 4 ++ .../compressor/plotting/variable_plotters.py | 45 +++++++++++++++++++ 3 files changed, 60 insertions(+) diff --git a/src/climatebenchpress/compressor/plotting/error_dist_plotter.py b/src/climatebenchpress/compressor/plotting/error_dist_plotter.py index ab64461..dfbb736 100644 --- a/src/climatebenchpress/compressor/plotting/error_dist_plotter.py +++ b/src/climatebenchpress/compressor/plotting/error_dist_plotter.py @@ -97,6 +97,17 @@ def plot_error_bound_histograms( -2 * error_bound_value, 2 * error_bound_value ) + self.axes[var][col_index].axvline( + -error_bound_value, + color="black", + linestyle="--", + linewidth=2, + alpha=0.7, + ) + self.axes[var][col_index].axvline( + error_bound_value, color="black", linestyle="--", linewidth=2, alpha=0.7 + ) + err_bound_type = ( "Abs. Error" if error_bound_name == "abs_error" else "Rel. Error" ) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index f8372ad..c7566c0 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -70,6 +70,7 @@ def plot_metrics( exclude_dataset: list[str] = [], exclude_compressor: list[str] = [], tiny_datasets: bool = False, + chunked_datasets: bool = False, use_latex: bool = True, ): metrics_path = basepath / "metrics" @@ -85,6 +86,9 @@ def plot_metrics( is_tiny = df["Dataset"].str.endswith("-tiny") filter_tiny = is_tiny if tiny_datasets else ~is_tiny df = df[filter_tiny] + is_chunked = df["Dataset"].str.endswith("-chunked") + filter_chunked = is_chunked if chunked_datasets else ~is_chunked + df = df[filter_chunked] plot_per_variable_metrics( datasets=datasets, diff --git a/src/climatebenchpress/compressor/plotting/variable_plotters.py b/src/climatebenchpress/compressor/plotting/variable_plotters.py index d59c6a2..aac7a94 100644 --- a/src/climatebenchpress/compressor/plotting/variable_plotters.py +++ b/src/climatebenchpress/compressor/plotting/variable_plotters.py @@ -153,6 +153,18 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): "ticks": [-bound_value, 0, bound_value], }, ) + + bad_vals = np.isnan(ds.isel(time=0)) & ~np.isnan(ds_new.isel(time=0)) + bad_vals = bad_vals.where(bad_vals) + bad_vals.plot( + ax=ax[2], + x="longitude", + y="latitude", + transform=ccrs.PlateCarree(), + add_colorbar=False, + cmap=plt.cm.colors.ListedColormap(["yellow"]), + ) + for a in ax: a.collections[0].colorbar.ax.set_ylabel( "degC", fontsize=self.cbar_label_fontsize @@ -164,6 +176,39 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): self.error_title = "Absolute Error" +class IFSHumidityPlotter(Plotter): + datasets = ["ifs-humidity"] + + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var, err_bound): + selector = dict(time=0, level=3) + # Calculate shared vmin and vmax for consistent color ranges + data_orig = ds.isel(**selector) + data_new = ds_new.isel(**selector) + vmin = np.nanmin(data_orig.values.squeeze()) + vmax = np.nanmax(data_orig.values.squeeze()) + + data_orig.plot(ax=ax[0], transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax) + data_new.plot( + ax=ax[1], transform=ccrs.PlateCarree(), robust=True, vmin=vmin, vmax=vmax + ) + error = data_orig - data_new + error.attrs["long_name"] = data_orig.attrs.get("long_name", "") + error.attrs["units"] = data_orig.attrs.get("units", "") + + _, bound_value = err_bound + vmin_error, vmax_error = -bound_value, bound_value + error.plot( + ax=ax[2], + transform=ccrs.PlateCarree(), + rasterized=True, + vmin=vmin_error, + vmax=vmax_error, + cbar_kwargs={"ticks": [-bound_value, 0, bound_value]}, + cmap="seismic", + ) + self.error_title = "Absolute Error" + + class Era5Plotter(Plotter): datasets = ["era5-tiny", "era5", "ifs-uncompressed"]