From a7428d508d124c1fa1f68e1a52a5eb2794b86cc0 Mon Sep 17 00:00:00 2001 From: Niru Maheswaranathan Date: Tue, 20 May 2025 10:57:15 -0700 Subject: [PATCH 1/3] Add type hints across project --- src/jetplot/chart_utils.py | 40 ++++++++++--- src/jetplot/colors.py | 10 ++-- src/jetplot/images.py | 64 ++++++++++++--------- src/jetplot/plots.py | 112 +++++++++++++++++++++++++------------ src/jetplot/signals.py | 4 +- src/jetplot/style.py | 10 ++-- src/jetplot/timepiece.py | 18 +++--- 7 files changed, 165 insertions(+), 93 deletions(-) diff --git a/src/jetplot/chart_utils.py b/src/jetplot/chart_utils.py index 6fe157d..712e808 100644 --- a/src/jetplot/chart_utils.py +++ b/src/jetplot/chart_utils.py @@ -2,6 +2,7 @@ from collections.abc import Callable from functools import partial, wraps +from typing import Any import numpy as np from matplotlib import pyplot as plt @@ -19,7 +20,7 @@ ] -def figwrapper(fun): +def figwrapper(fun: Callable[..., Any]) -> Callable[..., Any]: """Decorator that adds figure handles to the kwargs of a function.""" @wraps(fun) @@ -32,7 +33,7 @@ def wrapper(*args, **kwargs): return wrapper -def plotwrapper(fun): +def plotwrapper(fun: Callable[..., Any]) -> Callable[..., Any]: """Decorator that adds figure and axes handles to the kwargs of a function.""" @wraps(fun) @@ -51,7 +52,7 @@ def wrapper(*args, **kwargs): return wrapper -def axwrapper(fun): +def axwrapper(fun: Callable[..., Any]) -> Callable[..., Any]: """Decorator that adds an axes handle to kwargs.""" @wraps(fun) @@ -69,7 +70,7 @@ def wrapper(*args, **kwargs): @axwrapper -def noticks(**kwargs): +def noticks(**kwargs: Any) -> None: """ Clears tick marks (useful for images) """ @@ -80,7 +81,13 @@ def noticks(**kwargs): @axwrapper -def nospines(left=False, bottom=False, top=True, right=True, **kwargs): +def nospines( + left: bool = False, + bottom: bool = False, + top: bool = True, + right: bool = True, + **kwargs: Any, +) -> plt.Axes: """ Hides the specified axis spines (by default, right and top spines) """ @@ -114,7 +121,7 @@ def nospines(left=False, bottom=False, top=True, right=True, **kwargs): return ax -def get_bounds(axis, ax=None): +def get_bounds(axis: str, ax: plt.Axes | None = None) -> tuple[float, float]: if ax is None: ax = plt.gca() @@ -148,7 +155,12 @@ def get_bounds(axis, ax=None): @axwrapper -def breathe(xlims=None, ylims=None, padding_percent=0.05, **kwargs): +def breathe( + xlims: tuple[float, float] | None = None, + ylims: tuple[float, float] | None = None, + padding_percent: float = 0.05, + **kwargs: Any, +) -> plt.Axes: """Adds space between axes and plot.""" ax = kwargs["ax"] @@ -187,7 +199,12 @@ def identity(x): @axwrapper -def yclamp(y0=None, y1=None, dt=None, **kwargs): +def yclamp( + y0: float | None = None, + y1: float | None = None, + dt: float | None = None, + **kwargs: Any, +) -> plt.Axes: ax = kwargs["ax"] lims = ax.get_ylim() @@ -206,7 +223,12 @@ def yclamp(y0=None, y1=None, dt=None, **kwargs): @axwrapper -def xclamp(x0=None, x1=None, dt=None, **kwargs): +def xclamp( + x0: float | None = None, + x1: float | None = None, + dt: float | None = None, + **kwargs: Any, +) -> plt.Axes: ax = kwargs["ax"] lims = ax.get_xlim() diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index 040e2ad..026595c 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -15,14 +15,14 @@ class Palette(list[ColorType]): """Color palette based on a list of values.""" @property - def hex(self): + def hex(self) -> "Palette": return Palette([to_hex(rgb) for rgb in self]) @property - def cmap(self): + def cmap(self) -> LinearSegmentedColormap: return LinearSegmentedColormap.from_list("", self) - def plot(self, figsize=(5, 1)): + def plot(self, figsize: tuple[float, float] = (5, 1)) -> tuple[plt.Figure, np.ndarray]: fig, axs = plt.subplots(1, len(self), figsize=figsize) for c, ax in zip(self, axs, strict=True): # pyrefly: ignore ax.set_facecolor(c) @@ -40,7 +40,7 @@ def cubehelix( start: float = 0.0, rot: float = 0.4, hue: float = 0.8, -): +) -> Palette: """Cubehelix parameterized colormap.""" lambda_ = np.linspace(vmin, vmax, n) x = lambda_**gamma @@ -54,7 +54,7 @@ def cubehelix( return Palette(colors) -def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0): +def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0) -> Palette: return Palette(getattr(cm, cmap)(np.linspace(vmin, vmax, n))) diff --git a/src/jetplot/images.py b/src/jetplot/images.py index c464786..724cbf7 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -1,9 +1,12 @@ """Image visualization tools.""" from functools import partial +from typing import Any, Callable, Iterable import numpy as np from matplotlib import pyplot as plt +from matplotlib.axes import Axes +from matplotlib.image import AxesImage from matplotlib.ticker import FixedLocator from . import colors as c @@ -14,16 +17,16 @@ @plotwrapper def img( - data, - mode="div", - cmap=None, - aspect="equal", - vmin=None, - vmax=None, - cbar=True, - interpolation="none", - **kwargs, -): + data: np.ndarray, + mode: str = "div", + cmap: str | None = None, + aspect: str = "equal", + vmin: float | None = None, + vmax: float | None = None, + cbar: bool = True, + interpolation: str = "none", + **kwargs: Any, +) -> AxesImage: """Visualize a matrix as an image. Args: @@ -79,7 +82,14 @@ def img( @plotwrapper -def fsurface(func, xrng=None, yrng=None, n=100, nargs=2, **kwargs): +def fsurface( + func: Callable[..., np.ndarray], + xrng: tuple[float, float] | None = None, + yrng: tuple[float, float] | None = None, + n: int = 100, + nargs: int = 2, + **kwargs: Any, +) -> None: xrng = (-1, 1) if xrng is None else xrng yrng = xrng if yrng is None else yrng @@ -103,22 +113,22 @@ def fsurface(func, xrng=None, yrng=None, n=100, nargs=2, **kwargs): @plotwrapper def cmat( - arr, - labels=None, - annot=True, - cmap="gist_heat_r", - cbar=False, - fmt="0.0%", - dark_color="#222222", - light_color="#dddddd", - grid_color=c.gray[9], - theta=0.5, - label_fontsize=10.0, - fontsize=10.0, - vmin=0.0, - vmax=1.0, - **kwargs, -): + arr: np.ndarray, + labels: Iterable[str] | None = None, + annot: bool = True, + cmap: str = "gist_heat_r", + cbar: bool = False, + fmt: str = "0.0%", + dark_color: str = "#222222", + light_color: str = "#dddddd", + grid_color: str = c.gray[9], + theta: float = 0.5, + label_fontsize: float = 10.0, + fontsize: float = 10.0, + vmin: float = 0.0, + vmax: float = 1.0, + **kwargs: Any, +) -> tuple[AxesImage, Axes]: """Plot confusion matrix.""" num_rows, num_cols = arr.shape diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index e57f18b..edac078 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -1,12 +1,15 @@ """Common plots.""" import numpy as np +from matplotlib.axes import Axes +from matplotlib.figure import Figure from matplotlib.patches import Ellipse from matplotlib.transforms import Affine2D from matplotlib.typing import ColorType from numpy.typing import NDArray from scipy.stats import gaussian_kde from sklearn.covariance import EmpiricalCovariance, MinCovDet +from typing import Any, Iterable, Sequence from .chart_utils import figwrapper, nospines, plotwrapper from .colors import cmap_colors, neutral @@ -27,15 +30,15 @@ @plotwrapper def violinplot( data: NDArray[np.floating], - xs, - fc=neutral[3], - ec=neutral[9], - mc=neutral[1], - showmedians=True, - showmeans=False, - showquartiles=True, - **kwargs, -): + xs: Sequence[float] | float, + fc: ColorType = neutral[3], + ec: ColorType = neutral[9], + mc: ColorType = neutral[1], + showmedians: bool = True, + showmeans: bool = False, + showquartiles: bool = True, + **kwargs: Any, +) -> Axes: _ = kwargs.pop("fig") ax = kwargs.pop("ax") @@ -85,10 +88,10 @@ def violinplot( s=15, zorder=20, ) - + return ax @plotwrapper -def hist(*args, **kwargs): +def hist(*args: Any, **kwargs: Any) -> Any: """Wrapper for matplotlib.hist function.""" # remove kwargs that are filled in manually @@ -104,7 +107,14 @@ def hist(*args, **kwargs): @plotwrapper -def hist2d(x, y, bins=None, range=None, cmap="hot", **kwargs): +def hist2d( + x: NDArray[np.floating], + y: NDArray[np.floating], + bins: int | Sequence[float] | None = None, + range: NDArray[np.floating] | Sequence[Sequence[float]] | None = None, + cmap: str = "hot", + **kwargs: Any, +) -> None: """ Visualizes a 2D histogram by binning data. @@ -138,18 +148,18 @@ def hist2d(x, y, bins=None, range=None, cmap="hot", **kwargs): @plotwrapper def errorplot( - x, - y, - yerr, - method="patch", + x: NDArray[np.floating], + y: NDArray[np.floating], + yerr: NDArray[np.floating] | float | tuple[NDArray[np.floating], NDArray[np.floating]], + method: str = "patch", color: ColorType = "#222222", - xscale="linear", - fmt="-", + xscale: str = "linear", + fmt: str = "-", err_color: ColorType = "#cccccc", - alpha_fill=1.0, - clip_on=True, - **kwargs, -): + alpha_fill: float = 1.0, + clip_on: bool = True, + **kwargs: Any, +) -> None: """Plot a line with error bars.""" ax = kwargs["ax"] @@ -196,16 +206,16 @@ def errorplot( @plotwrapper def bar( - labels, - data, - color="#888888", - width=0.7, - offset=0.0, - err=None, - capsize=5, - capthick=2, - **kwargs, -): + labels: Sequence[str], + data: Sequence[float], + color: ColorType = "#888888", + width: float = 0.7, + offset: float = 0.0, + err: Sequence[float] | None = None, + capsize: float = 5, + capthick: float = 2, + **kwargs: Any, +) -> Axes: """Bar chart. Args: @@ -249,7 +259,12 @@ def bar( @plotwrapper -def lines(x, lines=None, cmap="viridis", **kwargs): +def lines( + x: NDArray[np.floating] | Sequence[float], + lines: Iterable[Sequence[float]] | None = None, + cmap: str = "viridis", + **kwargs: Any, +) -> None: ax = kwargs["ax"] if lines is None: @@ -265,7 +280,16 @@ def lines(x, lines=None, cmap="viridis", **kwargs): @plotwrapper -def waterfall(x, ys, dy=1.0, pad=0.1, color="#444444", ec="#cccccc", ew=2.0, **kwargs): +def waterfall( + x: NDArray[np.floating], + ys: Iterable[NDArray[np.floating]], + dy: float = 1.0, + pad: float = 0.1, + color: ColorType = "#444444", + ec: ColorType = "#cccccc", + ew: float = 2.0, + **kwargs: Any, +) -> None: """Waterfall plot.""" ax = kwargs["ax"] total = len(ys) @@ -281,7 +305,14 @@ def waterfall(x, ys, dy=1.0, pad=0.1, color="#444444", ec="#cccccc", ew=2.0, **k @figwrapper -def ridgeline(t, xs, colors, edgecolor="#ffffff", ymax=0.6, **kwargs): +def ridgeline( + t: NDArray[np.floating], + xs: Iterable[NDArray[np.floating]], + colors: Iterable[ColorType], + edgecolor: ColorType = "#ffffff", + ymax: float = 0.6, + **kwargs: Any, +) -> tuple[Figure, list[Axes]]: fig = kwargs["fig"] axs = [] @@ -307,7 +338,7 @@ def ridgeline(t, xs, colors, edgecolor="#ffffff", ymax=0.6, **kwargs): @plotwrapper -def circle(radius=1.0, **kwargs): +def circle(radius: float = 1.0, **kwargs: Any) -> None: """Plots a unit circle.""" ax = kwargs["ax"] theta = np.linspace(0, 2 * np.pi, 1001) @@ -315,7 +346,14 @@ def circle(radius=1.0, **kwargs): @plotwrapper -def ellipse(x, y, n_std=3.0, facecolor="none", estimator="empirical", **kwargs): +def ellipse( + x: NDArray[np.floating], + y: NDArray[np.floating], + n_std: float = 3.0, + facecolor: str = "none", + estimator: str = "empirical", + **kwargs: Any, +) -> Ellipse: """ Create a plot of the covariance confidence ellipse of *x* and *y*. diff --git a/src/jetplot/signals.py b/src/jetplot/signals.py index e0b9d49..e204c32 100644 --- a/src/jetplot/signals.py +++ b/src/jetplot/signals.py @@ -13,7 +13,7 @@ FloatArray = NDArray[np.floating] -def smooth(x, sigma=1.0, axis=0): +def smooth(x: ArrayLike, sigma: float = 1.0, axis: int = 0) -> NDArray[np.floating]: """Smooths a 1D signal with a gaussian filter. Args: @@ -26,7 +26,7 @@ def smooth(x, sigma=1.0, axis=0): return gaussian_filter1d(x, sigma, axis=axis) -def stable_rank(X): +def stable_rank(X: NDArray[np.floating]) -> float: """Computes the stable rank of a matrix""" assert X.ndim == 2, "X must be a matrix" svals_sq = np.linalg.svd(X, compute_uv=False, full_matrices=False) ** 2 diff --git a/src/jetplot/style.py b/src/jetplot/style.py index 4cd5eba..a9888a7 100644 --- a/src/jetplot/style.py +++ b/src/jetplot/style.py @@ -76,7 +76,7 @@ } -def set_colors(bg, fg, text): +def set_colors(bg: ColorType, fg: ColorType, text: ColorType) -> None: """Set background/foreground colorscheme.""" rcParams.update( { @@ -96,7 +96,7 @@ def set_colors(bg, fg, text): ) -def set_font(fontname: str): +def set_font(fontname: str) -> None: """Specifies the matplotlib default font.""" if fontname not in available_fonts(): @@ -105,7 +105,7 @@ def set_font(fontname: str): rcParams["font.family"] = fontname -def set_dpi(dpi: int): +def set_dpi(dpi: int) -> None: """Sets the figure DPI.""" rcParams["figure.dpi"] = dpi @@ -118,7 +118,7 @@ def set_defaults( cycler_colors: c.Palette, defaults: Mapping[str, Any] = STYLE_DEFAULTS, font: str = "Helvetica", -): +) -> None: """Sets matplotlib defaults.""" rcParams.update(defaults) set_colors(bg, fg, text) @@ -143,7 +143,7 @@ def available_fonts() -> list[str]: return sorted(set([f.name for f in fm.fontManager.ttflist])) # pyrefly: ignore -def install_fonts(filepath: str): +def install_fonts(filepath: str) -> None: """Installs .ttf fonts in the given folder.""" original_fonts = set(available_fonts()) diff --git a/src/jetplot/timepiece.py b/src/jetplot/timepiece.py index dd86946..21dc71d 100644 --- a/src/jetplot/timepiece.py +++ b/src/jetplot/timepiece.py @@ -4,38 +4,39 @@ from functools import wraps import numpy as np +from typing import Any, Callable __all__ = ["hrtime", "Stopwatch", "profile"] class Stopwatch: - def __init__(self, name=""): + def __init__(self, name: str = "") -> None: self.name = name self.start = time.perf_counter() self.absolute_start = time.perf_counter() - def __str__(self): + def __str__(self) -> str: return "\u231a Stopwatch for: " + self.name @property - def elapsed(self): + def elapsed(self) -> float: current = time.perf_counter() elapsed = current - self.start self.start = time.perf_counter() return elapsed - def checkpoint(self, name=""): + def checkpoint(self, name: str = "") -> None: print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip()) - def __enter__(self): + def __enter__(self) -> "Stopwatch": return self - def __exit__(self, *_): + def __exit__(self, *_: object) -> None: total = hrtime(time.perf_counter() - self.absolute_start) print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}") -def hrtime(t: float): +def hrtime(t: float) -> str: """Converts a time in seconds to a reasonable human readable time. Args: @@ -84,7 +85,8 @@ def hrtime(t: float): return timestr -def profile(func): + +def profile(func: Callable[..., Any]) -> Callable[..., Any]: """Timing (profile) decorator for a function.""" calls = list() From 7d6de2f6fb35ecd8ef7c2cc5e1b2f13064975cf7 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 20 May 2025 21:48:01 -0700 Subject: [PATCH 2/3] [WIP] fixing type checks. --- src/jetplot/chart_utils.py | 1 + src/jetplot/images.py | 4 ++-- src/jetplot/plots.py | 17 ++++++++++------- src/jetplot/timepiece.py | 4 ++-- tests/test_images.py | 1 + tests/test_utils.py | 1 + 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/jetplot/chart_utils.py b/src/jetplot/chart_utils.py index d233e27..903bbb8 100644 --- a/src/jetplot/chart_utils.py +++ b/src/jetplot/chart_utils.py @@ -1,5 +1,6 @@ """Plotting utils.""" +from collections.abc import Callable from functools import partial, wraps from typing import Any, Literal diff --git a/src/jetplot/images.py b/src/jetplot/images.py index 596cea2..0f3f347 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -1,8 +1,8 @@ """Image visualization tools.""" -from collections.abc import Callable +from collections.abc import Callable, Iterable from functools import partial -from typing import Any, Callable, Iterable +from typing import Any import numpy as np from matplotlib import pyplot as plt diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index e39a7ef..7a28aed 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -1,18 +1,17 @@ """Common plots.""" +from collections.abc import Iterable, Sequence +from typing import Any + import numpy as np from matplotlib.axes import Axes from matplotlib.figure import Figure from matplotlib.patches import Ellipse from matplotlib.transforms import Affine2D from matplotlib.typing import ColorType -from matplotlib.figure import Figure -from matplotlib.axes import Axes -from collections.abc import Sequence from numpy.typing import NDArray from scipy.stats import gaussian_kde from sklearn.covariance import EmpiricalCovariance, MinCovDet -from typing import Any, Iterable, Sequence from .chart_utils import figwrapper, nospines, plotwrapper from .colors import cmap_colors, neutral @@ -95,9 +94,11 @@ def violinplot( return ax - + @plotwrapper -def hist(*args: Any, histtype="stepfilled", alpha=0.85, density=True, **kwargs: Any) -> Any: +def hist( + *args: Any, histtype="stepfilled", alpha=0.85, density=True, **kwargs: Any +) -> Any: """Wrapper for matplotlib.hist function.""" ax = kwargs.pop("ax") kwargs.pop("fig") @@ -147,7 +148,9 @@ def hist2d( def errorplot( x: NDArray[np.floating], y: NDArray[np.floating], - yerr: NDArray[np.floating] | float | tuple[NDArray[np.floating], NDArray[np.floating]], + yerr: NDArray[np.floating] + | float + | tuple[NDArray[np.floating], NDArray[np.floating]], method: str = "patch", color: ColorType = "#222222", xscale: str = "linear", diff --git a/src/jetplot/timepiece.py b/src/jetplot/timepiece.py index d1b538f..9d7b6c4 100644 --- a/src/jetplot/timepiece.py +++ b/src/jetplot/timepiece.py @@ -1,10 +1,11 @@ """Utilities for dealing with time.""" import time +from collections.abc import Callable from functools import wraps +from typing import Any import numpy as np -from typing import Any, Callable __all__ = ["hrtime", "Stopwatch", "profile"] @@ -87,7 +88,6 @@ def hrtime(t: float) -> str: return timestr - def profile(func: Callable[..., Any]) -> Callable[..., Any]: """Timing (profile) decorator for a function.""" calls = list() diff --git a/tests/test_images.py b/tests/test_images.py index 8c6c05d..36ff869 100644 --- a/tests/test_images.py +++ b/tests/test_images.py @@ -17,6 +17,7 @@ def test_img_corr_mode(): assert len(fig.axes) == 2 plt.close(fig) + def test_cmat_labels_and_colorbar(): data = np.array([[0.0, 1.0], [1.0, 0.0]]) fig, ax = plt.subplots() diff --git a/tests/test_utils.py b/tests/test_utils.py index 5d3be0e..f54e6d9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -52,6 +52,7 @@ def test_noticks(): plt.close(fig) + def test_get_bounds_spines(): fig, ax = plt.subplots() ax.plot([0, 1], [0, 1]) From 17a55d159d18ce8d27b9ed821faca0908dad8f01 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 20 May 2025 22:27:11 -0700 Subject: [PATCH 3/3] Fixes type checks. --- justfile | 3 +++ src/jetplot/colors.py | 9 +++++---- src/jetplot/images.py | 8 ++++---- src/jetplot/plots.py | 28 +++++++++++++++------------- src/jetplot/signals.py | 7 +++++-- 5 files changed, 32 insertions(+), 23 deletions(-) diff --git a/justfile b/justfile index 56e0b43..dba1dba 100644 --- a/justfile +++ b/justfile @@ -17,3 +17,6 @@ typecheck: test: uv run pytest --cov=jetplot --cov-report=term + +loop: + find {src,tests} -name "*.py" | entr -c just test diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index 79e7bc6..ac91109 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -1,5 +1,7 @@ """Colorschemes""" +from typing import cast + import numpy as np from matplotlib import cm from matplotlib import pyplot as plt @@ -7,7 +9,6 @@ from matplotlib.colors import LinearSegmentedColormap, to_hex from matplotlib.figure import Figure from matplotlib.typing import ColorType -from numpy.typing import NDArray from .chart_utils import noticks @@ -20,14 +21,14 @@ class Palette(list[ColorType]): @property def hex(self) -> "Palette": """Return the palette colors as hexadecimal strings.""" - return Palette([to_hex(rgb) for rgb in self]) + return Palette([to_hex(rgb) for rgb in self]) # pyrefly: ignore @property def cmap(self) -> LinearSegmentedColormap: """Return the palette as a Matplotlib colormap.""" return LinearSegmentedColormap.from_list("", self) - def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, NDArray[Axes]]: + def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, list[Axes]]: """Visualize the colors in the palette.""" fig, axs = plt.subplots(1, len(self), figsize=figsize) for c, ax in zip(self, axs, strict=True): # pyrefly: ignore @@ -35,7 +36,7 @@ def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, NDArray[Axes] ax.set_aspect("equal") noticks(ax=ax) - return fig, axs + return fig, cast(list[Axes], axs) def cubehelix( diff --git a/src/jetplot/images.py b/src/jetplot/images.py index 0f3f347..101c27b 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -2,7 +2,7 @@ from collections.abc import Callable, Iterable from functools import partial -from typing import Any +from typing import Any, cast import numpy as np from matplotlib import pyplot as plt @@ -123,7 +123,7 @@ def cmat( fmt: str = "0.0%", dark_color: str = "#222222", light_color: str = "#dddddd", - grid_color: str = c.gray[9], + grid_color: str = cast(str, c.gray[9]), theta: float = 0.5, label_fontsize: float = 10.0, fontsize: float = 10.0, @@ -141,8 +141,8 @@ def cmat( for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore color = dark_color if (value <= theta) else light_color - annot = f"{{:{fmt}}}".format(value) - ax.text(x, y, annot, ha="center", va="center", color=color, fontsize=fontsize) + label = f"{{:{fmt}}}".format(value) + ax.text(x, y, label, ha="center", va="center", color=color, fontsize=fontsize) if labels is not None: ax.set_xticks(np.arange(num_cols)) diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index 7a28aed..61b1b4d 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -1,7 +1,7 @@ """Common plots.""" from collections.abc import Iterable, Sequence -from typing import Any +from typing import Any, cast import numpy as np from matplotlib.axes import Axes @@ -163,11 +163,11 @@ def errorplot( """Plot a line with error bars.""" ax = kwargs["ax"] - if np.isscalar(yerr) or len(yerr) == len(y): - ymin = y - yerr - ymax = y + yerr + if np.isscalar(yerr) or len(yerr) == len(y): # pyrefly: ignore + ymin = y - yerr # pyrefly: ignore + ymax = y + yerr # pyrefly: ignore elif len(yerr) == 2: - ymin, ymax = yerr + ymin, ymax = yerr # pyrefly: ignore else: raise ValueError("Invalid yerr value: ", yerr) @@ -175,7 +175,9 @@ def errorplot( ax.plot(x, y, fmt, color=color, linewidth=4, clip_on=clip_on) ax.plot(x, ymax, "_", ms=20, color=err_color, clip_on=clip_on) ax.plot(x, ymin, "_", ms=20, color=err_color, clip_on=clip_on) - for i, xi in enumerate(x): + + # plot error bars + for i, xi in enumerate(x): # pyrefly: ignore ax.plot( np.array([xi, xi]), np.array([ymin[i], ymax[i]]), @@ -230,7 +232,7 @@ def bar( n = len(data) x = np.arange(n) + width if err is not None: - err = np.vstack((np.zeros_like(err), err)) + err = np.vstack((np.zeros_like(err), err)) # pyrefly: ignore ax.bar(x, data, width, color=color) @@ -295,7 +297,7 @@ def waterfall( ) -> None: """Waterfall plot.""" ax = kwargs["ax"] - total = len(ys) + total = cast(int, len(ys)) for index, y in enumerate(ys): zorder = total - index @@ -321,7 +323,7 @@ def ridgeline( axs = [] for k, (x, c) in enumerate(zip(xs, colors, strict=False)): - ax = fig.add_subplot(len(xs), 1, k + 1) + ax = fig.add_subplot(cast(int, len(xs)), 1, k + 1) y = gaussian_kde(x).evaluate(t) ax.fill_between(t, y, color=c, clip_on=False) ax.plot(t, y, color=edgecolor, clip_on=False) @@ -376,7 +378,7 @@ def ellipse( ------- matplotlib.patches.Ellipse """ - ax = kwargs.get("ax") + ax = cast(Axes, kwargs.get("ax")) if x.size != y.size: raise ValueError("x and y must be the same size") @@ -403,11 +405,11 @@ def ellipse( # the square root of the variance and multiplying # with the given number of standard deviations. scale_x = np.sqrt(cov[0, 0]) * n_std - mean_x = np.mean(x) + mean_x = np.mean(x) # pyrefly: ignore # calculating the standard deviation of y ... scale_y = np.sqrt(cov[1, 1]) * n_std - mean_y = np.mean(y) + mean_y = np.mean(y) # pyrefly: ignore transform = ( Affine2D() @@ -416,5 +418,5 @@ def ellipse( .translate(float(mean_x), float(mean_y)) ) - ellipse.set_transform(transform + ax.transData) + ellipse.set_transform(transform + ax.transData) # pyrefly: ignore return ax.add_patch(ellipse) diff --git a/src/jetplot/signals.py b/src/jetplot/signals.py index e204c32..ab9dcf4 100644 --- a/src/jetplot/signals.py +++ b/src/jetplot/signals.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Protocol, SupportsIndex +from typing import Any, Protocol, SupportsIndex import numpy as np from numpy.typing import ArrayLike, NDArray @@ -26,10 +26,13 @@ def smooth(x: ArrayLike, sigma: float = 1.0, axis: int = 0) -> NDArray[np.floati return gaussian_filter1d(x, sigma, axis=axis) -def stable_rank(X: NDArray[np.floating]) -> float: +def stable_rank(X: NDArray[np.floating[Any]]) -> float: """Computes the stable rank of a matrix""" assert X.ndim == 2, "X must be a matrix" + + # pyrefly: ignore svals_sq = np.linalg.svd(X, compute_uv=False, full_matrices=False) ** 2 + return svals_sq.sum() / svals_sq.max()