Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# Developer Instructions
# Repository Guidelines

## Code Style
- Run `uv run ruff check .` before committing to ensure all Python code passes linting.
- Write clear docstrings for all public functions and classes.
- Use relative imports within the `jetplot` package.
## Project Structure & Module Organization
Source lives under `src/jetplot/`, split into focused modules such as `colors.py`, `plots.py`, and `style.py`. Add new utilities in cohesive files and keep imports relative (e.g. `from .colors import Palette`). Tests reside in `tests/` and should mirror the module layout; prefer descriptive folders like `tests/test_plots.py` to match the feature under test. Documentation content is maintained in `docs/` and rendered via MkDocs, while build artefacts land in `build/` and `dist/`.

## Testing
- Run `uv run pytest --cov=jetplot --cov-report=term` before committing to ensure all tests pass before submitting a PR.
- Run `uv run pyrefly check` before committing to ensure all pyrefly type checking passes.
## Build, Test, and Development Commands
Run `uv run ruff check .` to ensure linting passes and style expectations are met. Use `uv run pytest --cov=jetplot --cov-report=term` for the full suite with coverage feedback, and `uv run pyrefly check` to validate typing across the package. When iterating on documentation, launch `uv run mkdocs serve` for a live preview. Regenerate distributions with `uv build` once changes are ready to publish.

## PR Guidelines
- Your pull request description must contain a **Summary** section explaining the changes.
- Include a **Testing** section describing the commands used to run lint and tests along with their results.
## Coding Style & Naming Conventions
Follow PEP 8 defaults: four-space indentation, snake_case for functions and module-level variables, and CapWords for classes. Exported constants stay upper-case with underscores. Provide clear docstrings for every public function or class that describes inputs, return values, and side effects. Keep modules small, favor pure functions, and rely on the shared helpers already defined in `style.py` and `chart_utils.py` instead of duplicating logic.

## Testing Guidelines
Write pytest-based tests alongside new functionality, naming files `test_<feature>.py` and individual tests `test_<behavior>`. Prefer parametrization to cover edge cases concisely. Aim to maintain or raise the coverage reported by the standard coverage command; unexpected drops should block merges. Include regression tests when fixing bugs so future refactors stay guarded.

## Commit & Pull Request Guidelines
Commits use short, imperative summaries (e.g. `Add palette cycler helper`). Break large efforts into logical commits that pass linting and tests independently. Pull requests must include **Summary** and **Testing** sections outlining what changed and the exact commands run (`uv run ruff check .`, `uv run pytest --cov=jetplot --cov-report=term`, `uv run pyrefly check`). Link relevant issues and add screenshots when UI-facing artifacts such as documentation pages change.

## Documentation Tips
Reference existing examples in `docs/` when adding guides, and keep code snippets synced with the APIs under `src/jetplot/`. Rebuild the site locally with `uv run mkdocs serve` after edits to verify navigation, formatting, and cross-links before submitting your changes.
5 changes: 4 additions & 1 deletion justfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
default: test
default: format lint test typecheck

build:
uv build
Expand All @@ -12,6 +12,9 @@ docs:
format:
uv run ruff format

lint:
uv run ruff check

typecheck:
uv run pyrefly check

Expand Down
2 changes: 1 addition & 1 deletion src/jetplot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Jetplot is a set of useful utility functions for scientific python."""

__version__ = "0.6.5"
__version__ = "0.6.6"

from . import colors as c # noqa: F401
from .chart_utils import *
Expand Down
12 changes: 9 additions & 3 deletions src/jetplot/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@ def cmap(self) -> LinearSegmentedColormap:

def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, list[Axes]]:
"""Visualize the colors in the palette."""
if not self:
raise ValueError("Palette has no colors to plot.")

fig, axs = plt.subplots(1, len(self), figsize=figsize)
for c, ax in zip(self, axs, strict=True): # pyrefly: ignore
ax.set_facecolor(c)
axs_array = np.atleast_1d(axs)
axes_list = [cast(Axes, ax) for ax in axs_array.flat]

for c, ax in zip(self, axes_list, strict=True):
ax.set_facecolor(c) # pyrefly: ignore
ax.set_aspect("equal")
noticks(ax=ax)

return fig, cast(list[Axes], axs)
return fig, axes_list


def cubehelix(
Expand Down
83 changes: 64 additions & 19 deletions src/jetplot/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any, cast

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.image import AxesImage
from matplotlib.ticker import FixedLocator
Expand All @@ -31,12 +30,21 @@ def img(
"""Visualize a matrix as an image.

Args:
img: array_like, The array to visualize.
mode: string, One of 'div' for a diverging image, 'seq' for
sequential, 'cov' for covariance matrices, or 'corr' for
correlation matrices (default: 'div').
cmap: string, Colormap to use.
aspect: string, Either 'equal' or 'auto'
data: Array to visualize.
mode: One of ``"div"``, ``"seq"``, ``"cov"``, or ``"corr"``.
cmap: Matplotlib colormap name. Mode defaults are used when ``None``.
aspect: Either ``"equal"`` or ``"auto"``.
vmin: Lower bound for normalization.
vmax: Upper bound for normalization.
cbar: Whether to draw a colorbar attached to the provided axes.
interpolation: Interpolation strategy passed to ``imshow``.

Raises:
ValueError: If ``mode`` is not recognized.

Notes:
When ``cbar`` is ``True``, the colorbar is added to the supplied axes/figure
so multi-axes layouts keep their layout intact.
"""
# work with a copy of the original image data
img = np.squeeze(data.copy())
Expand Down Expand Up @@ -68,16 +76,17 @@ def img(
raise ValueError("Unrecognized mode: '" + mode + "'")

# make the image
im = kwargs["ax"].imshow(
fig, ax = kwargs["fig"], kwargs["ax"]
im = ax.imshow(
img, cmap=cmap, interpolation=interpolation, vmin=vmin, vmax=vmax, aspect=aspect
)

# colorbar
if cbar:
plt.colorbar(im)
fig.colorbar(im, ax=ax)

# clear ticks
noticks(ax=kwargs["ax"])
noticks(ax=ax)

return im

Expand Down Expand Up @@ -131,24 +140,60 @@ def cmat(
vmax: float = 1.0,
**kwargs: Any,
) -> tuple[AxesImage, Axes]:
"""Plot confusion matrix."""
"""Plot a confusion matrix with optional annotations.

Args:
arr: Square matrix of scores in [0, 1].
labels: Optional axis labels. Must match matrix dimensions.
annot: Whether to draw text annotations for each cell.
cmap: Colormap used for the heatmap.
cbar: Whether to include a colorbar.
fmt: Format string applied to annotation labels.
dark_color: Text color used when ``value <= theta``.
light_color: Text color used when ``value > theta``.
grid_color: Grid line color.
theta: Threshold for choosing between ``dark_color`` and ``light_color``.
label_fontsize: Tick label font size.
fontsize: Annotation font size.
vmin: Lower bound for normalization.
vmax: Upper bound for normalization.

Raises:
ValueError: If labels are provided but do not match the matrix dimensions.
"""
num_rows, num_cols = arr.shape

label_list: list[str] | None = None
if labels is not None:
label_list = list(labels)
if len(label_list) != num_cols or num_rows != num_cols:
raise ValueError(
"Labels must match confusion matrix dimensions and matrix must be square."
)

ax = kwargs.pop("ax")
cb = imv(arr, ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, cbar=cbar)

xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy")

for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore
color = dark_color if (value <= theta) else light_color
label = f"{{:{fmt}}}".format(value)
ax.text(x, y, label, ha="center", va="center", color=color, fontsize=fontsize)

if labels is not None:
if annot:
for x, y, value in zip( # pyrefly: ignore
xs.flat, # pyrefly: ignore
ys.flat,
arr.flat,
strict=True, # pyrefly: ignore
):
color = dark_color if (value <= theta) else light_color
label = f"{{:{fmt}}}".format(value)
ax.text(
x, y, label, ha="center", va="center", color=color, fontsize=fontsize
)

if label_list is not None:
ax.set_xticks(np.arange(num_cols))
ax.set_xticklabels(labels, rotation=90, fontsize=label_fontsize)
ax.set_xticklabels(label_list, rotation=90, fontsize=label_fontsize)
ax.set_yticks(np.arange(num_rows))
ax.set_yticklabels(labels, fontsize=label_fontsize)
ax.set_yticklabels(label_list, fontsize=label_fontsize)

ax.xaxis.set_minor_locator(FixedLocator((np.arange(num_cols) - 0.5).tolist()))

Expand Down
73 changes: 60 additions & 13 deletions src/jetplot/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,19 @@ def errorplot(
"""Plot a line with error bars."""
ax = kwargs["ax"]

if np.isscalar(yerr) or len(yerr) == len(y): # pyrefly: ignore
if np.isscalar(yerr):
ymin = y - yerr # pyrefly: ignore
ymax = y + yerr # pyrefly: ignore
elif len(yerr) == 2:
elif isinstance(yerr, tuple):
if len(yerr) != 2:
raise ValueError("Invalid yerr tuple length: ", yerr)
ymin, ymax = yerr # pyrefly: ignore
else:
raise ValueError("Invalid yerr value: ", yerr)
yerr_array = np.asarray(yerr)
if yerr_array.shape != y.shape:
raise ValueError("Invalid yerr value: ", yerr)
ymin = y - yerr_array
ymax = y + yerr_array

if method == "line":
ax.plot(x, y, fmt, color=color, linewidth=4, clip_on=clip_on)
Expand Down Expand Up @@ -295,11 +301,27 @@ def waterfall(
ew: float = 2.0,
**kwargs: Any,
) -> None:
"""Waterfall plot."""
"""Waterfall plot for stacked sequences.

Args:
x: Common x-axis samples shared by every series.
ys: Iterable of y-series. Generators are supported and are consumed once.
dy: Vertical scaling applied to each successive series.
pad: Offset applied so the outline sits slightly above the fill.
color: Fill color for each series.
ec: Edge color for the outline.
ew: Edge line width.

Raises:
ValueError: If ``ys`` yields no series.
"""
ax = kwargs["ax"]
total = cast(int, len(ys))
ys_list = list(ys)
if not ys_list:
raise ValueError("ys must contain at least one series.")
total = len(ys_list)

for index, y in enumerate(ys):
for index, y in enumerate(ys_list):
zorder = total - index
y = y * dy + index
ax.plot(x, y + pad, color=ec, clip_on=False, lw=ew, zorder=zorder)
Expand All @@ -318,16 +340,40 @@ def ridgeline(
ymax: float = 0.6,
**kwargs: Any,
) -> tuple[Figure, list[Axes]]:
"""Stacked density plots reminiscent of a ridgeline plot."""
"""Stacked density plots reminiscent of a ridgeline plot.

Args:
t: Grid used when evaluating the kernel density estimate.
xs: Iterable of 1-D samples. Accepts generators and consumes them once.
colors: Iterable of colors. Must provide at least as many entries as ``xs``.
edgecolor: Line color used for the outline.
ymax: Upper y-limit for each subplot.

Raises:
ValueError: If ``xs`` is empty or ``colors`` provides too few values.
"""
fig = kwargs["fig"]
xs_list = list(xs)
colors_iter = iter(colors)

if not xs_list:
raise ValueError("xs must contain at least one series.")

axs = []

for k, (x, c) in enumerate(zip(xs, colors, strict=False)):
ax = fig.add_subplot(cast(int, len(xs)), 1, k + 1)
for k, x in enumerate(xs_list):
try:
palette_color = next(colors_iter)
except StopIteration as exc:
raise ValueError(
"colors must provide at least as many items as xs."
) from exc

ax = fig.add_subplot(len(xs_list), 1, k + 1)
y = gaussian_kde(x).evaluate(t)
ax.fill_between(t, y, color=c, clip_on=False)
ax.fill_between(t, y, color=palette_color, clip_on=False)
ax.plot(t, y, color=edgecolor, clip_on=False)
ax.axhline(0.0, lw=2, color=c, clip_on=False)
ax.axhline(0.0, lw=2, color=palette_color, clip_on=False)

ax.set_xlim(t[0], t[-1])
ax.set_xticks([])
Expand Down Expand Up @@ -378,7 +424,8 @@ def ellipse(
-------
matplotlib.patches.Ellipse
"""
ax = cast(Axes, kwargs.get("ax"))
ax = cast(Axes, kwargs.pop("ax", None))
kwargs.pop("fig", None)

if x.size != y.size:
raise ValueError("x and y must be the same size")
Expand Down Expand Up @@ -419,4 +466,4 @@ def ellipse(
)

ellipse.set_transform(transform + ax.transData) # pyrefly: ignore
return ax.add_patch(ellipse)
return cast(Ellipse, ax.add_patch(ellipse))
35 changes: 30 additions & 5 deletions src/jetplot/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,21 @@ def smooth(x: ArrayLike, sigma: float = 1.0, axis: int = 0) -> NDArray[np.floati
Returns:
xs: array_like, A smoothed version of the input signal
"""
return gaussian_filter1d(x, sigma, axis=axis)
arr = np.asarray(x)
return gaussian_filter1d(arr, sigma, axis=axis)


def stable_rank(X: NDArray[np.floating[Any]]) -> float:
"""Computes the stable rank of a matrix"""
assert X.ndim == 2, "X must be a matrix"
"""Compute the stable rank of a matrix.

Args:
X: Two-dimensional array representing a matrix.

Raises:
ValueError: If ``X`` is not two-dimensional.
"""
if X.ndim != 2:
raise ValueError("X must be a matrix")

# pyrefly: ignore
svals_sq = np.linalg.svd(X, compute_uv=False, full_matrices=False) ** 2
Expand Down Expand Up @@ -98,6 +107,22 @@ def normalize(
norm: Function that computes the norm (Default: np.linalg.norm).

Returns:
Xn: Arrays that have been normalized using to the given function.
Normalized array with the same shape as ``X``.

Notes:
Any vectors whose norm is zero remain zero after normalization instead of
producing NaNs or infinities.
"""
return np.asarray(X) / norm(X, axis=axis, keepdims=True)
arr = np.asarray(X, dtype=float)
denom = norm(arr, axis=axis, keepdims=True)
zero_mask = denom == 0

# Avoid divide-by-zero warnings and keep zeros in place by dividing only where safe.
safe_denom = np.where(zero_mask, 1.0, denom)
normalized = np.zeros_like(arr, dtype=float)
np.divide(arr, safe_denom, out=normalized, where=~zero_mask)

if np.any(zero_mask):
normalized = np.where(zero_mask, 0.0, normalized)

return normalized
4 changes: 2 additions & 2 deletions src/jetplot/timepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def elapsed(self) -> float:
return elapsed

def checkpoint(self, name: str = "") -> None:
print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip())
print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip(), flush=True)

def __enter__(self) -> "Stopwatch":
return self

def __exit__(self, *_: object) -> None:
total = hrtime(time.perf_counter() - self.absolute_start)
print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}")
print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}", flush=True)


def hrtime(t: float) -> str:
Expand Down
Loading