diff --git a/AGENTS.md b/AGENTS.md index 60ef688..829a3da 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,14 +1,19 @@ -# Developer Instructions +# Repository Guidelines -## Code Style -- Run `uv run ruff check .` before committing to ensure all Python code passes linting. -- Write clear docstrings for all public functions and classes. -- Use relative imports within the `jetplot` package. +## Project Structure & Module Organization +Source lives under `src/jetplot/`, split into focused modules such as `colors.py`, `plots.py`, and `style.py`. Add new utilities in cohesive files and keep imports relative (e.g. `from .colors import Palette`). Tests reside in `tests/` and should mirror the module layout; prefer descriptive folders like `tests/test_plots.py` to match the feature under test. Documentation content is maintained in `docs/` and rendered via MkDocs, while build artefacts land in `build/` and `dist/`. -## Testing -- Run `uv run pytest --cov=jetplot --cov-report=term` before committing to ensure all tests pass before submitting a PR. -- Run `uv run pyrefly check` before committing to ensure all pyrefly type checking passes. +## Build, Test, and Development Commands +Run `uv run ruff check .` to ensure linting passes and style expectations are met. Use `uv run pytest --cov=jetplot --cov-report=term` for the full suite with coverage feedback, and `uv run pyrefly check` to validate typing across the package. When iterating on documentation, launch `uv run mkdocs serve` for a live preview. Regenerate distributions with `uv build` once changes are ready to publish. -## PR Guidelines -- Your pull request description must contain a **Summary** section explaining the changes. -- Include a **Testing** section describing the commands used to run lint and tests along with their results. +## Coding Style & Naming Conventions +Follow PEP 8 defaults: four-space indentation, snake_case for functions and module-level variables, and CapWords for classes. Exported constants stay upper-case with underscores. Provide clear docstrings for every public function or class that describes inputs, return values, and side effects. Keep modules small, favor pure functions, and rely on the shared helpers already defined in `style.py` and `chart_utils.py` instead of duplicating logic. + +## Testing Guidelines +Write pytest-based tests alongside new functionality, naming files `test_.py` and individual tests `test_`. Prefer parametrization to cover edge cases concisely. Aim to maintain or raise the coverage reported by the standard coverage command; unexpected drops should block merges. Include regression tests when fixing bugs so future refactors stay guarded. + +## Commit & Pull Request Guidelines +Commits use short, imperative summaries (e.g. `Add palette cycler helper`). Break large efforts into logical commits that pass linting and tests independently. Pull requests must include **Summary** and **Testing** sections outlining what changed and the exact commands run (`uv run ruff check .`, `uv run pytest --cov=jetplot --cov-report=term`, `uv run pyrefly check`). Link relevant issues and add screenshots when UI-facing artifacts such as documentation pages change. + +## Documentation Tips +Reference existing examples in `docs/` when adding guides, and keep code snippets synced with the APIs under `src/jetplot/`. Rebuild the site locally with `uv run mkdocs serve` after edits to verify navigation, formatting, and cross-links before submitting your changes. diff --git a/justfile b/justfile index f573752..8414fe6 100644 --- a/justfile +++ b/justfile @@ -1,4 +1,4 @@ -default: test +default: format lint test typecheck build: uv build @@ -12,6 +12,9 @@ docs: format: uv run ruff format +lint: + uv run ruff check + typecheck: uv run pyrefly check diff --git a/src/jetplot/__init__.py b/src/jetplot/__init__.py index 8dee349..c68b48f 100644 --- a/src/jetplot/__init__.py +++ b/src/jetplot/__init__.py @@ -1,6 +1,6 @@ """Jetplot is a set of useful utility functions for scientific python.""" -__version__ = "0.6.5" +__version__ = "0.6.6" from . import colors as c # noqa: F401 from .chart_utils import * diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index ac91109..967cfd8 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -30,13 +30,19 @@ def cmap(self) -> LinearSegmentedColormap: def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, list[Axes]]: """Visualize the colors in the palette.""" + if not self: + raise ValueError("Palette has no colors to plot.") + fig, axs = plt.subplots(1, len(self), figsize=figsize) - for c, ax in zip(self, axs, strict=True): # pyrefly: ignore - ax.set_facecolor(c) + axs_array = np.atleast_1d(axs) + axes_list = [cast(Axes, ax) for ax in axs_array.flat] + + for c, ax in zip(self, axes_list, strict=True): + ax.set_facecolor(c) # pyrefly: ignore ax.set_aspect("equal") noticks(ax=ax) - return fig, cast(list[Axes], axs) + return fig, axes_list def cubehelix( diff --git a/src/jetplot/images.py b/src/jetplot/images.py index 101c27b..387f394 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -5,7 +5,6 @@ from typing import Any, cast import numpy as np -from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.image import AxesImage from matplotlib.ticker import FixedLocator @@ -31,12 +30,21 @@ def img( """Visualize a matrix as an image. Args: - img: array_like, The array to visualize. - mode: string, One of 'div' for a diverging image, 'seq' for - sequential, 'cov' for covariance matrices, or 'corr' for - correlation matrices (default: 'div'). - cmap: string, Colormap to use. - aspect: string, Either 'equal' or 'auto' + data: Array to visualize. + mode: One of ``"div"``, ``"seq"``, ``"cov"``, or ``"corr"``. + cmap: Matplotlib colormap name. Mode defaults are used when ``None``. + aspect: Either ``"equal"`` or ``"auto"``. + vmin: Lower bound for normalization. + vmax: Upper bound for normalization. + cbar: Whether to draw a colorbar attached to the provided axes. + interpolation: Interpolation strategy passed to ``imshow``. + + Raises: + ValueError: If ``mode`` is not recognized. + + Notes: + When ``cbar`` is ``True``, the colorbar is added to the supplied axes/figure + so multi-axes layouts keep their layout intact. """ # work with a copy of the original image data img = np.squeeze(data.copy()) @@ -68,16 +76,17 @@ def img( raise ValueError("Unrecognized mode: '" + mode + "'") # make the image - im = kwargs["ax"].imshow( + fig, ax = kwargs["fig"], kwargs["ax"] + im = ax.imshow( img, cmap=cmap, interpolation=interpolation, vmin=vmin, vmax=vmax, aspect=aspect ) # colorbar if cbar: - plt.colorbar(im) + fig.colorbar(im, ax=ax) # clear ticks - noticks(ax=kwargs["ax"]) + noticks(ax=ax) return im @@ -131,24 +140,60 @@ def cmat( vmax: float = 1.0, **kwargs: Any, ) -> tuple[AxesImage, Axes]: - """Plot confusion matrix.""" + """Plot a confusion matrix with optional annotations. + + Args: + arr: Square matrix of scores in [0, 1]. + labels: Optional axis labels. Must match matrix dimensions. + annot: Whether to draw text annotations for each cell. + cmap: Colormap used for the heatmap. + cbar: Whether to include a colorbar. + fmt: Format string applied to annotation labels. + dark_color: Text color used when ``value <= theta``. + light_color: Text color used when ``value > theta``. + grid_color: Grid line color. + theta: Threshold for choosing between ``dark_color`` and ``light_color``. + label_fontsize: Tick label font size. + fontsize: Annotation font size. + vmin: Lower bound for normalization. + vmax: Upper bound for normalization. + + Raises: + ValueError: If labels are provided but do not match the matrix dimensions. + """ num_rows, num_cols = arr.shape + label_list: list[str] | None = None + if labels is not None: + label_list = list(labels) + if len(label_list) != num_cols or num_rows != num_cols: + raise ValueError( + "Labels must match confusion matrix dimensions and matrix must be square." + ) + ax = kwargs.pop("ax") cb = imv(arr, ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, cbar=cbar) xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy") - for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore - color = dark_color if (value <= theta) else light_color - label = f"{{:{fmt}}}".format(value) - ax.text(x, y, label, ha="center", va="center", color=color, fontsize=fontsize) - - if labels is not None: + if annot: + for x, y, value in zip( # pyrefly: ignore + xs.flat, # pyrefly: ignore + ys.flat, + arr.flat, + strict=True, # pyrefly: ignore + ): + color = dark_color if (value <= theta) else light_color + label = f"{{:{fmt}}}".format(value) + ax.text( + x, y, label, ha="center", va="center", color=color, fontsize=fontsize + ) + + if label_list is not None: ax.set_xticks(np.arange(num_cols)) - ax.set_xticklabels(labels, rotation=90, fontsize=label_fontsize) + ax.set_xticklabels(label_list, rotation=90, fontsize=label_fontsize) ax.set_yticks(np.arange(num_rows)) - ax.set_yticklabels(labels, fontsize=label_fontsize) + ax.set_yticklabels(label_list, fontsize=label_fontsize) ax.xaxis.set_minor_locator(FixedLocator((np.arange(num_cols) - 0.5).tolist())) diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index 61b1b4d..0bf8435 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -163,13 +163,19 @@ def errorplot( """Plot a line with error bars.""" ax = kwargs["ax"] - if np.isscalar(yerr) or len(yerr) == len(y): # pyrefly: ignore + if np.isscalar(yerr): ymin = y - yerr # pyrefly: ignore ymax = y + yerr # pyrefly: ignore - elif len(yerr) == 2: + elif isinstance(yerr, tuple): + if len(yerr) != 2: + raise ValueError("Invalid yerr tuple length: ", yerr) ymin, ymax = yerr # pyrefly: ignore else: - raise ValueError("Invalid yerr value: ", yerr) + yerr_array = np.asarray(yerr) + if yerr_array.shape != y.shape: + raise ValueError("Invalid yerr value: ", yerr) + ymin = y - yerr_array + ymax = y + yerr_array if method == "line": ax.plot(x, y, fmt, color=color, linewidth=4, clip_on=clip_on) @@ -295,11 +301,27 @@ def waterfall( ew: float = 2.0, **kwargs: Any, ) -> None: - """Waterfall plot.""" + """Waterfall plot for stacked sequences. + + Args: + x: Common x-axis samples shared by every series. + ys: Iterable of y-series. Generators are supported and are consumed once. + dy: Vertical scaling applied to each successive series. + pad: Offset applied so the outline sits slightly above the fill. + color: Fill color for each series. + ec: Edge color for the outline. + ew: Edge line width. + + Raises: + ValueError: If ``ys`` yields no series. + """ ax = kwargs["ax"] - total = cast(int, len(ys)) + ys_list = list(ys) + if not ys_list: + raise ValueError("ys must contain at least one series.") + total = len(ys_list) - for index, y in enumerate(ys): + for index, y in enumerate(ys_list): zorder = total - index y = y * dy + index ax.plot(x, y + pad, color=ec, clip_on=False, lw=ew, zorder=zorder) @@ -318,16 +340,40 @@ def ridgeline( ymax: float = 0.6, **kwargs: Any, ) -> tuple[Figure, list[Axes]]: - """Stacked density plots reminiscent of a ridgeline plot.""" + """Stacked density plots reminiscent of a ridgeline plot. + + Args: + t: Grid used when evaluating the kernel density estimate. + xs: Iterable of 1-D samples. Accepts generators and consumes them once. + colors: Iterable of colors. Must provide at least as many entries as ``xs``. + edgecolor: Line color used for the outline. + ymax: Upper y-limit for each subplot. + + Raises: + ValueError: If ``xs`` is empty or ``colors`` provides too few values. + """ fig = kwargs["fig"] + xs_list = list(xs) + colors_iter = iter(colors) + + if not xs_list: + raise ValueError("xs must contain at least one series.") + axs = [] - for k, (x, c) in enumerate(zip(xs, colors, strict=False)): - ax = fig.add_subplot(cast(int, len(xs)), 1, k + 1) + for k, x in enumerate(xs_list): + try: + palette_color = next(colors_iter) + except StopIteration as exc: + raise ValueError( + "colors must provide at least as many items as xs." + ) from exc + + ax = fig.add_subplot(len(xs_list), 1, k + 1) y = gaussian_kde(x).evaluate(t) - ax.fill_between(t, y, color=c, clip_on=False) + ax.fill_between(t, y, color=palette_color, clip_on=False) ax.plot(t, y, color=edgecolor, clip_on=False) - ax.axhline(0.0, lw=2, color=c, clip_on=False) + ax.axhline(0.0, lw=2, color=palette_color, clip_on=False) ax.set_xlim(t[0], t[-1]) ax.set_xticks([]) @@ -378,7 +424,8 @@ def ellipse( ------- matplotlib.patches.Ellipse """ - ax = cast(Axes, kwargs.get("ax")) + ax = cast(Axes, kwargs.pop("ax", None)) + kwargs.pop("fig", None) if x.size != y.size: raise ValueError("x and y must be the same size") @@ -419,4 +466,4 @@ def ellipse( ) ellipse.set_transform(transform + ax.transData) # pyrefly: ignore - return ax.add_patch(ellipse) + return cast(Ellipse, ax.add_patch(ellipse)) diff --git a/src/jetplot/signals.py b/src/jetplot/signals.py index ab9dcf4..23d6780 100644 --- a/src/jetplot/signals.py +++ b/src/jetplot/signals.py @@ -23,12 +23,21 @@ def smooth(x: ArrayLike, sigma: float = 1.0, axis: int = 0) -> NDArray[np.floati Returns: xs: array_like, A smoothed version of the input signal """ - return gaussian_filter1d(x, sigma, axis=axis) + arr = np.asarray(x) + return gaussian_filter1d(arr, sigma, axis=axis) def stable_rank(X: NDArray[np.floating[Any]]) -> float: - """Computes the stable rank of a matrix""" - assert X.ndim == 2, "X must be a matrix" + """Compute the stable rank of a matrix. + + Args: + X: Two-dimensional array representing a matrix. + + Raises: + ValueError: If ``X`` is not two-dimensional. + """ + if X.ndim != 2: + raise ValueError("X must be a matrix") # pyrefly: ignore svals_sq = np.linalg.svd(X, compute_uv=False, full_matrices=False) ** 2 @@ -98,6 +107,22 @@ def normalize( norm: Function that computes the norm (Default: np.linalg.norm). Returns: - Xn: Arrays that have been normalized using to the given function. + Normalized array with the same shape as ``X``. + + Notes: + Any vectors whose norm is zero remain zero after normalization instead of + producing NaNs or infinities. """ - return np.asarray(X) / norm(X, axis=axis, keepdims=True) + arr = np.asarray(X, dtype=float) + denom = norm(arr, axis=axis, keepdims=True) + zero_mask = denom == 0 + + # Avoid divide-by-zero warnings and keep zeros in place by dividing only where safe. + safe_denom = np.where(zero_mask, 1.0, denom) + normalized = np.zeros_like(arr, dtype=float) + np.divide(arr, safe_denom, out=normalized, where=~zero_mask) + + if np.any(zero_mask): + normalized = np.where(zero_mask, 0.0, normalized) + + return normalized diff --git a/src/jetplot/timepiece.py b/src/jetplot/timepiece.py index 9d7b6c4..796676b 100644 --- a/src/jetplot/timepiece.py +++ b/src/jetplot/timepiece.py @@ -29,14 +29,14 @@ def elapsed(self) -> float: return elapsed def checkpoint(self, name: str = "") -> None: - print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip()) + print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip(), flush=True) def __enter__(self) -> "Stopwatch": return self def __exit__(self, *_: object) -> None: total = hrtime(time.perf_counter() - self.absolute_start) - print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}") + print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}", flush=True) def hrtime(t: float) -> str: diff --git a/tests/test_colors.py b/tests/test_colors.py index 76da5cf..640a715 100644 --- a/tests/test_colors.py +++ b/tests/test_colors.py @@ -1,5 +1,6 @@ """Tests the colors module.""" +from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.colors import to_rgb from matplotlib.figure import Figure @@ -27,6 +28,21 @@ def test_palette(): assert isinstance(fig, Figure) for ax in axs: assert isinstance(ax, Axes) + plt.close(fig) + + +def test_palette_single_color_plot(): + pal = colors.Palette( + [ + "#123456", + ] + ) + fig, axs = pal.plot() + + assert len(axs) == 1 + assert axs[0].get_facecolor()[:3] == to_rgb("#123456") + + plt.close(fig) def test_rainbow(): diff --git a/tests/test_images.py b/tests/test_images.py index 36ff869..85be406 100644 --- a/tests/test_images.py +++ b/tests/test_images.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from matplotlib import pyplot as plt from jetplot import images @@ -18,6 +19,21 @@ def test_img_corr_mode(): plt.close(fig) +def test_img_colorbar_attached_to_given_axes(): + data = np.eye(3) + fig, (ax_left, ax_right) = plt.subplots(1, 2) + im_left = images.img(data, fig=fig, ax=ax_left) + images.img(data, cbar=False, fig=fig, ax=ax_right) + + assert im_left in ax_left.images + # Expect one additional axes (colorbar) attached to the same figure + colorbar_axes = [ax for ax in fig.axes if ax not in {ax_left, ax_right}] + assert len(colorbar_axes) == 1 + assert colorbar_axes[0].figure is fig + + plt.close(fig) + + def test_cmat_labels_and_colorbar(): data = np.array([[0.0, 1.0], [1.0, 0.0]]) fig, ax = plt.subplots() @@ -28,3 +44,23 @@ def test_cmat_labels_and_colorbar(): assert [tick.get_text() for tick in ax.get_yticklabels()] == ["a", "b"] assert len(fig.axes) == 2 plt.close(fig) + + +def test_cmat_without_annotations(): + data = np.array([[0.2, 0.8], [0.1, 0.9]]) + fig, ax = plt.subplots() + + images.cmat(data, annot=False, fig=fig, ax=ax) + + assert len(ax.texts) == 0 + plt.close(fig) + + +def test_cmat_label_mismatch_raises(): + data = np.eye(2) + fig, ax = plt.subplots() + + with pytest.raises(ValueError): + images.cmat(data, labels=["short"], fig=fig, ax=ax) + + plt.close(fig) diff --git a/tests/test_plots.py b/tests/test_plots.py index 5b0ed21..f4c1bc2 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -1,5 +1,7 @@ import numpy as np +import pytest from matplotlib import pyplot as plt +from matplotlib.patches import Ellipse from jetplot import plots @@ -75,6 +77,17 @@ def test_waterfall(): plt.close(fig) +def test_waterfall_accepts_generators(): + x = np.arange(5) + ys = (np.linspace(0, 1, 5) for _ in range(3)) + + fig, ax = plt.subplots() + plots.waterfall(x, ys, fig=fig, ax=ax) + + assert len(ax.collections) >= 3 + plt.close(fig) + + def test_violinplot(): data = np.random.randn(100) fig, ax = plt.subplots() @@ -82,3 +95,58 @@ def test_violinplot(): # Expect at least one polygon from violin body assert len(ax.collections) > 0 plt.close(fig) + + +def test_ridgeline_accepts_generators(): + rng = np.random.default_rng(0) + t = np.linspace(-3, 3, 25) + xs = (rng.standard_normal(100) for _ in range(3)) + colors = (color for color in plots.neutral[:3]) + + fig, axs = plots.ridgeline(t, xs=xs, colors=colors) + assert len(axs) == 3 + plt.close(fig) + + +def test_ridgeline_mismatched_lengths_raise(): + t = np.linspace(-3, 3, 10) + xs = [np.linspace(0, 1, 5), np.linspace(0, 2, 5)] + colors = (color for color in plots.neutral[:1]) + + with pytest.raises(ValueError): + plots.ridgeline(t, xs=xs, colors=colors) + + plt.close("all") + + +def test_ridgeline_allows_extra_colors(): + rng = np.random.default_rng(2) + t = np.linspace(-3, 3, 25) + xs = [rng.standard_normal(100) for _ in range(3)] + + fig, axs = plots.ridgeline(t, xs=xs, colors=plots.neutral) + assert len(axs) == 3 + plt.close(fig) + + +def test_ellipse_returns_patch(): + rng = np.random.default_rng(1) + x = rng.standard_normal(200) + y = x + 0.1 * rng.standard_normal(200) + + fig, ax = plt.subplots() + patch = plots.ellipse(x, y, fig=fig, ax=ax) + + assert isinstance(patch, Ellipse) + assert patch in ax.patches + plt.close(fig) + + +def test_ellipse_length_mismatch_raises(): + x = np.arange(5) + y = np.arange(4) + + fig, ax = plt.subplots() + with pytest.raises(ValueError): + plots.ellipse(x, y, fig=fig, ax=ax) + plt.close(fig) diff --git a/tests/test_signals.py b/tests/test_signals.py index 0a4f763..6d4c921 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,6 +1,7 @@ """Tests for the signals module.""" import numpy as np +import pytest from jetplot import signals @@ -96,3 +97,21 @@ def test_normalize(): expected = np.stack([x / np.linalg.norm(x) for x in X.T]).T computed = signals.normalize(X, axis=0) assert np.allclose(expected, computed) + + +def test_stable_rank_invalid_shape(): + with pytest.raises(ValueError): + signals.stable_rank(np.ones(3)) + + +def test_normalize_handles_zero_vectors(): + X = np.array([[0.0, 0.0, 0.0], [3.0, 0.0, 4.0]]) + normalized = signals.normalize(X) + + assert np.allclose(normalized[0], 0.0) + assert np.allclose(normalized[1], np.array([0.6, 0.0, 0.8])) + + X = np.array([[0.0, 1.0], [0.0, 0.0]]) + normalized_cols = signals.normalize(X, axis=0) + assert np.allclose(normalized_cols[:, 0], 0.0) + assert np.allclose(normalized_cols[:, 1], np.array([1.0, 0.0]))