diff --git a/justfile b/justfile index 56e0b43..dba1dba 100644 --- a/justfile +++ b/justfile @@ -17,3 +17,6 @@ typecheck: test: uv run pytest --cov=jetplot --cov-report=term + +loop: + find {src,tests} -name "*.py" | entr -c just test diff --git a/src/jetplot/chart_utils.py b/src/jetplot/chart_utils.py index 2bf1da2..903bbb8 100644 --- a/src/jetplot/chart_utils.py +++ b/src/jetplot/chart_utils.py @@ -1,5 +1,6 @@ """Plotting utils.""" +from collections.abc import Callable from functools import partial, wraps from typing import Any, Literal @@ -20,7 +21,7 @@ ] -def figwrapper(fun): +def figwrapper(fun: Callable[..., Any]) -> Callable[..., Any]: """Decorator that adds figure handles to the kwargs of a function.""" @wraps(fun) @@ -33,7 +34,7 @@ def wrapper(*args, **kwargs): return wrapper -def plotwrapper(fun): +def plotwrapper(fun: Callable[..., Any]) -> Callable[..., Any]: """Decorator that adds figure and axes handles to the kwargs of a function.""" @wraps(fun) @@ -52,7 +53,7 @@ def wrapper(*args, **kwargs): return wrapper -def axwrapper(fun): +def axwrapper(fun: Callable[..., Any]) -> Callable[..., Any]: """Decorator that adds an axes handle to kwargs.""" @wraps(fun) @@ -70,7 +71,7 @@ def wrapper(*args, **kwargs): @axwrapper -def noticks(**kwargs): +def noticks(**kwargs: Any) -> None: """ Clears tick marks (useful for images) """ @@ -81,7 +82,13 @@ def noticks(**kwargs): @axwrapper -def nospines(left=False, bottom=False, top=True, right=True, **kwargs): +def nospines( + left: bool = False, + bottom: bool = False, + top: bool = True, + right: bool = True, + **kwargs: Any, +) -> plt.Axes: """ Hides the specified axis spines (by default, right and top spines) """ @@ -160,7 +167,12 @@ def get_bounds(axis: Literal["x", "y"], ax: Axes | None = None) -> tuple[float, @axwrapper -def breathe(xlims=None, ylims=None, padding_percent=0.05, **kwargs): +def breathe( + xlims: tuple[float, float] | None = None, + ylims: tuple[float, float] | None = None, + padding_percent: float = 0.05, + **kwargs: Any, +) -> plt.Axes: """Adds space between axes and plot.""" ax = kwargs["ax"] @@ -203,7 +215,7 @@ def yclamp( y0: float | None = None, y1: float | None = None, dt: float | None = None, - **kwargs, + **kwargs: Any, ) -> Axes: """Clamp the y-axis to evenly spaced tick marks.""" ax = kwargs["ax"] @@ -228,7 +240,7 @@ def xclamp( x0: float | None = None, x1: float | None = None, dt: float | None = None, - **kwargs, + **kwargs: Any, ) -> Axes: """Clamp the x-axis to evenly spaced tick marks.""" ax = kwargs["ax"] diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index 088d91f..ac91109 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -1,5 +1,7 @@ """Colorschemes""" +from typing import cast + import numpy as np from matplotlib import cm from matplotlib import pyplot as plt @@ -7,7 +9,6 @@ from matplotlib.colors import LinearSegmentedColormap, to_hex from matplotlib.figure import Figure from matplotlib.typing import ColorType -from numpy.typing import NDArray from .chart_utils import noticks @@ -18,16 +19,16 @@ class Palette(list[ColorType]): """Color palette based on a list of values.""" @property - def hex(self): + def hex(self) -> "Palette": """Return the palette colors as hexadecimal strings.""" - return Palette([to_hex(rgb) for rgb in self]) + return Palette([to_hex(rgb) for rgb in self]) # pyrefly: ignore @property def cmap(self) -> LinearSegmentedColormap: """Return the palette as a Matplotlib colormap.""" return LinearSegmentedColormap.from_list("", self) - def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, NDArray[Axes]]: + def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, list[Axes]]: """Visualize the colors in the palette.""" fig, axs = plt.subplots(1, len(self), figsize=figsize) for c, ax in zip(self, axs, strict=True): # pyrefly: ignore @@ -35,7 +36,7 @@ def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, NDArray[Axes] ax.set_aspect("equal") noticks(ax=ax) - return fig, axs + return fig, cast(list[Axes], axs) def cubehelix( @@ -46,7 +47,7 @@ def cubehelix( start: float = 0.0, rot: float = 0.4, hue: float = 0.8, -): +) -> Palette: """Cubehelix parameterized colormap.""" lambda_ = np.linspace(vmin, vmax, n) x = lambda_**gamma diff --git a/src/jetplot/images.py b/src/jetplot/images.py index afb71a7..101c27b 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -1,10 +1,13 @@ """Image visualization tools.""" -from collections.abc import Callable +from collections.abc import Callable, Iterable from functools import partial +from typing import Any, cast import numpy as np from matplotlib import pyplot as plt +from matplotlib.axes import Axes +from matplotlib.image import AxesImage from matplotlib.ticker import FixedLocator from . import colors as c @@ -15,16 +18,16 @@ @plotwrapper def img( - data, - mode="div", - cmap=None, - aspect="equal", - vmin=None, - vmax=None, - cbar=True, - interpolation="none", - **kwargs, -): + data: np.ndarray, + mode: str = "div", + cmap: str | None = None, + aspect: str = "equal", + vmin: float | None = None, + vmax: float | None = None, + cbar: bool = True, + interpolation: str = "none", + **kwargs: Any, +) -> AxesImage: """Visualize a matrix as an image. Args: @@ -86,7 +89,7 @@ def fsurface( yrng: tuple[float, float] | None = None, n: int = 100, nargs: int = 2, - **kwargs, + **kwargs: Any, ) -> None: """Plot a 2‑D function as a filled surface.""" xrng = (-1, 1) if xrng is None else xrng @@ -112,22 +115,22 @@ def fsurface( @plotwrapper def cmat( - arr, - labels=None, - annot=True, - cmap="gist_heat_r", - cbar=False, - fmt="0.0%", - dark_color="#222222", - light_color="#dddddd", - grid_color=c.gray[9], - theta=0.5, - label_fontsize=10.0, - fontsize=10.0, - vmin=0.0, - vmax=1.0, - **kwargs, -): + arr: np.ndarray, + labels: Iterable[str] | None = None, + annot: bool = True, + cmap: str = "gist_heat_r", + cbar: bool = False, + fmt: str = "0.0%", + dark_color: str = "#222222", + light_color: str = "#dddddd", + grid_color: str = cast(str, c.gray[9]), + theta: float = 0.5, + label_fontsize: float = 10.0, + fontsize: float = 10.0, + vmin: float = 0.0, + vmax: float = 1.0, + **kwargs: Any, +) -> tuple[AxesImage, Axes]: """Plot confusion matrix.""" num_rows, num_cols = arr.shape @@ -138,8 +141,8 @@ def cmat( for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore color = dark_color if (value <= theta) else light_color - annot = f"{{:{fmt}}}".format(value) - ax.text(x, y, annot, ha="center", va="center", color=color, fontsize=fontsize) + label = f"{{:{fmt}}}".format(value) + ax.text(x, y, label, ha="center", va="center", color=color, fontsize=fontsize) if labels is not None: ax.set_xticks(np.arange(num_cols)) diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index 60ca23e..61b1b4d 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -1,12 +1,14 @@ """Common plots.""" +from collections.abc import Iterable, Sequence +from typing import Any, cast + import numpy as np +from matplotlib.axes import Axes +from matplotlib.figure import Figure from matplotlib.patches import Ellipse from matplotlib.transforms import Affine2D from matplotlib.typing import ColorType -from matplotlib.figure import Figure -from matplotlib.axes import Axes -from collections.abc import Sequence from numpy.typing import NDArray from scipy.stats import gaussian_kde from sklearn.covariance import EmpiricalCovariance, MinCovDet @@ -30,14 +32,14 @@ @plotwrapper def violinplot( data: NDArray[np.floating], - xs, - fc=neutral[3], - ec=neutral[9], - mc=neutral[1], - showmedians=True, - showmeans=False, - showquartiles=True, - **kwargs, + xs: Sequence[float] | float, + fc: ColorType = neutral[3], + ec: ColorType = neutral[9], + mc: ColorType = neutral[1], + showmedians: bool = True, + showmeans: bool = False, + showquartiles: bool = True, + **kwargs: Any, ) -> Axes: """Violin plot with customizable elements.""" _ = kwargs.pop("fig") @@ -94,7 +96,9 @@ def violinplot( @plotwrapper -def hist(*args, histtype="stepfilled", alpha=0.85, density=True, **kwargs): +def hist( + *args: Any, histtype="stepfilled", alpha=0.85, density=True, **kwargs: Any +) -> Any: """Wrapper for matplotlib.hist function.""" ax = kwargs.pop("ax") kwargs.pop("fig") @@ -104,13 +108,13 @@ def hist(*args, histtype="stepfilled", alpha=0.85, density=True, **kwargs): @plotwrapper def hist2d( - x: np.ndarray, - y: np.ndarray, - bins: int | None = None, - limits: np.ndarray | None = None, + x: NDArray[np.floating], + y: NDArray[np.floating], + bins: int | Sequence[float] | None = None, + limits: NDArray[np.floating] | Sequence[Sequence[float]] | None = None, cmap: str = "hot", - **kwargs, -): + **kwargs: Any, +) -> None: """ Visualizes a 2D histogram by binning data. @@ -142,26 +146,28 @@ def hist2d( @plotwrapper def errorplot( - x, - y, - yerr, - method="patch", + x: NDArray[np.floating], + y: NDArray[np.floating], + yerr: NDArray[np.floating] + | float + | tuple[NDArray[np.floating], NDArray[np.floating]], + method: str = "patch", color: ColorType = "#222222", - xscale="linear", - fmt="-", + xscale: str = "linear", + fmt: str = "-", err_color: ColorType = "#cccccc", - alpha_fill=1.0, - clip_on=True, - **kwargs, -): + alpha_fill: float = 1.0, + clip_on: bool = True, + **kwargs: Any, +) -> None: """Plot a line with error bars.""" ax = kwargs["ax"] - if np.isscalar(yerr) or len(yerr) == len(y): - ymin = y - yerr - ymax = y + yerr + if np.isscalar(yerr) or len(yerr) == len(y): # pyrefly: ignore + ymin = y - yerr # pyrefly: ignore + ymax = y + yerr # pyrefly: ignore elif len(yerr) == 2: - ymin, ymax = yerr + ymin, ymax = yerr # pyrefly: ignore else: raise ValueError("Invalid yerr value: ", yerr) @@ -169,7 +175,9 @@ def errorplot( ax.plot(x, y, fmt, color=color, linewidth=4, clip_on=clip_on) ax.plot(x, ymax, "_", ms=20, color=err_color, clip_on=clip_on) ax.plot(x, ymin, "_", ms=20, color=err_color, clip_on=clip_on) - for i, xi in enumerate(x): + + # plot error bars + for i, xi in enumerate(x): # pyrefly: ignore ax.plot( np.array([xi, xi]), np.array([ymin[i], ymax[i]]), @@ -200,16 +208,16 @@ def errorplot( @plotwrapper def bar( - labels, - data, - color="#888888", - width=0.7, - offset=0.0, - err=None, - capsize=5, - capthick=2, - **kwargs, -): + labels: Sequence[str], + data: Sequence[float], + color: ColorType = "#888888", + width: float = 0.7, + offset: float = 0.0, + err: Sequence[float] | None = None, + capsize: float = 5, + capthick: float = 2, + **kwargs: Any, +) -> Axes: """Bar chart. Args: @@ -224,7 +232,7 @@ def bar( n = len(data) x = np.arange(n) + width if err is not None: - err = np.vstack((np.zeros_like(err), err)) + err = np.vstack((np.zeros_like(err), err)) # pyrefly: ignore ax.bar(x, data, width, color=color) @@ -277,10 +285,19 @@ def lines( @plotwrapper -def waterfall(x, ys, dy=1.0, pad=0.1, color="#444444", ec="#cccccc", ew=2.0, **kwargs): +def waterfall( + x: NDArray[np.floating], + ys: Iterable[NDArray[np.floating]], + dy: float = 1.0, + pad: float = 0.1, + color: ColorType = "#444444", + ec: ColorType = "#cccccc", + ew: float = 2.0, + **kwargs: Any, +) -> None: """Waterfall plot.""" ax = kwargs["ax"] - total = len(ys) + total = cast(int, len(ys)) for index, y in enumerate(ys): zorder = total - index @@ -295,18 +312,18 @@ def waterfall(x, ys, dy=1.0, pad=0.1, color="#444444", ec="#cccccc", ew=2.0, **k @figwrapper def ridgeline( t: NDArray[np.floating], - xs: Sequence[NDArray[np.floating]], - colors: Sequence[ColorType], + xs: Iterable[NDArray[np.floating]], + colors: Iterable[ColorType], edgecolor: ColorType = "#ffffff", ymax: float = 0.6, - **kwargs, + **kwargs: Any, ) -> tuple[Figure, list[Axes]]: """Stacked density plots reminiscent of a ridgeline plot.""" fig = kwargs["fig"] axs = [] for k, (x, c) in enumerate(zip(xs, colors, strict=False)): - ax = fig.add_subplot(len(xs), 1, k + 1) + ax = fig.add_subplot(cast(int, len(xs)), 1, k + 1) y = gaussian_kde(x).evaluate(t) ax.fill_between(t, y, color=c, clip_on=False) ax.plot(t, y, color=edgecolor, clip_on=False) @@ -327,7 +344,7 @@ def ridgeline( @plotwrapper -def circle(radius=1.0, **kwargs): +def circle(radius: float = 1.0, **kwargs: Any) -> None: """Plots a unit circle.""" ax = kwargs["ax"] theta = np.linspace(0, 2 * np.pi, 1001) @@ -335,7 +352,14 @@ def circle(radius=1.0, **kwargs): @plotwrapper -def ellipse(x, y, n_std=3.0, facecolor="none", estimator="empirical", **kwargs): +def ellipse( + x: NDArray[np.floating], + y: NDArray[np.floating], + n_std: float = 3.0, + facecolor: str = "none", + estimator: str = "empirical", + **kwargs: Any, +) -> Ellipse: """ Create a plot of the covariance confidence ellipse of *x* and *y*. @@ -354,7 +378,7 @@ def ellipse(x, y, n_std=3.0, facecolor="none", estimator="empirical", **kwargs): ------- matplotlib.patches.Ellipse """ - ax = kwargs.get("ax") + ax = cast(Axes, kwargs.get("ax")) if x.size != y.size: raise ValueError("x and y must be the same size") @@ -381,11 +405,11 @@ def ellipse(x, y, n_std=3.0, facecolor="none", estimator="empirical", **kwargs): # the square root of the variance and multiplying # with the given number of standard deviations. scale_x = np.sqrt(cov[0, 0]) * n_std - mean_x = np.mean(x) + mean_x = np.mean(x) # pyrefly: ignore # calculating the standard deviation of y ... scale_y = np.sqrt(cov[1, 1]) * n_std - mean_y = np.mean(y) + mean_y = np.mean(y) # pyrefly: ignore transform = ( Affine2D() @@ -394,5 +418,5 @@ def ellipse(x, y, n_std=3.0, facecolor="none", estimator="empirical", **kwargs): .translate(float(mean_x), float(mean_y)) ) - ellipse.set_transform(transform + ax.transData) + ellipse.set_transform(transform + ax.transData) # pyrefly: ignore return ax.add_patch(ellipse) diff --git a/src/jetplot/signals.py b/src/jetplot/signals.py index e0b9d49..ab9dcf4 100644 --- a/src/jetplot/signals.py +++ b/src/jetplot/signals.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Protocol, SupportsIndex +from typing import Any, Protocol, SupportsIndex import numpy as np from numpy.typing import ArrayLike, NDArray @@ -13,7 +13,7 @@ FloatArray = NDArray[np.floating] -def smooth(x, sigma=1.0, axis=0): +def smooth(x: ArrayLike, sigma: float = 1.0, axis: int = 0) -> NDArray[np.floating]: """Smooths a 1D signal with a gaussian filter. Args: @@ -26,10 +26,13 @@ def smooth(x, sigma=1.0, axis=0): return gaussian_filter1d(x, sigma, axis=axis) -def stable_rank(X): +def stable_rank(X: NDArray[np.floating[Any]]) -> float: """Computes the stable rank of a matrix""" assert X.ndim == 2, "X must be a matrix" + + # pyrefly: ignore svals_sq = np.linalg.svd(X, compute_uv=False, full_matrices=False) ** 2 + return svals_sq.sum() / svals_sq.max() diff --git a/src/jetplot/style.py b/src/jetplot/style.py index 930535a..bcbb1b9 100644 --- a/src/jetplot/style.py +++ b/src/jetplot/style.py @@ -76,7 +76,7 @@ } -def set_colors(bg, fg, text): +def set_colors(bg: ColorType, fg: ColorType, text: ColorType) -> None: """Set background/foreground colorscheme.""" rcParams.update( { @@ -96,7 +96,7 @@ def set_colors(bg, fg, text): ) -def set_font(fontname: str): +def set_font(fontname: str) -> None: """Specifies the matplotlib default font.""" if fontname not in available_fonts(): @@ -105,7 +105,7 @@ def set_font(fontname: str): rcParams["font.family"] = fontname -def set_dpi(dpi: int): +def set_dpi(dpi: int) -> None: """Sets the figure DPI.""" rcParams["figure.dpi"] = dpi @@ -118,7 +118,7 @@ def set_defaults( cycler_colors: c.Palette, defaults: Mapping[str, Any] = STYLE_DEFAULTS, font: str = "Helvetica", -): +) -> None: """Sets matplotlib defaults.""" rcParams.update(defaults) set_colors(bg, fg, text) @@ -143,7 +143,7 @@ def available_fonts() -> list[str]: return sorted(set([f.name for f in fm.fontManager.ttflist])) # pyrefly: ignore -def install_fonts(filepath: str): +def install_fonts(filepath: str) -> None: """Installs .ttf fonts in the given folder.""" original_fonts = set(available_fonts()) diff --git a/src/jetplot/timepiece.py b/src/jetplot/timepiece.py index 210d536..9d7b6c4 100644 --- a/src/jetplot/timepiece.py +++ b/src/jetplot/timepiece.py @@ -1,7 +1,9 @@ """Utilities for dealing with time.""" import time +from collections.abc import Callable from functools import wraps +from typing import Any import numpy as np @@ -37,7 +39,7 @@ def __exit__(self, *_: object) -> None: print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}") -def hrtime(t: float): +def hrtime(t: float) -> str: """Converts a time in seconds to a reasonable human readable time. Args: @@ -86,7 +88,7 @@ def hrtime(t: float): return timestr -def profile(func): +def profile(func: Callable[..., Any]) -> Callable[..., Any]: """Timing (profile) decorator for a function.""" calls = list() diff --git a/tests/test_images.py b/tests/test_images.py index 8c6c05d..36ff869 100644 --- a/tests/test_images.py +++ b/tests/test_images.py @@ -17,6 +17,7 @@ def test_img_corr_mode(): assert len(fig.axes) == 2 plt.close(fig) + def test_cmat_labels_and_colorbar(): data = np.array([[0.0, 1.0], [1.0, 0.0]]) fig, ax = plt.subplots() diff --git a/tests/test_utils.py b/tests/test_utils.py index 5d3be0e..f54e6d9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -52,6 +52,7 @@ def test_noticks(): plt.close(fig) + def test_get_bounds_spines(): fig, ax = plt.subplots() ax.plot([0, 1], [0, 1])