From b2517c5392f44500836ca5b932d3a0f16f09189e Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:03:50 -0700 Subject: [PATCH 01/17] Add repository guidelines --- AGENTS.md | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 60ef688..829a3da 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,14 +1,19 @@ -# Developer Instructions +# Repository Guidelines -## Code Style -- Run `uv run ruff check .` before committing to ensure all Python code passes linting. -- Write clear docstrings for all public functions and classes. -- Use relative imports within the `jetplot` package. +## Project Structure & Module Organization +Source lives under `src/jetplot/`, split into focused modules such as `colors.py`, `plots.py`, and `style.py`. Add new utilities in cohesive files and keep imports relative (e.g. `from .colors import Palette`). Tests reside in `tests/` and should mirror the module layout; prefer descriptive folders like `tests/test_plots.py` to match the feature under test. Documentation content is maintained in `docs/` and rendered via MkDocs, while build artefacts land in `build/` and `dist/`. -## Testing -- Run `uv run pytest --cov=jetplot --cov-report=term` before committing to ensure all tests pass before submitting a PR. -- Run `uv run pyrefly check` before committing to ensure all pyrefly type checking passes. +## Build, Test, and Development Commands +Run `uv run ruff check .` to ensure linting passes and style expectations are met. Use `uv run pytest --cov=jetplot --cov-report=term` for the full suite with coverage feedback, and `uv run pyrefly check` to validate typing across the package. When iterating on documentation, launch `uv run mkdocs serve` for a live preview. Regenerate distributions with `uv build` once changes are ready to publish. -## PR Guidelines -- Your pull request description must contain a **Summary** section explaining the changes. -- Include a **Testing** section describing the commands used to run lint and tests along with their results. +## Coding Style & Naming Conventions +Follow PEP 8 defaults: four-space indentation, snake_case for functions and module-level variables, and CapWords for classes. Exported constants stay upper-case with underscores. Provide clear docstrings for every public function or class that describes inputs, return values, and side effects. Keep modules small, favor pure functions, and rely on the shared helpers already defined in `style.py` and `chart_utils.py` instead of duplicating logic. + +## Testing Guidelines +Write pytest-based tests alongside new functionality, naming files `test_.py` and individual tests `test_`. Prefer parametrization to cover edge cases concisely. Aim to maintain or raise the coverage reported by the standard coverage command; unexpected drops should block merges. Include regression tests when fixing bugs so future refactors stay guarded. + +## Commit & Pull Request Guidelines +Commits use short, imperative summaries (e.g. `Add palette cycler helper`). Break large efforts into logical commits that pass linting and tests independently. Pull requests must include **Summary** and **Testing** sections outlining what changed and the exact commands run (`uv run ruff check .`, `uv run pytest --cov=jetplot --cov-report=term`, `uv run pyrefly check`). Link relevant issues and add screenshots when UI-facing artifacts such as documentation pages change. + +## Documentation Tips +Reference existing examples in `docs/` when adding guides, and keep code snippets synced with the APIs under `src/jetplot/`. Rebuild the site locally with `uv run mkdocs serve` after edits to verify navigation, formatting, and cross-links before submitting your changes. From b77abba8f3cb5979931ee577895857ad19fe643b Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:13:24 -0700 Subject: [PATCH 02/17] Fix palette plotting for single color palettes --- src/jetplot/colors.py | 10 ++++++++-- tests/test_colors.py | 12 ++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index ac91109..2a18426 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -30,13 +30,19 @@ def cmap(self) -> LinearSegmentedColormap: def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, list[Axes]]: """Visualize the colors in the palette.""" + if not self: + raise ValueError("Palette has no colors to plot.") + fig, axs = plt.subplots(1, len(self), figsize=figsize) - for c, ax in zip(self, axs, strict=True): # pyrefly: ignore + axs_array = np.atleast_1d(axs) + axis_list = [cast(Axes, ax) for ax in axs_array.flat] + + for c, ax in zip(self, axis_list, strict=True): ax.set_facecolor(c) ax.set_aspect("equal") noticks(ax=ax) - return fig, cast(list[Axes], axs) + return fig, axis_list def cubehelix( diff --git a/tests/test_colors.py b/tests/test_colors.py index 76da5cf..2200aa4 100644 --- a/tests/test_colors.py +++ b/tests/test_colors.py @@ -1,5 +1,6 @@ """Tests the colors module.""" +from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.colors import to_rgb from matplotlib.figure import Figure @@ -27,6 +28,17 @@ def test_palette(): assert isinstance(fig, Figure) for ax in axs: assert isinstance(ax, Axes) + plt.close(fig) + + +def test_palette_single_color_plot(): + pal = colors.Palette(["#123456"]) + fig, axs = pal.plot() + + assert len(axs) == 1 + assert axs[0].get_facecolor()[:3] == to_rgb("#123456") + + plt.close(fig) def test_rainbow(): From 05873f3611398634be9c9e47bc7a4102f7314259 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:14:27 -0700 Subject: [PATCH 03/17] Respect annot flag in confusion matrix helper --- src/jetplot/images.py | 23 ++++++++++++++++------- tests/test_images.py | 21 +++++++++++++++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/jetplot/images.py b/src/jetplot/images.py index 101c27b..2a8d847 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -134,21 +134,30 @@ def cmat( """Plot confusion matrix.""" num_rows, num_cols = arr.shape + label_list: list[str] | None = None + if labels is not None: + label_list = list(labels) + if len(label_list) != num_cols or num_rows != num_cols: + raise ValueError( + "Labels must match confusion matrix dimensions and matrix must be square." + ) + ax = kwargs.pop("ax") cb = imv(arr, ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, cbar=cbar) xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy") - for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore - color = dark_color if (value <= theta) else light_color - label = f"{{:{fmt}}}".format(value) - ax.text(x, y, label, ha="center", va="center", color=color, fontsize=fontsize) + if annot: + for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore + color = dark_color if (value <= theta) else light_color + label = f"{{:{fmt}}}".format(value) + ax.text(x, y, label, ha="center", va="center", color=color, fontsize=fontsize) - if labels is not None: + if label_list is not None: ax.set_xticks(np.arange(num_cols)) - ax.set_xticklabels(labels, rotation=90, fontsize=label_fontsize) + ax.set_xticklabels(label_list, rotation=90, fontsize=label_fontsize) ax.set_yticks(np.arange(num_rows)) - ax.set_yticklabels(labels, fontsize=label_fontsize) + ax.set_yticklabels(label_list, fontsize=label_fontsize) ax.xaxis.set_minor_locator(FixedLocator((np.arange(num_cols) - 0.5).tolist())) diff --git a/tests/test_images.py b/tests/test_images.py index 36ff869..d728588 100644 --- a/tests/test_images.py +++ b/tests/test_images.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from matplotlib import pyplot as plt from jetplot import images @@ -28,3 +29,23 @@ def test_cmat_labels_and_colorbar(): assert [tick.get_text() for tick in ax.get_yticklabels()] == ["a", "b"] assert len(fig.axes) == 2 plt.close(fig) + + +def test_cmat_without_annotations(): + data = np.array([[0.2, 0.8], [0.1, 0.9]]) + fig, ax = plt.subplots() + + images.cmat(data, annot=False, fig=fig, ax=ax) + + assert len(ax.texts) == 0 + plt.close(fig) + + +def test_cmat_label_mismatch_raises(): + data = np.eye(2) + fig, ax = plt.subplots() + + with pytest.raises(ValueError): + images.cmat(data, labels=["short"], fig=fig, ax=ax) + + plt.close(fig) From 2260d0421f613e6eedcbaa7eeea6966e6c712e23 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:15:41 -0700 Subject: [PATCH 04/17] Handle generator inputs in waterfall and ridgeline --- src/jetplot/plots.py | 19 +++++++++++++++---- tests/test_plots.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index 61b1b4d..0d64903 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -297,9 +297,12 @@ def waterfall( ) -> None: """Waterfall plot.""" ax = kwargs["ax"] - total = cast(int, len(ys)) + ys_list = list(ys) + if not ys_list: + raise ValueError("ys must contain at least one series.") + total = len(ys_list) - for index, y in enumerate(ys): + for index, y in enumerate(ys_list): zorder = total - index y = y * dy + index ax.plot(x, y + pad, color=ec, clip_on=False, lw=ew, zorder=zorder) @@ -320,10 +323,18 @@ def ridgeline( ) -> tuple[Figure, list[Axes]]: """Stacked density plots reminiscent of a ridgeline plot.""" fig = kwargs["fig"] + xs_list = list(xs) + color_list = list(colors) + + if not xs_list: + raise ValueError("xs must contain at least one series.") + if len(xs_list) != len(color_list): + raise ValueError("xs and colors must have the same length.") + axs = [] - for k, (x, c) in enumerate(zip(xs, colors, strict=False)): - ax = fig.add_subplot(cast(int, len(xs)), 1, k + 1) + for k, (x, c) in enumerate(zip(xs_list, color_list, strict=True)): + ax = fig.add_subplot(len(xs_list), 1, k + 1) y = gaussian_kde(x).evaluate(t) ax.fill_between(t, y, color=c, clip_on=False) ax.plot(t, y, color=edgecolor, clip_on=False) diff --git a/tests/test_plots.py b/tests/test_plots.py index 5b0ed21..27a3c29 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from matplotlib import pyplot as plt from jetplot import plots @@ -75,6 +76,17 @@ def test_waterfall(): plt.close(fig) +def test_waterfall_accepts_generators(): + x = np.arange(5) + ys = (np.linspace(0, 1, 5) for _ in range(3)) + + fig, ax = plt.subplots() + plots.waterfall(x, ys, fig=fig, ax=ax) + + assert len(ax.collections) >= 3 + plt.close(fig) + + def test_violinplot(): data = np.random.randn(100) fig, ax = plt.subplots() @@ -82,3 +94,25 @@ def test_violinplot(): # Expect at least one polygon from violin body assert len(ax.collections) > 0 plt.close(fig) + + +def test_ridgeline_accepts_generators(): + rng = np.random.default_rng(0) + t = np.linspace(-3, 3, 25) + xs = (rng.standard_normal(100) for _ in range(3)) + colors = (color for color in plots.neutral[:3]) + + fig, axs = plots.ridgeline(t, xs=xs, colors=colors) + assert len(axs) == 3 + plt.close(fig) + + +def test_ridgeline_mismatched_lengths_raise(): + t = np.linspace(-3, 3, 10) + xs = [np.linspace(0, 1, 5)] + colors = (color for color in plots.neutral[:2]) + + with pytest.raises(ValueError): + plots.ridgeline(t, xs=xs, colors=colors) + + plt.close("all") From 23f64f16f28868d5b497766785c2e789d591230d Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:16:47 -0700 Subject: [PATCH 05/17] Scope image colorbars to target axes --- src/jetplot/images.py | 8 +++++--- tests/test_images.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/jetplot/images.py b/src/jetplot/images.py index 2a8d847..ba82db2 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -68,16 +68,18 @@ def img( raise ValueError("Unrecognized mode: '" + mode + "'") # make the image - im = kwargs["ax"].imshow( + ax = kwargs["ax"] + fig = kwargs.get("fig", ax.get_figure()) + im = ax.imshow( img, cmap=cmap, interpolation=interpolation, vmin=vmin, vmax=vmax, aspect=aspect ) # colorbar if cbar: - plt.colorbar(im) + fig.colorbar(im, ax=ax) # clear ticks - noticks(ax=kwargs["ax"]) + noticks(ax=ax) return im diff --git a/tests/test_images.py b/tests/test_images.py index d728588..85be406 100644 --- a/tests/test_images.py +++ b/tests/test_images.py @@ -19,6 +19,21 @@ def test_img_corr_mode(): plt.close(fig) +def test_img_colorbar_attached_to_given_axes(): + data = np.eye(3) + fig, (ax_left, ax_right) = plt.subplots(1, 2) + im_left = images.img(data, fig=fig, ax=ax_left) + images.img(data, cbar=False, fig=fig, ax=ax_right) + + assert im_left in ax_left.images + # Expect one additional axes (colorbar) attached to the same figure + colorbar_axes = [ax for ax in fig.axes if ax not in {ax_left, ax_right}] + assert len(colorbar_axes) == 1 + assert colorbar_axes[0].figure is fig + + plt.close(fig) + + def test_cmat_labels_and_colorbar(): data = np.array([[0.0, 1.0], [1.0, 0.0]]) fig, ax = plt.subplots() From da089a806c4e747f983338b0870b43809b6900e5 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:17:48 -0700 Subject: [PATCH 06/17] Safeguard normalization edge cases --- src/jetplot/signals.py | 17 +++++++++++++++-- tests/test_signals.py | 19 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/jetplot/signals.py b/src/jetplot/signals.py index ab9dcf4..309f71d 100644 --- a/src/jetplot/signals.py +++ b/src/jetplot/signals.py @@ -28,7 +28,8 @@ def smooth(x: ArrayLike, sigma: float = 1.0, axis: int = 0) -> NDArray[np.floati def stable_rank(X: NDArray[np.floating[Any]]) -> float: """Computes the stable rank of a matrix""" - assert X.ndim == 2, "X must be a matrix" + if X.ndim != 2: + raise ValueError("X must be a matrix") # pyrefly: ignore svals_sq = np.linalg.svd(X, compute_uv=False, full_matrices=False) ** 2 @@ -100,4 +101,16 @@ def normalize( Returns: Xn: Arrays that have been normalized using to the given function. """ - return np.asarray(X) / norm(X, axis=axis, keepdims=True) + arr = np.asarray(X, dtype=float) + denom = norm(arr, axis=axis, keepdims=True) + zero_mask = denom == 0 + + # Avoid divide-by-zero warnings and keep zeros in place by dividing only where safe. + safe_denom = np.where(zero_mask, 1.0, denom) + normalized = np.zeros_like(arr, dtype=float) + np.divide(arr, safe_denom, out=normalized, where=~zero_mask) + + if np.any(zero_mask): + normalized = np.where(zero_mask, 0.0, normalized) + + return normalized diff --git a/tests/test_signals.py b/tests/test_signals.py index 0a4f763..6d4c921 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,6 +1,7 @@ """Tests for the signals module.""" import numpy as np +import pytest from jetplot import signals @@ -96,3 +97,21 @@ def test_normalize(): expected = np.stack([x / np.linalg.norm(x) for x in X.T]).T computed = signals.normalize(X, axis=0) assert np.allclose(expected, computed) + + +def test_stable_rank_invalid_shape(): + with pytest.raises(ValueError): + signals.stable_rank(np.ones(3)) + + +def test_normalize_handles_zero_vectors(): + X = np.array([[0.0, 0.0, 0.0], [3.0, 0.0, 4.0]]) + normalized = signals.normalize(X) + + assert np.allclose(normalized[0], 0.0) + assert np.allclose(normalized[1], np.array([0.6, 0.0, 0.8])) + + X = np.array([[0.0, 1.0], [0.0, 0.0]]) + normalized_cols = signals.normalize(X, axis=0) + assert np.allclose(normalized_cols[:, 0], 0.0) + assert np.allclose(normalized_cols[:, 1], np.array([1.0, 0.0])) From 097a5299d29c8d52780e82f2dc7cffaf77470b89 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:20:12 -0700 Subject: [PATCH 07/17] Clarify plotting APIs and add coverage --- src/jetplot/images.py | 43 +++++++++++++++++++++++++++++++++++------- src/jetplot/plots.py | 31 +++++++++++++++++++++++++++--- src/jetplot/signals.py | 15 +++++++++++++-- tests/test_plots.py | 24 +++++++++++++++++++++++ 4 files changed, 101 insertions(+), 12 deletions(-) diff --git a/src/jetplot/images.py b/src/jetplot/images.py index ba82db2..23ecb81 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -31,12 +31,21 @@ def img( """Visualize a matrix as an image. Args: - img: array_like, The array to visualize. - mode: string, One of 'div' for a diverging image, 'seq' for - sequential, 'cov' for covariance matrices, or 'corr' for - correlation matrices (default: 'div'). - cmap: string, Colormap to use. - aspect: string, Either 'equal' or 'auto' + data: Array to visualize. + mode: One of ``"div"``, ``"seq"``, ``"cov"``, or ``"corr"``. + cmap: Matplotlib colormap name. Mode defaults are used when ``None``. + aspect: Either ``"equal"`` or ``"auto"``. + vmin: Lower bound for normalization. + vmax: Upper bound for normalization. + cbar: Whether to draw a colorbar attached to the provided axes. + interpolation: Interpolation strategy passed to ``imshow``. + + Raises: + ValueError: If ``mode`` is not recognized. + + Notes: + When ``cbar`` is ``True``, the colorbar is added to the supplied axes/figure + so multi-axes layouts keep their layout intact. """ # work with a copy of the original image data img = np.squeeze(data.copy()) @@ -133,7 +142,27 @@ def cmat( vmax: float = 1.0, **kwargs: Any, ) -> tuple[AxesImage, Axes]: - """Plot confusion matrix.""" + """Plot a confusion matrix with optional annotations. + + Args: + arr: Square matrix of scores in [0, 1]. + labels: Optional axis labels. Must match matrix dimensions. + annot: Whether to draw text annotations for each cell. + cmap: Colormap used for the heatmap. + cbar: Whether to include a colorbar. + fmt: Format string applied to annotation labels. + dark_color: Text color used when ``value <= theta``. + light_color: Text color used when ``value > theta``. + grid_color: Grid line color. + theta: Threshold for choosing between ``dark_color`` and ``light_color``. + label_fontsize: Tick label font size. + fontsize: Annotation font size. + vmin: Lower bound for normalization. + vmax: Upper bound for normalization. + + Raises: + ValueError: If labels are provided but do not match the matrix dimensions. + """ num_rows, num_cols = arr.shape label_list: list[str] | None = None diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index 0d64903..d72848b 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -295,7 +295,20 @@ def waterfall( ew: float = 2.0, **kwargs: Any, ) -> None: - """Waterfall plot.""" + """Waterfall plot for stacked sequences. + + Args: + x: Common x-axis samples shared by every series. + ys: Iterable of y-series. Generators are supported and are consumed once. + dy: Vertical scaling applied to each successive series. + pad: Offset applied so the outline sits slightly above the fill. + color: Fill color for each series. + ec: Edge color for the outline. + ew: Edge line width. + + Raises: + ValueError: If ``ys`` yields no series. + """ ax = kwargs["ax"] ys_list = list(ys) if not ys_list: @@ -321,7 +334,18 @@ def ridgeline( ymax: float = 0.6, **kwargs: Any, ) -> tuple[Figure, list[Axes]]: - """Stacked density plots reminiscent of a ridgeline plot.""" + """Stacked density plots reminiscent of a ridgeline plot. + + Args: + t: Grid used when evaluating the kernel density estimate. + xs: Iterable of 1-D samples. Accepts generators but consumes them eagerly. + colors: Iterable of colors, one for each series in ``xs``. + edgecolor: Line color used for the outline. + ymax: Upper y-limit for each subplot. + + Raises: + ValueError: If ``xs`` is empty or the number of colors does not match. + """ fig = kwargs["fig"] xs_list = list(xs) color_list = list(colors) @@ -389,7 +413,8 @@ def ellipse( ------- matplotlib.patches.Ellipse """ - ax = cast(Axes, kwargs.get("ax")) + ax = cast(Axes, kwargs.pop("ax", None)) + kwargs.pop("fig", None) if x.size != y.size: raise ValueError("x and y must be the same size") diff --git a/src/jetplot/signals.py b/src/jetplot/signals.py index 309f71d..0aaa4ef 100644 --- a/src/jetplot/signals.py +++ b/src/jetplot/signals.py @@ -27,7 +27,14 @@ def smooth(x: ArrayLike, sigma: float = 1.0, axis: int = 0) -> NDArray[np.floati def stable_rank(X: NDArray[np.floating[Any]]) -> float: - """Computes the stable rank of a matrix""" + """Compute the stable rank of a matrix. + + Args: + X: Two-dimensional array representing a matrix. + + Raises: + ValueError: If ``X`` is not two-dimensional. + """ if X.ndim != 2: raise ValueError("X must be a matrix") @@ -99,7 +106,11 @@ def normalize( norm: Function that computes the norm (Default: np.linalg.norm). Returns: - Xn: Arrays that have been normalized using to the given function. + Normalized array with the same shape as ``X``. + + Notes: + Any vectors whose norm is zero remain zero after normalization instead of + producing NaNs or infinities. """ arr = np.asarray(X, dtype=float) denom = norm(arr, axis=axis, keepdims=True) diff --git a/tests/test_plots.py b/tests/test_plots.py index 27a3c29..4296f82 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -1,6 +1,7 @@ import numpy as np import pytest from matplotlib import pyplot as plt +from matplotlib.patches import Ellipse from jetplot import plots @@ -116,3 +117,26 @@ def test_ridgeline_mismatched_lengths_raise(): plots.ridgeline(t, xs=xs, colors=colors) plt.close("all") + + +def test_ellipse_returns_patch(): + rng = np.random.default_rng(1) + x = rng.standard_normal(200) + y = x + 0.1 * rng.standard_normal(200) + + fig, ax = plt.subplots() + patch = plots.ellipse(x, y, fig=fig, ax=ax) + + assert isinstance(patch, Ellipse) + assert patch in ax.patches + plt.close(fig) + + +def test_ellipse_length_mismatch_raises(): + x = np.arange(5) + y = np.arange(4) + + fig, ax = plt.subplots() + with pytest.raises(ValueError): + plots.ellipse(x, y, fig=fig, ax=ax) + plt.close(fig) From d9c36791ccf198ebac59cff0efcc8bd115ee875f Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:21:29 -0700 Subject: [PATCH 08/17] Remove unused pyplot import --- src/jetplot/images.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/jetplot/images.py b/src/jetplot/images.py index 23ecb81..e103e24 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -5,7 +5,6 @@ from typing import Any, cast import numpy as np -from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.image import AxesImage from matplotlib.ticker import FixedLocator From 1c830485e98d6e88aa4322c1d6499575e5942ca0 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:22:01 -0700 Subject: [PATCH 09/17] Flush stopwatch prints --- src/jetplot/timepiece.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jetplot/timepiece.py b/src/jetplot/timepiece.py index 9d7b6c4..796676b 100644 --- a/src/jetplot/timepiece.py +++ b/src/jetplot/timepiece.py @@ -29,14 +29,14 @@ def elapsed(self) -> float: return elapsed def checkpoint(self, name: str = "") -> None: - print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip()) + print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip(), flush=True) def __enter__(self) -> "Stopwatch": return self def __exit__(self, *_: object) -> None: total = hrtime(time.perf_counter() - self.absolute_start) - print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}") + print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}", flush=True) def hrtime(t: float) -> str: From fa8ba84a031f978c194f9573327f32e74a9deb90 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:25:06 -0700 Subject: [PATCH 10/17] Satisfy pyrefly typing for figure colorbar --- src/jetplot/images.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/jetplot/images.py b/src/jetplot/images.py index e103e24..18f795f 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -8,6 +8,7 @@ from matplotlib.axes import Axes from matplotlib.image import AxesImage from matplotlib.ticker import FixedLocator +from matplotlib.figure import Figure from . import colors as c from .chart_utils import noticks, plotwrapper @@ -77,7 +78,8 @@ def img( # make the image ax = kwargs["ax"] - fig = kwargs.get("fig", ax.get_figure()) + fig_candidate = kwargs.get("fig") + fig = cast(Figure, fig_candidate if fig_candidate is not None else ax.get_figure()) im = ax.imshow( img, cmap=cmap, interpolation=interpolation, vmin=vmin, vmax=vmax, aspect=aspect ) From 558fee902edf7c82ab25a4cdab6dcc02f8a3bc17 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:29:33 -0700 Subject: [PATCH 11/17] Preserve lazy color iteration in ridgeline --- src/jetplot/plots.py | 21 ++++++++++++--------- tests/test_plots.py | 14 ++++++++++++-- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index d72848b..9e080c6 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -338,31 +338,34 @@ def ridgeline( Args: t: Grid used when evaluating the kernel density estimate. - xs: Iterable of 1-D samples. Accepts generators but consumes them eagerly. - colors: Iterable of colors, one for each series in ``xs``. + xs: Iterable of 1-D samples. Accepts generators and consumes them once. + colors: Iterable of colors. Must provide at least as many entries as ``xs``. edgecolor: Line color used for the outline. ymax: Upper y-limit for each subplot. Raises: - ValueError: If ``xs`` is empty or the number of colors does not match. + ValueError: If ``xs`` is empty or ``colors`` provides too few values. """ fig = kwargs["fig"] xs_list = list(xs) - color_list = list(colors) + colors_iter = iter(colors) if not xs_list: raise ValueError("xs must contain at least one series.") - if len(xs_list) != len(color_list): - raise ValueError("xs and colors must have the same length.") axs = [] - for k, (x, c) in enumerate(zip(xs_list, color_list, strict=True)): + for k, x in enumerate(xs_list): + try: + palette_color = next(colors_iter) + except StopIteration as exc: + raise ValueError("colors must provide at least as many items as xs.") from exc + ax = fig.add_subplot(len(xs_list), 1, k + 1) y = gaussian_kde(x).evaluate(t) - ax.fill_between(t, y, color=c, clip_on=False) + ax.fill_between(t, y, color=palette_color, clip_on=False) ax.plot(t, y, color=edgecolor, clip_on=False) - ax.axhline(0.0, lw=2, color=c, clip_on=False) + ax.axhline(0.0, lw=2, color=palette_color, clip_on=False) ax.set_xlim(t[0], t[-1]) ax.set_xticks([]) diff --git a/tests/test_plots.py b/tests/test_plots.py index 4296f82..f4c1bc2 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -110,8 +110,8 @@ def test_ridgeline_accepts_generators(): def test_ridgeline_mismatched_lengths_raise(): t = np.linspace(-3, 3, 10) - xs = [np.linspace(0, 1, 5)] - colors = (color for color in plots.neutral[:2]) + xs = [np.linspace(0, 1, 5), np.linspace(0, 2, 5)] + colors = (color for color in plots.neutral[:1]) with pytest.raises(ValueError): plots.ridgeline(t, xs=xs, colors=colors) @@ -119,6 +119,16 @@ def test_ridgeline_mismatched_lengths_raise(): plt.close("all") +def test_ridgeline_allows_extra_colors(): + rng = np.random.default_rng(2) + t = np.linspace(-3, 3, 25) + xs = [rng.standard_normal(100) for _ in range(3)] + + fig, axs = plots.ridgeline(t, xs=xs, colors=plots.neutral) + assert len(axs) == 3 + plt.close(fig) + + def test_ellipse_returns_patch(): rng = np.random.default_rng(1) x = rng.standard_normal(200) From e7f6e93bc8f2b904ab8fdd58800ff72997919dc2 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:40:52 -0700 Subject: [PATCH 12/17] Fix pyrefly warnings for errorplot and smoothing --- src/jetplot/plots.py | 14 ++++++++++---- src/jetplot/signals.py | 3 ++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index 9e080c6..d60f326 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -163,13 +163,19 @@ def errorplot( """Plot a line with error bars.""" ax = kwargs["ax"] - if np.isscalar(yerr) or len(yerr) == len(y): # pyrefly: ignore + if np.isscalar(yerr): ymin = y - yerr # pyrefly: ignore ymax = y + yerr # pyrefly: ignore - elif len(yerr) == 2: + elif isinstance(yerr, tuple): + if len(yerr) != 2: + raise ValueError("Invalid yerr tuple length: ", yerr) ymin, ymax = yerr # pyrefly: ignore else: - raise ValueError("Invalid yerr value: ", yerr) + yerr_array = np.asarray(yerr) + if yerr_array.shape != y.shape: + raise ValueError("Invalid yerr value: ", yerr) + ymin = y - yerr_array + ymax = y + yerr_array if method == "line": ax.plot(x, y, fmt, color=color, linewidth=4, clip_on=clip_on) @@ -458,4 +464,4 @@ def ellipse( ) ellipse.set_transform(transform + ax.transData) # pyrefly: ignore - return ax.add_patch(ellipse) + return cast(Ellipse, ax.add_patch(ellipse)) diff --git a/src/jetplot/signals.py b/src/jetplot/signals.py index 0aaa4ef..23d6780 100644 --- a/src/jetplot/signals.py +++ b/src/jetplot/signals.py @@ -23,7 +23,8 @@ def smooth(x: ArrayLike, sigma: float = 1.0, axis: int = 0) -> NDArray[np.floati Returns: xs: array_like, A smoothed version of the input signal """ - return gaussian_filter1d(x, sigma, axis=axis) + arr = np.asarray(x) + return gaussian_filter1d(arr, sigma, axis=axis) def stable_rank(X: NDArray[np.floating[Any]]) -> float: From 6ee69b2f304f8ab826b7c3011401930f12d61eef Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 14 Oct 2025 21:43:06 -0700 Subject: [PATCH 13/17] Bump version to 0.6.6 --- src/jetplot/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jetplot/__init__.py b/src/jetplot/__init__.py index 8dee349..c68b48f 100644 --- a/src/jetplot/__init__.py +++ b/src/jetplot/__init__.py @@ -1,6 +1,6 @@ """Jetplot is a set of useful utility functions for scientific python.""" -__version__ = "0.6.5" +__version__ = "0.6.6" from . import colors as c # noqa: F401 from .chart_utils import * From a0df4a25d703cc0bf9613dfdd14b2dc1aaa194bb Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Wed, 15 Oct 2025 18:50:32 -0700 Subject: [PATCH 14/17] Fixes a couple of overeager edits. --- src/jetplot/colors.py | 9 ++------- src/jetplot/images.py | 4 +--- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index 2a18426..c4ccede 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -30,19 +30,14 @@ def cmap(self) -> LinearSegmentedColormap: def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, list[Axes]]: """Visualize the colors in the palette.""" - if not self: - raise ValueError("Palette has no colors to plot.") - fig, axs = plt.subplots(1, len(self), figsize=figsize) - axs_array = np.atleast_1d(axs) - axis_list = [cast(Axes, ax) for ax in axs_array.flat] - for c, ax in zip(self, axis_list, strict=True): + for c, ax in zip(self, axs, strict=True): ax.set_facecolor(c) ax.set_aspect("equal") noticks(ax=ax) - return fig, axis_list + return fig, axs def cubehelix( diff --git a/src/jetplot/images.py b/src/jetplot/images.py index 18f795f..30edf05 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -77,9 +77,7 @@ def img( raise ValueError("Unrecognized mode: '" + mode + "'") # make the image - ax = kwargs["ax"] - fig_candidate = kwargs.get("fig") - fig = cast(Figure, fig_candidate if fig_candidate is not None else ax.get_figure()) + fig, ax = kwargs["fig"], kwargs["ax"] im = ax.imshow( img, cmap=cmap, interpolation=interpolation, vmin=vmin, vmax=vmax, aspect=aspect ) From 5910eb044d656551b6e9e468b72056764f52102c Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Wed, 15 Oct 2025 18:56:35 -0700 Subject: [PATCH 15/17] Fixes lint and format errors. --- src/jetplot/colors.py | 11 ++++++++--- src/jetplot/images.py | 10 ++++++++-- src/jetplot/plots.py | 4 +++- tests/test_colors.py | 6 +++++- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index c4ccede..967cfd8 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -30,14 +30,19 @@ def cmap(self) -> LinearSegmentedColormap: def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, list[Axes]]: """Visualize the colors in the palette.""" + if not self: + raise ValueError("Palette has no colors to plot.") + fig, axs = plt.subplots(1, len(self), figsize=figsize) + axs_array = np.atleast_1d(axs) + axes_list = [cast(Axes, ax) for ax in axs_array.flat] - for c, ax in zip(self, axs, strict=True): - ax.set_facecolor(c) + for c, ax in zip(self, axes_list, strict=True): + ax.set_facecolor(c) # pyrefly: ignore ax.set_aspect("equal") noticks(ax=ax) - return fig, axs + return fig, axes_list def cubehelix( diff --git a/src/jetplot/images.py b/src/jetplot/images.py index 30edf05..b9437a8 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -178,10 +178,16 @@ def cmat( xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy") if annot: - for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore + # pyrefly: ignore + for x, y, value in zip( + xs.flat, ys.flat, arr.flat, strict=True # pyrefly: ignore + + ): color = dark_color if (value <= theta) else light_color label = f"{{:{fmt}}}".format(value) - ax.text(x, y, label, ha="center", va="center", color=color, fontsize=fontsize) + ax.text( + x, y, label, ha="center", va="center", color=color, fontsize=fontsize + ) if label_list is not None: ax.set_xticks(np.arange(num_cols)) diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index d60f326..0bf8435 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -365,7 +365,9 @@ def ridgeline( try: palette_color = next(colors_iter) except StopIteration as exc: - raise ValueError("colors must provide at least as many items as xs.") from exc + raise ValueError( + "colors must provide at least as many items as xs." + ) from exc ax = fig.add_subplot(len(xs_list), 1, k + 1) y = gaussian_kde(x).evaluate(t) diff --git a/tests/test_colors.py b/tests/test_colors.py index 2200aa4..640a715 100644 --- a/tests/test_colors.py +++ b/tests/test_colors.py @@ -32,7 +32,11 @@ def test_palette(): def test_palette_single_color_plot(): - pal = colors.Palette(["#123456"]) + pal = colors.Palette( + [ + "#123456", + ] + ) fig, axs = pal.plot() assert len(axs) == 1 From fe419cd4f53e47f257d23585bbb4fc378a85c714 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Wed, 15 Oct 2025 18:59:08 -0700 Subject: [PATCH 16/17] Fixes lint error. --- src/jetplot/images.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/jetplot/images.py b/src/jetplot/images.py index b9437a8..387f394 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -8,7 +8,6 @@ from matplotlib.axes import Axes from matplotlib.image import AxesImage from matplotlib.ticker import FixedLocator -from matplotlib.figure import Figure from . import colors as c from .chart_utils import noticks, plotwrapper @@ -178,10 +177,11 @@ def cmat( xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy") if annot: - # pyrefly: ignore - for x, y, value in zip( - xs.flat, ys.flat, arr.flat, strict=True # pyrefly: ignore - + for x, y, value in zip( # pyrefly: ignore + xs.flat, # pyrefly: ignore + ys.flat, + arr.flat, + strict=True, # pyrefly: ignore ): color = dark_color if (value <= theta) else light_color label = f"{{:{fmt}}}".format(value) From 5fa8edf623954452f283b2f393311dd78b9e9501 Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Wed, 15 Oct 2025 19:00:03 -0700 Subject: [PATCH 17/17] Updates default just command. --- justfile | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/justfile b/justfile index f573752..8414fe6 100644 --- a/justfile +++ b/justfile @@ -1,4 +1,4 @@ -default: test +default: format lint test typecheck build: uv build @@ -12,6 +12,9 @@ docs: format: uv run ruff format +lint: + uv run ruff check + typecheck: uv run pyrefly check