From 10b2d740cd989f2d2b6a6223dbc611cbbc331d4d Mon Sep 17 00:00:00 2001 From: Niru Maheswaranathan Date: Tue, 20 May 2025 10:57:04 -0700 Subject: [PATCH 1/4] Add type hints --- src/jetplot/chart_utils.py | 34 +++++++++++++++++++++++++++++++--- src/jetplot/colors.py | 22 ++++++++++++++++++---- src/jetplot/images.py | 12 +++++++++++- src/jetplot/plots.py | 28 +++++++++++++++++++++++++--- src/jetplot/timepiece.py | 14 ++++++++------ 5 files changed, 93 insertions(+), 17 deletions(-) diff --git a/src/jetplot/chart_utils.py b/src/jetplot/chart_utils.py index 6fe157d..705b8eb 100644 --- a/src/jetplot/chart_utils.py +++ b/src/jetplot/chart_utils.py @@ -3,6 +3,8 @@ from collections.abc import Callable from functools import partial, wraps +from matplotlib.axes import Axes + import numpy as np from matplotlib import pyplot as plt @@ -114,7 +116,21 @@ def nospines(left=False, bottom=False, top=True, right=True, **kwargs): return ax -def get_bounds(axis, ax=None): +def get_bounds(axis: str, ax: Axes | None = None) -> tuple[float, float]: + """Return the axis spine bounds for the given axis. + + Parameters + ---------- + axis : str + Axis to inspect, either ``"x"`` or ``"y"``. + ax : matplotlib.axes.Axes | None, optional + Axes object to inspect. If ``None``, the current axes are used. + + Returns + ------- + tuple[float, float] + Lower and upper bounds of the axis spine. + """ if ax is None: ax = plt.gca() @@ -187,7 +203,13 @@ def identity(x): @axwrapper -def yclamp(y0=None, y1=None, dt=None, **kwargs): +def yclamp( + y0: float | None = None, + y1: float | None = None, + dt: float | None = None, + **kwargs, +) -> Axes: + """Clamp the y-axis to evenly spaced tick marks.""" ax = kwargs["ax"] lims = ax.get_ylim() @@ -206,7 +228,13 @@ def yclamp(y0=None, y1=None, dt=None, **kwargs): @axwrapper -def xclamp(x0=None, x1=None, dt=None, **kwargs): +def xclamp( + x0: float | None = None, + x1: float | None = None, + dt: float | None = None, + **kwargs, +) -> Axes: + """Clamp the x-axis to evenly spaced tick marks.""" ax = kwargs["ax"] lims = ax.get_xlim() diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index 040e2ad..9938a6c 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -4,7 +4,10 @@ from matplotlib import cm from matplotlib import pyplot as plt from matplotlib.colors import LinearSegmentedColormap, to_hex +from matplotlib.figure import Figure +from matplotlib.axes import Axes from matplotlib.typing import ColorType +from numpy.typing import NDArray from .chart_utils import noticks @@ -15,14 +18,17 @@ class Palette(list[ColorType]): """Color palette based on a list of values.""" @property - def hex(self): + def hex(self) -> "Palette": + """Return the palette colors as hexadecimal strings.""" return Palette([to_hex(rgb) for rgb in self]) @property - def cmap(self): + def cmap(self) -> LinearSegmentedColormap: + """Return the palette as a Matplotlib colormap.""" return LinearSegmentedColormap.from_list("", self) - def plot(self, figsize=(5, 1)): + def plot(self, figsize: tuple[float, float] = (5, 1)) -> tuple[Figure, NDArray[Axes]]: + """Visualize the colors in the palette.""" fig, axs = plt.subplots(1, len(self), figsize=figsize) for c, ax in zip(self, axs, strict=True): # pyrefly: ignore ax.set_facecolor(c) @@ -54,7 +60,13 @@ def cubehelix( return Palette(colors) -def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0): +def cmap_colors( + cmap: str, + n: int, + vmin: float = 0.0, + vmax: float = 1.0, +) -> Palette: + """Extract ``n`` colors from a Matplotlib colormap.""" return Palette(getattr(cm, cmap)(np.linspace(vmin, vmax, n))) @@ -371,6 +383,8 @@ def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0): def rainbow(k: int) -> Palette: + """Return a palette of distinct colors from several base palettes.""" + _colors = ( blue, orange, diff --git a/src/jetplot/images.py b/src/jetplot/images.py index c464786..eb028ba 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -1,10 +1,12 @@ """Image visualization tools.""" from functools import partial +from collections.abc import Callable import numpy as np from matplotlib import pyplot as plt from matplotlib.ticker import FixedLocator +from matplotlib.axes import Axes from . import colors as c from .chart_utils import noticks, plotwrapper @@ -79,7 +81,15 @@ def img( @plotwrapper -def fsurface(func, xrng=None, yrng=None, n=100, nargs=2, **kwargs): +def fsurface( + func: Callable[..., np.ndarray], + xrng: tuple[float, float] | None = None, + yrng: tuple[float, float] | None = None, + n: int = 100, + nargs: int = 2, + **kwargs, +) -> None: + """Plot a 2‑D function as a filled surface.""" xrng = (-1, 1) if xrng is None else xrng yrng = xrng if yrng is None else yrng diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index e57f18b..5e090c2 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -4,6 +4,9 @@ from matplotlib.patches import Ellipse from matplotlib.transforms import Affine2D from matplotlib.typing import ColorType +from matplotlib.figure import Figure +from matplotlib.axes import Axes +from collections.abc import Sequence from numpy.typing import NDArray from scipy.stats import gaussian_kde from sklearn.covariance import EmpiricalCovariance, MinCovDet @@ -35,7 +38,8 @@ def violinplot( showmeans=False, showquartiles=True, **kwargs, -): +) -> Axes: + """Violin plot with customizable elements.""" _ = kwargs.pop("fig") ax = kwargs.pop("ax") @@ -86,6 +90,8 @@ def violinplot( zorder=20, ) + return ax + @plotwrapper def hist(*args, **kwargs): @@ -249,7 +255,13 @@ def bar( @plotwrapper -def lines(x, lines=None, cmap="viridis", **kwargs): +def lines( + x: NDArray[np.floating] | NDArray[np.integer], + lines: list[NDArray[np.floating]] | None = None, + cmap: str = "viridis", + **kwargs, +) -> Axes: + """Plot multiple lines using a color map.""" ax = kwargs["ax"] if lines is None: @@ -263,6 +275,8 @@ def lines(x, lines=None, cmap="viridis", **kwargs): for line, color in zip(lines, colors, strict=False): ax.plot(x, line, color=color) + return ax + @plotwrapper def waterfall(x, ys, dy=1.0, pad=0.1, color="#444444", ec="#cccccc", ew=2.0, **kwargs): @@ -281,7 +295,15 @@ def waterfall(x, ys, dy=1.0, pad=0.1, color="#444444", ec="#cccccc", ew=2.0, **k @figwrapper -def ridgeline(t, xs, colors, edgecolor="#ffffff", ymax=0.6, **kwargs): +def ridgeline( + t: NDArray[np.floating], + xs: Sequence[NDArray[np.floating]], + colors: Sequence[ColorType], + edgecolor: ColorType = "#ffffff", + ymax: float = 0.6, + **kwargs, +) -> tuple[Figure, list[Axes]]: + """Stacked density plots reminiscent of a ridgeline plot.""" fig = kwargs["fig"] axs = [] diff --git a/src/jetplot/timepiece.py b/src/jetplot/timepiece.py index dd86946..210d536 100644 --- a/src/jetplot/timepiece.py +++ b/src/jetplot/timepiece.py @@ -9,28 +9,30 @@ class Stopwatch: - def __init__(self, name=""): + """Simple timer utility for measuring code execution time.""" + + def __init__(self, name: str = "") -> None: self.name = name self.start = time.perf_counter() self.absolute_start = time.perf_counter() - def __str__(self): + def __str__(self) -> str: return "\u231a Stopwatch for: " + self.name @property - def elapsed(self): + def elapsed(self) -> float: current = time.perf_counter() elapsed = current - self.start self.start = time.perf_counter() return elapsed - def checkpoint(self, name=""): + def checkpoint(self, name: str = "") -> None: print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip()) - def __enter__(self): + def __enter__(self) -> "Stopwatch": return self - def __exit__(self, *_): + def __exit__(self, *_: object) -> None: total = hrtime(time.perf_counter() - self.absolute_start) print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}") From 2ea7c34a586803e24a761a6ebcf324f780107680 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 20 May 2025 12:36:19 -0700 Subject: [PATCH 2/4] Fixes type errors. --- src/jetplot/chart_utils.py | 8 +++----- src/jetplot/colors.py | 4 ++-- src/jetplot/plots.py | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/jetplot/chart_utils.py b/src/jetplot/chart_utils.py index 705b8eb..7762a67 100644 --- a/src/jetplot/chart_utils.py +++ b/src/jetplot/chart_utils.py @@ -2,11 +2,11 @@ from collections.abc import Callable from functools import partial, wraps - -from matplotlib.axes import Axes +from typing import Any import numpy as np from matplotlib import pyplot as plt +from matplotlib.axes import Axes __all__ = [ "noticks", @@ -135,9 +135,7 @@ def get_bounds(axis: str, ax: Axes | None = None) -> tuple[float, float]: ax = plt.gca() - Result = tuple[Callable[[], list[float]], Callable[[], list[str]], Callable[[], tuple[float, float]], str] - - axis_map: dict[str, Result] = { + axis_map: dict[str, Any] = { "x": (ax.get_xticks, ax.get_xticklabels, ax.get_xlim, "bottom"), "y": (ax.get_yticks, ax.get_yticklabels, ax.get_ylim, "left"), } diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index 9938a6c..407c384 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -3,9 +3,9 @@ import numpy as np from matplotlib import cm from matplotlib import pyplot as plt +from matplotlib.axes import Axes from matplotlib.colors import LinearSegmentedColormap, to_hex from matplotlib.figure import Figure -from matplotlib.axes import Axes from matplotlib.typing import ColorType from numpy.typing import NDArray @@ -18,7 +18,7 @@ class Palette(list[ColorType]): """Color palette based on a list of values.""" @property - def hex(self) -> "Palette": + def hex(self): """Return the palette colors as hexadecimal strings.""" return Palette([to_hex(rgb) for rgb in self]) diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index 5e090c2..3182f01 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -265,7 +265,7 @@ def lines( ax = kwargs["ax"] if lines is None: - lines = list(x) + lines = list(x) # pyrefly: ignore x = np.arange(len(lines[0])) else: From fc7c888182d57b296cef321c1ac3548172d13bba Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 20 May 2025 13:13:52 -0700 Subject: [PATCH 3/4] Adjusts function signatures. --- src/jetplot/chart_utils.py | 5 ++--- src/jetplot/colors.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/jetplot/chart_utils.py b/src/jetplot/chart_utils.py index 7762a67..2dca8e6 100644 --- a/src/jetplot/chart_utils.py +++ b/src/jetplot/chart_utils.py @@ -1,8 +1,7 @@ """Plotting utils.""" -from collections.abc import Callable from functools import partial, wraps -from typing import Any +from typing import Any, Literal import numpy as np from matplotlib import pyplot as plt @@ -116,7 +115,7 @@ def nospines(left=False, bottom=False, top=True, right=True, **kwargs): return ax -def get_bounds(axis: str, ax: Axes | None = None) -> tuple[float, float]: +def get_bounds(axis: Literal["x", "y"], ax: Axes | None = None) -> tuple[float, float]: """Return the axis spine bounds for the given axis. Parameters diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index 407c384..9f0e388 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -27,7 +27,7 @@ def cmap(self) -> LinearSegmentedColormap: """Return the palette as a Matplotlib colormap.""" return LinearSegmentedColormap.from_list("", self) - def plot(self, figsize: tuple[float, float] = (5, 1)) -> tuple[Figure, NDArray[Axes]]: + def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, NDArray[Axes]]: """Visualize the colors in the palette.""" fig, axs = plt.subplots(1, len(self), figsize=figsize) for c, ax in zip(self, axs, strict=True): # pyrefly: ignore From f671e6d39da81431ab36d00579851e229f3e84b7 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 20 May 2025 13:15:14 -0700 Subject: [PATCH 4/4] Fixes formatting. --- src/jetplot/chart_utils.py | 5 ++--- src/jetplot/colors.py | 2 +- src/jetplot/images.py | 5 ++--- src/jetplot/plots.py | 2 +- src/jetplot/style.py | 4 ++-- 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/jetplot/chart_utils.py b/src/jetplot/chart_utils.py index 2dca8e6..2bf1da2 100644 --- a/src/jetplot/chart_utils.py +++ b/src/jetplot/chart_utils.py @@ -133,7 +133,6 @@ def get_bounds(axis: Literal["x", "y"], ax: Axes | None = None) -> tuple[float, if ax is None: ax = plt.gca() - axis_map: dict[str, Any] = { "x": (ax.get_xticks, ax.get_xticklabels, ax.get_xlim, "bottom"), "y": (ax.get_yticks, ax.get_yticklabels, ax.get_ylim, "left"), @@ -213,7 +212,7 @@ def yclamp( y0 = lims[0] if y0 is None else y0 y1 = lims[1] if y1 is None else y1 - ticks: list[float] = ax.get_yticks() # pyrefly: ignore + ticks: list[float] = ax.get_yticks() # pyrefly: ignore dt = float(np.mean(np.diff(ticks))) if dt is None else float(dt) new_ticks = np.arange(dt * np.floor(y0 / dt), dt * (np.ceil(y1 / dt) + 1), dt) @@ -238,7 +237,7 @@ def xclamp( x0 = lims[0] if x0 is None else x0 x1 = lims[1] if x1 is None else x1 - ticks: list[float] = ax.get_xticks() # pyrefly: ignore + ticks: list[float] = ax.get_xticks() # pyrefly: ignore dt = float(np.mean(np.diff(ticks))) if dt is None else float(dt) new_ticks = np.arange(dt * np.floor(x0 / dt), dt * (np.ceil(x1 / dt) + 1), dt) diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index 9f0e388..088d91f 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -30,7 +30,7 @@ def cmap(self) -> LinearSegmentedColormap: def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, NDArray[Axes]]: """Visualize the colors in the palette.""" fig, axs = plt.subplots(1, len(self), figsize=figsize) - for c, ax in zip(self, axs, strict=True): # pyrefly: ignore + for c, ax in zip(self, axs, strict=True): # pyrefly: ignore ax.set_facecolor(c) ax.set_aspect("equal") noticks(ax=ax) diff --git a/src/jetplot/images.py b/src/jetplot/images.py index eb028ba..afb71a7 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -1,12 +1,11 @@ """Image visualization tools.""" -from functools import partial from collections.abc import Callable +from functools import partial import numpy as np from matplotlib import pyplot as plt from matplotlib.ticker import FixedLocator -from matplotlib.axes import Axes from . import colors as c from .chart_utils import noticks, plotwrapper @@ -137,7 +136,7 @@ def cmat( xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy") - for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore + for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore color = dark_color if (value <= theta) else light_color annot = f"{{:{fmt}}}".format(value) ax.text(x, y, annot, ha="center", va="center", color=color, fontsize=fontsize) diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index 3182f01..6b17e77 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -265,7 +265,7 @@ def lines( ax = kwargs["ax"] if lines is None: - lines = list(x) # pyrefly: ignore + lines = list(x) # pyrefly: ignore x = np.arange(len(lines[0])) else: diff --git a/src/jetplot/style.py b/src/jetplot/style.py index 4cd5eba..930535a 100644 --- a/src/jetplot/style.py +++ b/src/jetplot/style.py @@ -140,7 +140,7 @@ def set_defaults( def available_fonts() -> list[str]: """Returns a list of available fonts.""" - return sorted(set([f.name for f in fm.fontManager.ttflist])) # pyrefly: ignore + return sorted(set([f.name for f in fm.fontManager.ttflist])) # pyrefly: ignore def install_fonts(filepath: str): @@ -150,7 +150,7 @@ def install_fonts(filepath: str): font_files = fm.findSystemFonts(fontpaths=[filepath]) for font_file in font_files: - fm.fontManager.addfont(font_file) # pyrefly: ignore + fm.fontManager.addfont(font_file) # pyrefly: ignore new_fonts = set(available_fonts()) - original_fonts if new_fonts: