Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions src/jetplot/chart_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Plotting utils."""

from collections.abc import Callable
from functools import partial, wraps
from typing import Any, Literal

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes

__all__ = [
"noticks",
Expand Down Expand Up @@ -114,14 +115,25 @@ def nospines(left=False, bottom=False, top=True, right=True, **kwargs):
return ax


def get_bounds(axis, ax=None):
if ax is None:
ax = plt.gca()
def get_bounds(axis: Literal["x", "y"], ax: Axes | None = None) -> tuple[float, float]:
"""Return the axis spine bounds for the given axis.

Parameters
----------
axis : str
Axis to inspect, either ``"x"`` or ``"y"``.
ax : matplotlib.axes.Axes | None, optional
Axes object to inspect. If ``None``, the current axes are used.

Result = tuple[Callable[[], list[float]], Callable[[], list[str]], Callable[[], tuple[float, float]], str]
Returns
-------
tuple[float, float]
Lower and upper bounds of the axis spine.
"""
if ax is None:
ax = plt.gca()

axis_map: dict[str, Result] = {
axis_map: dict[str, Any] = {
"x": (ax.get_xticks, ax.get_xticklabels, ax.get_xlim, "bottom"),
"y": (ax.get_yticks, ax.get_yticklabels, ax.get_ylim, "left"),
}
Expand Down Expand Up @@ -187,14 +199,20 @@ def identity(x):


@axwrapper
def yclamp(y0=None, y1=None, dt=None, **kwargs):
def yclamp(
y0: float | None = None,
y1: float | None = None,
dt: float | None = None,
**kwargs,
) -> Axes:
"""Clamp the y-axis to evenly spaced tick marks."""
ax = kwargs["ax"]

lims = ax.get_ylim()
y0 = lims[0] if y0 is None else y0
y1 = lims[1] if y1 is None else y1

ticks: list[float] = ax.get_yticks() # pyrefly: ignore
ticks: list[float] = ax.get_yticks() # pyrefly: ignore
dt = float(np.mean(np.diff(ticks))) if dt is None else float(dt)

new_ticks = np.arange(dt * np.floor(y0 / dt), dt * (np.ceil(y1 / dt) + 1), dt)
Expand All @@ -206,14 +224,20 @@ def yclamp(y0=None, y1=None, dt=None, **kwargs):


@axwrapper
def xclamp(x0=None, x1=None, dt=None, **kwargs):
def xclamp(
x0: float | None = None,
x1: float | None = None,
dt: float | None = None,
**kwargs,
) -> Axes:
"""Clamp the x-axis to evenly spaced tick marks."""
ax = kwargs["ax"]

lims = ax.get_xlim()
x0 = lims[0] if x0 is None else x0
x1 = lims[1] if x1 is None else x1

ticks: list[float] = ax.get_xticks() # pyrefly: ignore
ticks: list[float] = ax.get_xticks() # pyrefly: ignore
dt = float(np.mean(np.diff(ticks))) if dt is None else float(dt)

new_ticks = np.arange(dt * np.floor(x0 / dt), dt * (np.ceil(x1 / dt) + 1), dt)
Expand Down
22 changes: 18 additions & 4 deletions src/jetplot/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import numpy as np
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.colors import LinearSegmentedColormap, to_hex
from matplotlib.figure import Figure
from matplotlib.typing import ColorType
from numpy.typing import NDArray

from .chart_utils import noticks

Expand All @@ -16,15 +19,18 @@ class Palette(list[ColorType]):

@property
def hex(self):
"""Return the palette colors as hexadecimal strings."""
return Palette([to_hex(rgb) for rgb in self])

@property
def cmap(self):
def cmap(self) -> LinearSegmentedColormap:
"""Return the palette as a Matplotlib colormap."""
return LinearSegmentedColormap.from_list("", self)

def plot(self, figsize=(5, 1)):
def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, NDArray[Axes]]:
"""Visualize the colors in the palette."""
fig, axs = plt.subplots(1, len(self), figsize=figsize)
for c, ax in zip(self, axs, strict=True): # pyrefly: ignore
for c, ax in zip(self, axs, strict=True): # pyrefly: ignore
ax.set_facecolor(c)
ax.set_aspect("equal")
noticks(ax=ax)
Expand Down Expand Up @@ -54,7 +60,13 @@ def cubehelix(
return Palette(colors)


def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0):
def cmap_colors(
cmap: str,
n: int,
vmin: float = 0.0,
vmax: float = 1.0,
) -> Palette:
"""Extract ``n`` colors from a Matplotlib colormap."""
return Palette(getattr(cm, cmap)(np.linspace(vmin, vmax, n)))


Expand Down Expand Up @@ -371,6 +383,8 @@ def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0):


def rainbow(k: int) -> Palette:
"""Return a palette of distinct colors from several base palettes."""

_colors = (
blue,
orange,
Expand Down
13 changes: 11 additions & 2 deletions src/jetplot/images.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Image visualization tools."""

from collections.abc import Callable
from functools import partial

import numpy as np
Expand Down Expand Up @@ -79,7 +80,15 @@ def img(


@plotwrapper
def fsurface(func, xrng=None, yrng=None, n=100, nargs=2, **kwargs):
def fsurface(
func: Callable[..., np.ndarray],
xrng: tuple[float, float] | None = None,
yrng: tuple[float, float] | None = None,
n: int = 100,
nargs: int = 2,
**kwargs,
) -> None:
"""Plot a 2‑D function as a filled surface."""
xrng = (-1, 1) if xrng is None else xrng
yrng = xrng if yrng is None else yrng

Expand Down Expand Up @@ -127,7 +136,7 @@ def cmat(

xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy")

for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore
for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore
color = dark_color if (value <= theta) else light_color
annot = f"{{:{fmt}}}".format(value)
ax.text(x, y, annot, ha="center", va="center", color=color, fontsize=fontsize)
Expand Down
30 changes: 26 additions & 4 deletions src/jetplot/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from matplotlib.patches import Ellipse
from matplotlib.transforms import Affine2D
from matplotlib.typing import ColorType
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from collections.abc import Sequence
from numpy.typing import NDArray
from scipy.stats import gaussian_kde
from sklearn.covariance import EmpiricalCovariance, MinCovDet
Expand Down Expand Up @@ -35,7 +38,8 @@ def violinplot(
showmeans=False,
showquartiles=True,
**kwargs,
):
) -> Axes:
"""Violin plot with customizable elements."""
_ = kwargs.pop("fig")
ax = kwargs.pop("ax")

Expand Down Expand Up @@ -86,6 +90,8 @@ def violinplot(
zorder=20,
)

return ax


@plotwrapper
def hist(*args, **kwargs):
Expand Down Expand Up @@ -249,11 +255,17 @@ def bar(


@plotwrapper
def lines(x, lines=None, cmap="viridis", **kwargs):
def lines(
x: NDArray[np.floating] | NDArray[np.integer],
lines: list[NDArray[np.floating]] | None = None,
cmap: str = "viridis",
**kwargs,
) -> Axes:
"""Plot multiple lines using a color map."""
ax = kwargs["ax"]

if lines is None:
lines = list(x)
lines = list(x) # pyrefly: ignore
x = np.arange(len(lines[0]))

else:
Expand All @@ -263,6 +275,8 @@ def lines(x, lines=None, cmap="viridis", **kwargs):
for line, color in zip(lines, colors, strict=False):
ax.plot(x, line, color=color)

return ax


@plotwrapper
def waterfall(x, ys, dy=1.0, pad=0.1, color="#444444", ec="#cccccc", ew=2.0, **kwargs):
Expand All @@ -281,7 +295,15 @@ def waterfall(x, ys, dy=1.0, pad=0.1, color="#444444", ec="#cccccc", ew=2.0, **k


@figwrapper
def ridgeline(t, xs, colors, edgecolor="#ffffff", ymax=0.6, **kwargs):
def ridgeline(
t: NDArray[np.floating],
xs: Sequence[NDArray[np.floating]],
colors: Sequence[ColorType],
edgecolor: ColorType = "#ffffff",
ymax: float = 0.6,
**kwargs,
) -> tuple[Figure, list[Axes]]:
"""Stacked density plots reminiscent of a ridgeline plot."""
fig = kwargs["fig"]
axs = []

Expand Down
4 changes: 2 additions & 2 deletions src/jetplot/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def set_defaults(

def available_fonts() -> list[str]:
"""Returns a list of available fonts."""
return sorted(set([f.name for f in fm.fontManager.ttflist])) # pyrefly: ignore
return sorted(set([f.name for f in fm.fontManager.ttflist])) # pyrefly: ignore


def install_fonts(filepath: str):
Expand All @@ -150,7 +150,7 @@ def install_fonts(filepath: str):
font_files = fm.findSystemFonts(fontpaths=[filepath])

for font_file in font_files:
fm.fontManager.addfont(font_file) # pyrefly: ignore
fm.fontManager.addfont(font_file) # pyrefly: ignore

new_fonts = set(available_fonts()) - original_fonts
if new_fonts:
Expand Down
14 changes: 8 additions & 6 deletions src/jetplot/timepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,30 @@


class Stopwatch:
def __init__(self, name=""):
"""Simple timer utility for measuring code execution time."""

def __init__(self, name: str = "") -> None:
self.name = name
self.start = time.perf_counter()
self.absolute_start = time.perf_counter()

def __str__(self):
def __str__(self) -> str:
return "\u231a Stopwatch for: " + self.name

@property
def elapsed(self):
def elapsed(self) -> float:
current = time.perf_counter()
elapsed = current - self.start
self.start = time.perf_counter()
return elapsed

def checkpoint(self, name=""):
def checkpoint(self, name: str = "") -> None:
print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip())

def __enter__(self):
def __enter__(self) -> "Stopwatch":
return self

def __exit__(self, *_):
def __exit__(self, *_: object) -> None:
total = hrtime(time.perf_counter() - self.absolute_start)
print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}")

Expand Down