Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ typecheck:

test:
uv run pytest --cov=jetplot --cov-report=term

loop:
find {src,tests} -name "*.py" | entr -c just test
28 changes: 20 additions & 8 deletions src/jetplot/chart_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Plotting utils."""

from collections.abc import Callable
from functools import partial, wraps
from typing import Any, Literal

Expand All @@ -20,7 +21,7 @@
]


def figwrapper(fun):
def figwrapper(fun: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator that adds figure handles to the kwargs of a function."""

@wraps(fun)
Expand All @@ -33,7 +34,7 @@ def wrapper(*args, **kwargs):
return wrapper


def plotwrapper(fun):
def plotwrapper(fun: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator that adds figure and axes handles to the kwargs of a function."""

@wraps(fun)
Expand All @@ -52,7 +53,7 @@ def wrapper(*args, **kwargs):
return wrapper


def axwrapper(fun):
def axwrapper(fun: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator that adds an axes handle to kwargs."""

@wraps(fun)
Expand All @@ -70,7 +71,7 @@ def wrapper(*args, **kwargs):


@axwrapper
def noticks(**kwargs):
def noticks(**kwargs: Any) -> None:
"""
Clears tick marks (useful for images)
"""
Expand All @@ -81,7 +82,13 @@ def noticks(**kwargs):


@axwrapper
def nospines(left=False, bottom=False, top=True, right=True, **kwargs):
def nospines(
left: bool = False,
bottom: bool = False,
top: bool = True,
right: bool = True,
**kwargs: Any,
) -> plt.Axes:
"""
Hides the specified axis spines (by default, right and top spines)
"""
Expand Down Expand Up @@ -160,7 +167,12 @@ def get_bounds(axis: Literal["x", "y"], ax: Axes | None = None) -> tuple[float,


@axwrapper
def breathe(xlims=None, ylims=None, padding_percent=0.05, **kwargs):
def breathe(
xlims: tuple[float, float] | None = None,
ylims: tuple[float, float] | None = None,
padding_percent: float = 0.05,
**kwargs: Any,
) -> plt.Axes:
"""Adds space between axes and plot."""
ax = kwargs["ax"]

Expand Down Expand Up @@ -203,7 +215,7 @@ def yclamp(
y0: float | None = None,
y1: float | None = None,
dt: float | None = None,
**kwargs,
**kwargs: Any,
) -> Axes:
"""Clamp the y-axis to evenly spaced tick marks."""
ax = kwargs["ax"]
Expand All @@ -228,7 +240,7 @@ def xclamp(
x0: float | None = None,
x1: float | None = None,
dt: float | None = None,
**kwargs,
**kwargs: Any,
) -> Axes:
"""Clamp the x-axis to evenly spaced tick marks."""
ax = kwargs["ax"]
Expand Down
13 changes: 7 additions & 6 deletions src/jetplot/colors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Colorschemes"""

from typing import cast

import numpy as np
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.colors import LinearSegmentedColormap, to_hex
from matplotlib.figure import Figure
from matplotlib.typing import ColorType
from numpy.typing import NDArray

from .chart_utils import noticks

Expand All @@ -18,24 +19,24 @@ class Palette(list[ColorType]):
"""Color palette based on a list of values."""

@property
def hex(self):
def hex(self) -> "Palette":
"""Return the palette colors as hexadecimal strings."""
return Palette([to_hex(rgb) for rgb in self])
return Palette([to_hex(rgb) for rgb in self]) # pyrefly: ignore

@property
def cmap(self) -> LinearSegmentedColormap:
"""Return the palette as a Matplotlib colormap."""
return LinearSegmentedColormap.from_list("", self)

def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, NDArray[Axes]]:
def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, list[Axes]]:
"""Visualize the colors in the palette."""
fig, axs = plt.subplots(1, len(self), figsize=figsize)
for c, ax in zip(self, axs, strict=True): # pyrefly: ignore
ax.set_facecolor(c)
ax.set_aspect("equal")
noticks(ax=ax)

return fig, axs
return fig, cast(list[Axes], axs)


def cubehelix(
Expand All @@ -46,7 +47,7 @@ def cubehelix(
start: float = 0.0,
rot: float = 0.4,
hue: float = 0.8,
):
) -> Palette:
"""Cubehelix parameterized colormap."""
lambda_ = np.linspace(vmin, vmax, n)
x = lambda_**gamma
Expand Down
63 changes: 33 additions & 30 deletions src/jetplot/images.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Image visualization tools."""

from collections.abc import Callable
from collections.abc import Callable, Iterable
from functools import partial
from typing import Any, cast

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.image import AxesImage
from matplotlib.ticker import FixedLocator

from . import colors as c
Expand All @@ -15,16 +18,16 @@

@plotwrapper
def img(
data,
mode="div",
cmap=None,
aspect="equal",
vmin=None,
vmax=None,
cbar=True,
interpolation="none",
**kwargs,
):
data: np.ndarray,
mode: str = "div",
cmap: str | None = None,
aspect: str = "equal",
vmin: float | None = None,
vmax: float | None = None,
cbar: bool = True,
interpolation: str = "none",
**kwargs: Any,
) -> AxesImage:
"""Visualize a matrix as an image.

Args:
Expand Down Expand Up @@ -86,7 +89,7 @@ def fsurface(
yrng: tuple[float, float] | None = None,
n: int = 100,
nargs: int = 2,
**kwargs,
**kwargs: Any,
) -> None:
"""Plot a 2‑D function as a filled surface."""
xrng = (-1, 1) if xrng is None else xrng
Expand All @@ -112,22 +115,22 @@ def fsurface(

@plotwrapper
def cmat(
arr,
labels=None,
annot=True,
cmap="gist_heat_r",
cbar=False,
fmt="0.0%",
dark_color="#222222",
light_color="#dddddd",
grid_color=c.gray[9],
theta=0.5,
label_fontsize=10.0,
fontsize=10.0,
vmin=0.0,
vmax=1.0,
**kwargs,
):
arr: np.ndarray,
labels: Iterable[str] | None = None,
annot: bool = True,
cmap: str = "gist_heat_r",
cbar: bool = False,
fmt: str = "0.0%",
dark_color: str = "#222222",
light_color: str = "#dddddd",
grid_color: str = cast(str, c.gray[9]),
theta: float = 0.5,
label_fontsize: float = 10.0,
fontsize: float = 10.0,
vmin: float = 0.0,
vmax: float = 1.0,
**kwargs: Any,
) -> tuple[AxesImage, Axes]:
"""Plot confusion matrix."""
num_rows, num_cols = arr.shape

Expand All @@ -138,8 +141,8 @@ def cmat(

for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore
color = dark_color if (value <= theta) else light_color
annot = f"{{:{fmt}}}".format(value)
ax.text(x, y, annot, ha="center", va="center", color=color, fontsize=fontsize)
label = f"{{:{fmt}}}".format(value)
ax.text(x, y, label, ha="center", va="center", color=color, fontsize=fontsize)

if labels is not None:
ax.set_xticks(np.arange(num_cols))
Expand Down
Loading