Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 68 additions & 25 deletions src/climatebenchpress/compressor/plotting/error_dist_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@

class ErrorDistPlotter:
def __init__(self, variables, error_bounds):
self.fig, self.axes = plt.subplots(
len(variables),
len(error_bounds),
figsize=(17, 5 * len(variables)),
squeeze=False,
)
self.variables = variables
self.error_bounds = error_bounds
self.figs = {}
self.axes = {}
# Create a separate figure for each variable
for var in variables:
self.figs[var], self.axes[var] = plt.subplots(
1,
len(error_bounds),
figsize=(17, 5),
squeeze=False,
)
# Remove the first dimension since we only have 1 row
self.axes[var] = self.axes[var][0]

self.errors = {var: dict() for var in variables}

Expand Down Expand Up @@ -53,8 +61,7 @@ def plot_error_bound_histograms(
# does not change the error plot distribution. Hence, we ignore the PCO
# compressors here.
compressors = [comp for comp in compressors if "-pco" not in comp]

for j, var in enumerate(variables):
for var in variables:
for comp in compressors:
color, linestyle = get_line_info(comp)
label = get_legend_name(comp)
Expand All @@ -63,44 +70,62 @@ def plot_error_bound_histograms(
label = "BitRound"
elif label.startswith("StochRound"):
label = "StochRound"
self.axes[j, col_index].hist(
self.errors[var][comp],
# Filter out inf values
error_data = self.errors[var][comp]
error_data = error_data[~np.isinf(error_data) & ~np.isnan(error_data)]
self.axes[var][col_index].hist(
np.float64(error_data),
bins=100,
density=True,
histtype="step",
label=label,
color=color,
linestyle=linestyle,
linewidth=2,
alpha=0.8,
linewidth=4,
alpha=0.6,
)

error_bound_name, error_bound_value = error_bound_vals[var]
self.axes[j, col_index].set_xlabel("Error Value")
self.axes[j, col_index].set_ylabel("Log Probability Density")
self.axes[j, col_index].set_yscale("log")
self.axes[var][col_index].set_xlabel("Error Value", fontsize=14)
self.axes[var][col_index].set_ylabel("Log Probability Density", fontsize=14)
self.axes[var][col_index].set_yscale("log")

xticks = np.linspace(-2 * error_bound_value, 2 * error_bound_value, num=5)
xlabels = ["0.0" if x == 0.0 else f"{x:.2e}" for x in xticks]
self.axes[j, col_index].set_xticks(xticks, labels=xlabels)
self.axes[j, col_index].set_xlim(
xlabels = [_format_number(x) for x in xticks]
self.axes[var][col_index].set_xticks(xticks, labels=xlabels, fontsize=12)
self.axes[var][col_index].set_xlim(
-2 * error_bound_value, 2 * error_bound_value
)

self.axes[j, col_index].set_title(
f"{var}\n{error_bound_name} = {error_bound_value:.2e}"
if col_index == 1
else f"{error_bound_name} = {error_bound_value:.2e}"
self.axes[var][col_index].axvline(
-error_bound_value,
color="black",
linestyle="--",
linewidth=2,
alpha=0.7,
)
self.axes[var][col_index].axvline(
error_bound_value, color="black", linestyle="--", linewidth=2, alpha=0.7
)

err_bound_type = (
"Abs. Error" if error_bound_name == "abs_error" else "Rel. Error"
)
self.axes[var][col_index].set_title(
f"{err_bound_type} = {_format_number(error_bound_value)}", fontsize=14
)

# Reset errors for the next iteration. Ensures we don't plot the wrong errors
# for the next error bound.
self.errors = {var: dict() for var in variables}

def get_final_figure(self):
self.axes[0, -1].legend(loc="center left", bbox_to_anchor=(1, 0.5))
self.fig.tight_layout()
return self.fig, self.axes
for var in self.variables:
self.axes[var][-1].legend(
loc="center left", bbox_to_anchor=(1, 0.5), fontsize=14
)
self.figs[var].tight_layout()
return self.figs, self.axes


def robust_error(x, y):
Expand All @@ -125,3 +150,21 @@ def robust_error(x, y):
result = xr.where(both_nan | (both_inf & ~inf_sign_mismatch), 0.0, x - y)

return result


def _format_number(value: float) -> str:
if abs(value) > 10_000:
text = f"{value:.1e}"
elif abs(value) > 10:
text = f"{value:.0f}"
elif abs(value) > 1:
text = f"{value:.1f}"
elif value == 0:
text = "0.0"
elif abs(value) < 0.001:
text = f"{value:.1e}"
else:
# 0.001 <= abs(value) < 1
text = f"{value:.3f}"

return text
Loading