diff --git a/README.md b/README.md
index b32153c..a3d0fa3 100644
--- a/README.md
+++ b/README.md
@@ -46,6 +46,8 @@ To dig deeper, see the [documentation](https://tdhook.readthedocs.io).
- [Integrated Gradients](https://tdhook.readthedocs.io/en/latest/notebooks/methods/integrated-gradients.html): [](https://colab.research.google.com/github/Xmaster6y/tdhook/blob/main/docs/source/notebooks/methods/integrated-gradients.ipynb)
- [Steering Vectors](https://tdhook.readthedocs.io/en/latest/notebooks/methods/steering-vectors.html): [](https://colab.research.google.com/github/Xmaster6y/tdhook/blob/main/docs/source/notebooks/methods/steering-vectors.ipynb)
+- [Linear Probing](https://tdhook.readthedocs.io/en/latest/notebooks/methods/linear-probing.html): [](https://colab.research.google.com/github/Xmaster6y/tdhook/blob/main/docs/source/notebooks/methods/linear-probing.ipynb)
+- [Dimension Estimation](https://tdhook.readthedocs.io/en/latest/notebooks/methods/dimension-estimation.html): [](https://colab.research.google.com/github/Xmaster6y/tdhook/blob/main/docs/source/notebooks/methods/dimension-estimation.ipynb)
## Config
diff --git a/docs/source/methods.rst b/docs/source/methods.rst
index 3a813ad..75e6a7f 100644
--- a/docs/source/methods.rst
+++ b/docs/source/methods.rst
@@ -71,6 +71,23 @@ Methods
+ .. grid-item-card::
+ :link: notebooks/methods/dimension-estimation.ipynb
+ :class-card: surface
+ :class-body: surface
+
+ .. raw:: html
+
+
+
+
+
+
+
Dimension Estimation
+
Estimate intrinsic dimension of data manifolds using TwoNN, Local PCA, and related methods.
+
+
+
.. toctree::
:hidden:
:maxdepth: 2
@@ -78,3 +95,4 @@ Methods
notebooks/methods/integrated-gradients.ipynb
notebooks/methods/steering-vectors.ipynb
notebooks/methods/linear-probing.ipynb
+ notebooks/methods/dimension-estimation.ipynb
diff --git a/docs/source/notebooks/methods/dimension-estimation.ipynb b/docs/source/notebooks/methods/dimension-estimation.ipynb
new file mode 100644
index 0000000..419ce88
--- /dev/null
+++ b/docs/source/notebooks/methods/dimension-estimation.ipynb
@@ -0,0 +1,583 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Intrinsic Dimension Estimation\n",
+ "\n",
+ "This notebook demonstrates how to use the intrinsic dimension estimators from tdhook on synthetic data and MNIST."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import importlib.util\n",
+ "\n",
+ "DEV = True\n",
+ "\n",
+ "if importlib.util.find_spec(\"google.colab\") is not None:\n",
+ " MODE = \"colab-dev\" if DEV else \"colab\"\n",
+ "else:\n",
+ " MODE = \"local\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if MODE == \"colab\":\n",
+ " %pip install -q tdhook scikit-learn\n",
+ "elif MODE == \"colab-dev\":\n",
+ " !rm -rf tdhook\n",
+ " !git clone https://github.com/Xmaster6y/tdhook -b main\n",
+ " %pip install -q ./tdhook"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import time\n",
+ "import torch\n",
+ "from tensordict import TensorDict\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "\n",
+ "from tdhook.latent.dimension_estimation import (\n",
+ " TwoNnDimensionEstimator,\n",
+ " LocalKnnDimensionEstimator,\n",
+ " LocalPcaDimensionEstimator,\n",
+ " CaPcaDimensionEstimator,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 1: Synthetic Data (Simple Case)\n",
+ "\n",
+ "We start with simple synthetic manifolds where we know the true intrinsic dimension."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.1 Generate and visualize synthetic data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "torch.manual_seed(42)\n",
+ "\n",
+ "# 2D plane embedded in 10D (last 8 dims zero) - intrinsic dim = 2\n",
+ "plane_data = torch.randn(200, 10)\n",
+ "plane_data[:, 2:] = 0\n",
+ "\n",
+ "# 1D circle embedded in 2D - intrinsic dim = 1\n",
+ "theta = torch.rand(200) * 2 * torch.pi\n",
+ "circle_data = torch.stack([torch.cos(theta), torch.sin(theta)], dim=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
+ "\n",
+ "axes[0].scatter(plane_data[:, 0].numpy(), plane_data[:, 1].numpy(), alpha=0.6)\n",
+ "axes[0].set_title(\"2D plane in 10D (first 2 dims)\")\n",
+ "axes[0].set_xlabel(\"$x_1$\")\n",
+ "axes[0].set_ylabel(\"$x_2$\")\n",
+ "axes[0].set_aspect(\"equal\")\n",
+ "\n",
+ "axes[1].scatter(circle_data[:, 0].numpy(), circle_data[:, 1].numpy(), alpha=0.6)\n",
+ "axes[1].set_title(\"1D circle in 2D\")\n",
+ "axes[1].set_xlabel(\"$x_1$\")\n",
+ "axes[1].set_ylabel(\"$x_2$\")\n",
+ "axes[1].set_aspect(\"equal\")\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.2 Run all estimators"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def run_estimators(data, k=\"auto\"):\n",
+ " td = TensorDict({\"data\": data}, batch_size=[])\n",
+ " results = {}\n",
+ " timings = {}\n",
+ "\n",
+ " # TwoNN: single scalar per dataset\n",
+ " t0 = time.perf_counter()\n",
+ " twonn = TwoNnDimensionEstimator(return_xy=True)\n",
+ " td_twonn = twonn(td.clone())\n",
+ " timings[\"TwoNN\"] = time.perf_counter() - t0\n",
+ " results[\"TwoNN\"] = td_twonn[\"dimension\"].item()\n",
+ " results[\"TwoNN_xy\"] = (td_twonn[\"dimension_x\"], td_twonn[\"dimension_y\"])\n",
+ "\n",
+ " # Per-point estimators\n",
+ " for name, est in [\n",
+ " (\"LocalKnn\", LocalKnnDimensionEstimator(k=k)),\n",
+ " (\"LocalPCA\", LocalPcaDimensionEstimator(k=k)),\n",
+ " (\"CaPca\", CaPcaDimensionEstimator(k=k)),\n",
+ " ]:\n",
+ " t0 = time.perf_counter()\n",
+ " td_est = est(td.clone())\n",
+ " timings[name] = time.perf_counter() - t0\n",
+ " d = td_est[\"dimension\"]\n",
+ " valid = torch.isfinite(d)\n",
+ " results[name] = d[valid].mean().item() if valid.any() else float(\"nan\")\n",
+ " results[f\"{name}_per_point\"] = d\n",
+ "\n",
+ " results[\"timings\"] = timings\n",
+ " print(\"Timings:\", \" | \".join(f\"{k}: {v:.3f}s\" for k, v in timings.items()))\n",
+ " return results"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "`k=5`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plane_results = run_estimators(plane_data, k=5)\n",
+ "circle_results = run_estimators(circle_data, k=5)\n",
+ "\n",
+ "print(\"2D plane (expected ~2):\")\n",
+ "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n",
+ " print(f\" {name}: {plane_results[name]:.2f}\")\n",
+ "\n",
+ "print(\"\\n1D circle (expected ~1):\")\n",
+ "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n",
+ " print(f\" {name}: {circle_results[name]:.2f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "`k=10`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plane_results = run_estimators(plane_data, k=10)\n",
+ "circle_results = run_estimators(circle_data, k=10)\n",
+ "\n",
+ "print(\"2D plane (expected ~2):\")\n",
+ "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n",
+ " print(f\" {name}: {plane_results[name]:.2f}\")\n",
+ "\n",
+ "print(\"\\n1D circle (expected ~1):\")\n",
+ "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n",
+ " print(f\" {name}: {circle_results[name]:.2f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "`k = \"auto\"`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plane_results = run_estimators(plane_data)\n",
+ "circle_results = run_estimators(circle_data)\n",
+ "\n",
+ "print(\"2D plane (expected ~2):\")\n",
+ "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n",
+ " print(f\" {name}: {plane_results[name]:.2f}\")\n",
+ "\n",
+ "print(\"\\n1D circle (expected ~1):\")\n",
+ "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n",
+ " print(f\" {name}: {circle_results[name]:.2f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.3 TwoNN visualization\n",
+ "\n",
+ "TwoNN estimates dimension from the linear relationship $y = d \\cdot x$ where $x = \\log(\\mu)$ and $y = -\\log(1-F)$."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x_plane = plane_results[\"TwoNN_xy\"][0].numpy()\n",
+ "y_plane = plane_results[\"TwoNN_xy\"][1].numpy()\n",
+ "d_plane = plane_results[\"TwoNN\"]\n",
+ "valid = np.isfinite(x_plane) & np.isfinite(y_plane)\n",
+ "\n",
+ "fig, ax = plt.subplots(figsize=(6, 4))\n",
+ "ax.scatter(x_plane[valid], y_plane[valid], alpha=0.6, label=\"data\")\n",
+ "x_line = np.linspace(x_plane[valid].min(), x_plane[valid].max(), 50)\n",
+ "ax.plot(x_line, d_plane * x_line, \"r-\", lw=2, label=f\"y = {d_plane:.2f} * x\")\n",
+ "ax.set_xlabel(\"$x = \\\\log(\\\\mu)$\")\n",
+ "ax.set_ylabel(\"$y = -\\\\log(1-F)$\")\n",
+ "ax.set_title(\"TwoNN: 2D plane (slope = estimated dimension)\")\n",
+ "ax.legend()\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.4 Per-point estimator histograms\n",
+ "\n",
+ "Per-point estimators (LocalKnn, LocalPCA, CaPca) give a dimension estimate at each data point. The box plot shows the distribution of these local estimates, revealing variation across the manifold."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n",
+ "\n",
+ "for ax, name in zip(axes, [\"LocalKnn\", \"LocalPCA\", \"CaPca\"]):\n",
+ " d = plane_results[f\"{name}_per_point\"].numpy()\n",
+ " valid = np.isfinite(d)\n",
+ " ax.hist(d[valid], bins=20, edgecolor=\"black\", alpha=0.7)\n",
+ " ax.axvline(plane_results[name], color=\"red\", linestyle=\"--\", label=f\"mean = {plane_results[name]:.2f}\")\n",
+ " ax.set_xlabel(\"Estimated dimension\")\n",
+ " ax.set_title(name)\n",
+ " ax.legend()\n",
+ "\n",
+ "plt.suptitle(\"Per-point dimension estimates on 2D plane\")\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 2: MNIST\n",
+ "\n",
+ "We now run the estimators on MNIST digits. The intrinsic dimension of the MNIST manifold is typically around 10-15."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.1 Load MNIST"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "ds = load_dataset(\"mnist\", split=\"train\")\n",
+ "arr = np.array(ds[\"image\"][:1000])\n",
+ "mnist_data = torch.tensor(arr, dtype=torch.float32).reshape(1000, -1) / 255.0\n",
+ "\n",
+ "\n",
+ "print(f\"MNIST shape: {mnist_data.shape}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.2 Sample digits visualization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, axes = plt.subplots(3, 3, figsize=(6, 6))\n",
+ "for i, ax in enumerate(axes.flat):\n",
+ " img = mnist_data[i].reshape(28, 28).numpy()\n",
+ " ax.imshow(img, cmap=\"gray\")\n",
+ " ax.axis(\"off\")\n",
+ "plt.suptitle(\"Sample MNIST digits\")\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.3 Run estimators on MNIST"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mnist_results = run_estimators(mnist_data, k=\"auto\")\n",
+ "\n",
+ "print(\"MNIST (expected ~10-15):\")\n",
+ "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n",
+ " print(f\" {name}: {mnist_results[name]:.2f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.4 TwoNN curve fitting on MNIST"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x_mnist = mnist_results[\"TwoNN_xy\"][0].numpy()\n",
+ "y_mnist = mnist_results[\"TwoNN_xy\"][1].numpy()\n",
+ "d_mnist = mnist_results[\"TwoNN\"]\n",
+ "valid = np.isfinite(x_mnist) & np.isfinite(y_mnist)\n",
+ "\n",
+ "fig, ax = plt.subplots(figsize=(6, 4))\n",
+ "ax.scatter(x_mnist[valid], y_mnist[valid], alpha=0.6, label=\"data\")\n",
+ "x_line = np.linspace(x_mnist[valid].min(), x_mnist[valid].max(), 50)\n",
+ "ax.plot(x_line, d_mnist * x_line, \"r-\", lw=2, label=f\"y = {d_mnist:.2f} * x\")\n",
+ "ax.set_xlabel(\"$x = \\\\log(\\\\mu)$\")\n",
+ "ax.set_ylabel(\"$y = -\\\\log(1-F)$\")\n",
+ "ax.set_title(\"TwoNN: MNIST (slope = estimated dimension)\")\n",
+ "ax.legend()\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.5 Box plot: per-point estimators (local variations)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "per_point_methods = [\"LocalKnn\", \"LocalPCA\", \"CaPca\"]\n",
+ "\n",
+ "fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n",
+ "for ax, name in zip(axes, per_point_methods):\n",
+ " d = mnist_results[f\"{name}_per_point\"].numpy()\n",
+ " d = d[np.isfinite(d)]\n",
+ " bp = ax.boxplot([d], tick_labels=[name], patch_artist=True)\n",
+ " ax.axhline(\n",
+ " mnist_results[\"TwoNN\"], color=\"gray\", linestyle=\"--\", alpha=0.7, label=f\"TwoNN = {mnist_results['TwoNN']:.1f}\"\n",
+ " )\n",
+ " ax.set_ylabel(\"Estimated dimension (per point)\")\n",
+ " ax.legend()\n",
+ "\n",
+ "plt.suptitle(\"Per-point dimension estimates on MNIST (1000 images)\")\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.6 Using 5000 images"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "arr = np.array(ds[\"image\"][:5000])\n",
+ "mnist_data_5k = torch.tensor(arr, dtype=torch.float32).reshape(5000, -1) / 255.0\n",
+ "\n",
+ "mnist_results_5k = run_estimators(mnist_data_5k, k=\"auto\")\n",
+ "\n",
+ "print(\"MNIST (5000 images, expected ~10-15):\")\n",
+ "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n",
+ " print(f\" {name}: {mnist_results_5k[name]:.2f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Comparison (1000 vs 5000 images):\")\n",
+ "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n",
+ " print(f\" {name}: {mnist_results[name]:.2f} (1k) → {mnist_results_5k[name]:.2f} (5k)\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x_mnist_5k = mnist_results_5k[\"TwoNN_xy\"][0].numpy()\n",
+ "y_mnist_5k = mnist_results_5k[\"TwoNN_xy\"][1].numpy()\n",
+ "d_mnist_5k = mnist_results_5k[\"TwoNN\"]\n",
+ "valid = np.isfinite(x_mnist_5k) & np.isfinite(y_mnist_5k)\n",
+ "\n",
+ "fig, ax = plt.subplots(figsize=(6, 4))\n",
+ "ax.scatter(x_mnist_5k[valid], y_mnist_5k[valid], alpha=0.4, label=\"data\")\n",
+ "x_line = np.linspace(x_mnist_5k[valid].min(), x_mnist_5k[valid].max(), 50)\n",
+ "ax.plot(x_line, d_mnist_5k * x_line, \"r-\", lw=2, label=f\"y = {d_mnist_5k:.2f} * x\")\n",
+ "ax.set_xlabel(\"$x = \\\\log(\\\\mu)$\")\n",
+ "ax.set_ylabel(\"$y = -\\\\log(1-F)$\")\n",
+ "ax.set_title(\"TwoNN: MNIST 5k images (slope = estimated dimension)\")\n",
+ "ax.legend()\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n",
+ "for ax, name in zip(axes, per_point_methods):\n",
+ " d = mnist_results_5k[f\"{name}_per_point\"].numpy()\n",
+ " d = d[np.isfinite(d)]\n",
+ " bp = ax.boxplot([d], tick_labels=[name], patch_artist=True)\n",
+ " ax.axhline(\n",
+ " mnist_results_5k[\"TwoNN\"],\n",
+ " color=\"gray\",\n",
+ " linestyle=\"--\",\n",
+ " alpha=0.7,\n",
+ " label=f\"TwoNN = {mnist_results_5k['TwoNN']:.1f}\",\n",
+ " )\n",
+ " ax.set_ylabel(\"Estimated dimension (per point)\")\n",
+ " ax.legend()\n",
+ "\n",
+ "plt.suptitle(\"Per-point dimension estimates on MNIST (5000 images)\")\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.7 Per-label dimension estimation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "labels = np.array(ds[\"label\"][:5000])\n",
+ "per_label = {}\n",
+ "for d in range(10):\n",
+ " subset = mnist_data_5k[labels == d]\n",
+ " per_label[d] = run_estimators(subset, k=\"auto\")\n",
+ "\n",
+ "print(\"Per-label dimension estimates (5000 images):\")\n",
+ "for d in range(10):\n",
+ " r = per_label[d]\n",
+ " print(\n",
+ " f\" {d}: TwoNN={r['TwoNN']:.2f} LocalKnn={r['LocalKnn']:.2f} LocalPCA={r['LocalPCA']:.2f} CaPca={r['CaPca']:.2f} (n={(labels == d).sum()})\"\n",
+ " )"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs/source/references.bib b/docs/source/references.bib
index 344630a..445905f 100644
--- a/docs/source/references.bib
+++ b/docs/source/references.bib
@@ -235,7 +235,7 @@ @inproceedings{farahmand2007manifold
pages = {265--272},
url = {https://doi.org/10.1145/1273496.1273530},
doi = {10.1145/1273496.1273530},
- location = {Corvalis, Oregon, USA},
+ location = {Corvallis, Oregon, USA},
series = {ICML '07},
}
@@ -252,3 +252,32 @@ @article{facco_estimating_2017
year = {2017},
pages = {12140},
}
+
+@article{fukunaga1971algorithm,
+ title = {An algorithm for finding intrinsic dimensionality of data},
+ author = {Fukunaga, Keinosuke and Olsen, D. R.},
+ journal = {IEEE Transactions on Computers},
+ volume = {C-20},
+ number = {2},
+ pages = {176--183},
+ year = {1971},
+ doi = {10.1109/T-C.1971.223208},
+}
+
+@article{bruske1998intrinsic,
+ title = {Intrinsic dimensionality estimation with optimally topology preserving maps},
+ author = {Bruske, Jochen and Sommer, Gerald},
+ journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence},
+ volume = {20},
+ number = {5},
+ pages = {572--575},
+ year = {1998},
+ doi = {10.1109/34.682189},
+}
+
+@article{gilbert2023capca,
+ title = {CA-PCA: Manifold Dimension Estimation, Adapted for Curvature},
+ author = {Gilbert, Anna C. and O'Neill, Kevin},
+ journal = {arXiv preprint arXiv:2309.13478},
+ year = {2023},
+}
diff --git a/pyproject.toml b/pyproject.toml
index 688283a..898a6e5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,6 +40,7 @@ dev = [
"scikit-learn>=1.7.1",
]
docs = [
+ "ipywidgets>=8.0.0",
"nbsphinx>=0.9.6",
"pandoc>=2.4",
"plotly>=5.24.1",
@@ -70,6 +71,7 @@ scripts = [
]
notebooks = [
"ipykernel>=6.29.5",
+ "ipywidgets>=8.0.0",
]
diff --git a/src/tdhook/_optional_deps.py b/src/tdhook/_optional_deps.py
new file mode 100644
index 0000000..f7ed2b0
--- /dev/null
+++ b/src/tdhook/_optional_deps.py
@@ -0,0 +1,19 @@
+import functools
+import importlib.util
+
+
+def _ensure_sklearn() -> None:
+ """Raise ImportError if scikit-learn is not installed."""
+ if importlib.util.find_spec("sklearn") is None:
+ raise ImportError("scikit-learn is required for this feature. Install with: pip install scikit-learn")
+
+
+def requires_sklearn(func):
+ """Decorator: raise ImportError if sklearn is missing when the decorated function is called."""
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ _ensure_sklearn()
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/src/tdhook/latent/dimension_estimation/__init__.py b/src/tdhook/latent/dimension_estimation/__init__.py
index eecb55c..29e58fc 100644
--- a/src/tdhook/latent/dimension_estimation/__init__.py
+++ b/src/tdhook/latent/dimension_estimation/__init__.py
@@ -2,10 +2,14 @@
Intrinsic dimension estimation methods.
"""
+from .ca_pca import CaPcaDimensionEstimator
from .local_knn import LocalKnnDimensionEstimator
+from .local_pca import LocalPcaDimensionEstimator
from .twonn import TwoNnDimensionEstimator
__all__ = [
+ "CaPcaDimensionEstimator",
"LocalKnnDimensionEstimator",
+ "LocalPcaDimensionEstimator",
"TwoNnDimensionEstimator",
]
diff --git a/src/tdhook/latent/dimension_estimation/_utils.py b/src/tdhook/latent/dimension_estimation/_utils.py
index a24d9dc..b8ecf59 100644
--- a/src/tdhook/latent/dimension_estimation/_utils.py
+++ b/src/tdhook/latent/dimension_estimation/_utils.py
@@ -3,14 +3,15 @@
import torch
-def sorted_neighbor_distances(data: torch.Tensor, eps: float) -> torch.Tensor:
- """Compute sorted distances to neighbors for each point.
+def sorted_neighbors(data: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
+ """Compute sorted distances and indices of neighbors for each point.
- Returns (N, N) where each row is sorted ascending. Self and distances <= eps are inf.
+ Returns (sorted_dist, indices) each of shape (N, N). Each row is sorted ascending.
+ Self and distances <= eps are inf, with their indices appearing last in each row.
"""
dist = torch.cdist(data, data, p=2)
dist = dist.clone()
dist.fill_diagonal_(float("inf"))
dist = torch.where(dist > eps, dist, float("inf"))
- sorted_dist, _ = torch.sort(dist, dim=1)
- return sorted_dist
+ sorted_dist, indices = torch.sort(dist, dim=1)
+ return sorted_dist, indices
diff --git a/src/tdhook/latent/dimension_estimation/ca_pca.py b/src/tdhook/latent/dimension_estimation/ca_pca.py
new file mode 100644
index 0000000..c96c095
--- /dev/null
+++ b/src/tdhook/latent/dimension_estimation/ca_pca.py
@@ -0,0 +1,125 @@
+from textwrap import indent
+from typing import Literal, Union
+
+import numpy as np
+import torch
+from tensordict import TensorDict
+from tensordict.nn import TensorDictModuleBase
+
+from tdhook._optional_deps import requires_sklearn
+
+from ._utils import sorted_neighbors
+from .local_knn import _resolve_k
+
+
+class CaPcaDimensionEstimator(TensorDictModuleBase):
+ """
+ Curvature-adjusted intrinsic dimension estimation via local PCA :cite:`gilbert2023capca`.
+
+ Extends local PCA by calibrating to a quadratic embedding instead of a flat unit ball,
+ accounting for manifold curvature. For each point, uses its k+1 nearest neighbors,
+ forms the local covariance, and selects dimension by comparing curvature-corrected
+ eigenvalues to the expected spectrum of a d-dimensional ball.
+
+ Reads a data tensor from the input TensorDict. Expects (N, D) or (..., N, D).
+ Outputs per-point dimension estimates of shape (..., N).
+ """
+
+ def __init__(
+ self,
+ k: Union[int, Literal["auto"]] = "auto",
+ in_key: str = "data",
+ out_key: str = "dimension",
+ eps: float = 1e-5,
+ ):
+ super().__init__()
+ if k != "auto":
+ if not isinstance(k, int):
+ raise TypeError("k must be an int or 'auto'")
+ if k < 1:
+ raise ValueError("k must be at least 1")
+ self.k = k
+ self.in_key = in_key
+ self.out_key = out_key
+ self.eps = eps
+ self.in_keys = [in_key]
+ self.out_keys = [out_key]
+
+ @requires_sklearn
+ def forward(self, td: TensorDict) -> TensorDict:
+ from sklearn.decomposition import PCA
+
+ data = td[self.in_key]
+ N = data.shape[-2]
+ k = _resolve_k(self.k, N)
+ if N < k + 2:
+ raise ValueError(f"At least k+2 points required for CA-PCA (k={k}), got {N}")
+ batch_shape = data.shape[:-2]
+ flat = data.reshape(-1, data.shape[-2], data.shape[-1])
+ device = data.device
+ dtype = data.dtype
+ dims = []
+ for i in range(flat.shape[0]):
+ d_i = _ca_pca(flat[i], k=k, eps=self.eps, pca_cls=PCA)
+ dims.append(d_i)
+ td[self.out_key] = torch.stack(dims).reshape(*batch_shape, N).to(device=device, dtype=dtype)
+ return td
+
+ def __repr__(self):
+ fields = indent(
+ f"in_keys={self.in_keys},\nout_keys={self.out_keys},\nk={self.k},\neps={self.eps}",
+ 4 * " ",
+ )
+ return f"{type(self).__name__}(\n{fields})"
+
+
+def _ca_pca(data: torch.Tensor, k: int, eps: float, pca_cls: type) -> torch.Tensor:
+ """Compute per-point dimension via CA-PCA. data: (N, D). Returns (N,) dimension estimates."""
+ sorted_dist, indices = sorted_neighbors(data, eps)
+ N, D = data.shape
+ dims = []
+ for i in range(N):
+ dist_k = sorted_dist[i, k - 1]
+ dist_kp1 = sorted_dist[i, k]
+ r = (dist_k + dist_kp1) / 2.0
+ if r <= 0 or not np.isfinite(float(r)):
+ dims.append(float("nan"))
+ continue
+ valid_mask = torch.isfinite(sorted_dist[i])
+ valid_indices = indices[i][valid_mask]
+ neighbor_idx = valid_indices[: k + 1]
+ if len(neighbor_idx) < k + 1:
+ dims.append(float("nan"))
+ continue
+ neighborhood = data[neighbor_idx].detach().cpu().double().numpy()
+ if neighborhood.shape[0] < 2:
+ dims.append(float("nan"))
+ continue
+ pca = pca_cls(n_components=None).fit(neighborhood)
+ eigvals = pca.explained_variance_
+ lambda_hat = np.zeros(D, dtype=np.float64)
+ n_eig = min(len(eigvals), D)
+ lambda_hat[:n_eig] = eigvals[:n_eig] / (r**2)
+ d_est = _dim_from_ca_pca(lambda_hat, D)
+ dims.append(float(d_est))
+ return torch.tensor(dims, device=data.device, dtype=torch.float32)
+
+
+def _dim_from_ca_pca(lambda_hat: np.ndarray, D: int) -> int:
+ """Select dimension via curvature-corrected eigenvalue matching."""
+ best_d = 1
+ best_score = np.inf
+ for d in range(1, D + 1):
+ tail_sum = lambda_hat[d:].sum()
+ coef = (3 * d + 4) / (d * (d + 4)) if d > 0 else 0.0
+ lambda_d = np.zeros(D)
+ lambda_d[:d] = lambda_hat[:d] + coef * tail_sum
+ lambda_d[d:] = 0.0
+ target = np.zeros(D)
+ target[:d] = 1.0 / (d + 2)
+ target[d:] = 0.0
+ score = float(np.linalg.norm(target - lambda_d)) + 2.0 * tail_sum
+ if score < best_score:
+ best_score = score
+ best_d = d
+ return best_d
diff --git a/src/tdhook/latent/dimension_estimation/local_knn.py b/src/tdhook/latent/dimension_estimation/local_knn.py
index 8352c17..1bdf3f6 100644
--- a/src/tdhook/latent/dimension_estimation/local_knn.py
+++ b/src/tdhook/latent/dimension_estimation/local_knn.py
@@ -5,7 +5,7 @@
from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase
-from ._utils import sorted_neighbor_distances
+from ._utils import sorted_neighbors
def _resolve_k(k: Union[int, Literal["auto"]], n: int) -> int:
@@ -71,7 +71,7 @@ def __repr__(self):
def _local_knn(data: torch.Tensor, k: int, eps: float) -> torch.Tensor:
"""Compute per-point local dimension. data: (N, D). Returns (N,) dimension estimates."""
- sorted_dist = sorted_neighbor_distances(data, eps)
+ sorted_dist, _ = sorted_neighbors(data, eps)
rk = sorted_dist[:, k - 1]
r2k = sorted_dist[:, 2 * k - 1]
diff --git a/src/tdhook/latent/dimension_estimation/local_pca.py b/src/tdhook/latent/dimension_estimation/local_pca.py
new file mode 100644
index 0000000..6e9e201
--- /dev/null
+++ b/src/tdhook/latent/dimension_estimation/local_pca.py
@@ -0,0 +1,147 @@
+"""Local PCA dimension estimation via eigenvalues of local covariance."""
+
+from textwrap import indent
+from typing import Literal, Union
+
+import numpy as np
+import torch
+from tensordict import TensorDict
+from tensordict.nn import TensorDictModuleBase
+
+from tdhook._optional_deps import requires_sklearn
+
+from ._utils import sorted_neighbors
+from .local_knn import _resolve_k
+
+
+class LocalPcaDimensionEstimator(TensorDictModuleBase):
+ """
+ Local intrinsic dimension estimation via PCA on k-NN neighborhoods :cite:`fukunaga1971algorithm`.
+
+ For each point, extracts its k+1 nearest neighbors (self + k neighbors), fits PCA,
+ and estimates dimension from eigenvalues using a configurable criterion (maxgap or ratio).
+
+ Reads a data tensor from the input TensorDict. Expects (N, D) or (..., N, D).
+ Outputs per-point dimension estimates of shape (..., N).
+ """
+
+ def __init__(
+ self,
+ k: Union[int, Literal["auto"]] = "auto",
+ criterion: Literal["maxgap", "ratio"] = "maxgap",
+ alpha: float = 0.05,
+ in_key: str = "data",
+ out_key: str = "dimension",
+ eps: float = 1e-5,
+ ):
+ super().__init__()
+ if k != "auto":
+ if not isinstance(k, int):
+ raise TypeError("k must be an int or 'auto'")
+ if k < 1:
+ raise ValueError("k must be at least 1")
+ self.k = k
+ self.criterion = criterion
+ self.alpha = alpha
+ self.in_key = in_key
+ self.out_key = out_key
+ self.eps = eps
+ self.in_keys = [in_key]
+ self.out_keys = [out_key]
+
+ @requires_sklearn
+ def forward(self, td: TensorDict) -> TensorDict:
+ from sklearn.decomposition import PCA
+
+ data = td[self.in_key]
+ N = data.shape[-2]
+ k = _resolve_k(self.k, N)
+ if N < k + 1:
+ raise ValueError(f"At least k+1 points required for local PCA (k={k}), got {N}")
+ batch_shape = data.shape[:-2]
+ flat = data.reshape(-1, data.shape[-2], data.shape[-1])
+ device = data.device
+ dtype = data.dtype
+ dims = []
+ for i in range(flat.shape[0]):
+ d_i = _local_pca(
+ flat[i],
+ k=k,
+ eps=self.eps,
+ criterion=self.criterion,
+ alpha=self.alpha,
+ pca_cls=PCA,
+ )
+ dims.append(d_i)
+ td[self.out_key] = torch.stack(dims).reshape(*batch_shape, N).to(device=device, dtype=dtype)
+ return td
+
+ def __repr__(self):
+ fields = indent(
+ f"in_keys={self.in_keys},\nout_keys={self.out_keys},\nk={self.k},\n"
+ f"criterion={self.criterion!r},\nalpha={self.alpha},\neps={self.eps}",
+ 4 * " ",
+ )
+ return f"{type(self).__name__}(\n{fields})"
+
+
+def _local_pca(
+ data: torch.Tensor,
+ k: int,
+ eps: float,
+ criterion: Literal["maxgap", "ratio"],
+ alpha: float,
+ pca_cls: type,
+) -> torch.Tensor:
+ """Compute per-point local dimension via PCA. data: (N, D). Returns (N,) dimension estimates."""
+ sorted_dist, indices = sorted_neighbors(data, eps)
+ N, D = data.shape
+ dims = []
+ for i in range(N):
+ valid_mask = torch.isfinite(sorted_dist[i])
+ valid_indices = indices[i][valid_mask]
+ neighbor_idx = valid_indices[:k]
+ if len(neighbor_idx) < k:
+ dims.append(float("nan"))
+ continue
+ neighborhood = torch.cat([data[i : i + 1], data[neighbor_idx]], dim=0)
+ X = neighborhood.detach().cpu().double().numpy()
+ if X.shape[0] < 2:
+ dims.append(float("nan"))
+ continue
+ pca = pca_cls(n_components=None).fit(X)
+ lambda_ = pca.explained_variance_
+ if len(lambda_) == 0:
+ dims.append(1.0)
+ continue
+ if criterion == "maxgap":
+ d = float(_dim_from_eigenvalues_maxgap(lambda_))
+ elif criterion == "ratio":
+ d = float(_dim_from_eigenvalues_ratio(lambda_, alpha))
+ else:
+ raise ValueError(f"Unknown criterion: {criterion!r}")
+ dims.append(d)
+ return torch.tensor(dims, device=data.device, dtype=torch.float32)
+
+
+def _dim_from_eigenvalues_maxgap(lambda_: np.ndarray) -> int:
+ """Estimate dimension from eigenvalues using the maximum gap criterion :cite:`bruske1998intrinsic`.
+
+ de = argmax(lambda[i]/lambda[i+1]) + 1 (1-based dimension).
+ """
+ if len(lambda_) < 2:
+ return 1
+ gaps = lambda_[:-1] / (lambda_[1:] + 1e-15)
+ return int(np.argmax(gaps) + 1)
+
+
+def _dim_from_eigenvalues_ratio(lambda_: np.ndarray, alpha: float) -> int:
+ """Estimate dimension using ratio criterion :cite:`fukunaga1971algorithm`.
+
+ Count eigenvalues above alpha * lambda[0]. Clamped to at least 1.
+ """
+ if len(lambda_) == 0:
+ return 1
+ threshold = alpha * lambda_[0]
+ de = int(np.sum(lambda_ > threshold))
+ return max(1, de)
diff --git a/src/tdhook/latent/dimension_estimation/twonn.py b/src/tdhook/latent/dimension_estimation/twonn.py
index 78c0385..290b71d 100644
--- a/src/tdhook/latent/dimension_estimation/twonn.py
+++ b/src/tdhook/latent/dimension_estimation/twonn.py
@@ -4,7 +4,7 @@
from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase
-from ._utils import sorted_neighbor_distances
+from ._utils import sorted_neighbors
class TwoNnDimensionEstimator(TensorDictModuleBase):
@@ -71,7 +71,7 @@ def _twonn(data: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor,
Distances <= eps are treated as duplicates (excluded from nearest-neighbor selection).
"""
- sorted_dist = sorted_neighbor_distances(data, eps)
+ sorted_dist, _ = sorted_neighbors(data, eps)
r1 = sorted_dist[:, 0]
r2 = sorted_dist[:, 1]
diff --git a/tests/latent/test_dimension_estimation.py b/tests/latent/test_dimension_estimation.py
index f62d890..cf1bf9a 100644
--- a/tests/latent/test_dimension_estimation.py
+++ b/tests/latent/test_dimension_estimation.py
@@ -2,12 +2,26 @@
Tests for intrinsic dimension estimation.
"""
+import importlib.util
+from unittest.mock import patch
+
+import numpy as np
import pytest
import torch
from sklearn.metrics import r2_score
from tensordict import TensorDict
-from tdhook.latent.dimension_estimation import LocalKnnDimensionEstimator, TwoNnDimensionEstimator
+from tdhook.latent.dimension_estimation import (
+ CaPcaDimensionEstimator,
+ LocalKnnDimensionEstimator,
+ LocalPcaDimensionEstimator,
+ TwoNnDimensionEstimator,
+)
+from tdhook.latent.dimension_estimation.local_pca import (
+ _dim_from_eigenvalues_maxgap,
+ _dim_from_eigenvalues_ratio,
+ _local_pca,
+)
@pytest.fixture
@@ -240,3 +254,313 @@ def test_repr(self):
assert "out_keys=['dimension']" in r
assert "k=3" in r
assert "eps=" in r
+
+
+@pytest.fixture
+def run_local_pca_estimator():
+ torch.manual_seed(42)
+
+ def _run(data, k=5, in_key="data", batch_size=None, **estimator_kwargs):
+ if batch_size is None:
+ batch_size = [] if data.ndim == 2 else data.shape[:-2]
+ td = TensorDict({in_key: data}, batch_size=batch_size)
+ return LocalPcaDimensionEstimator(k=k, in_key=in_key, **estimator_kwargs)(td)
+
+ return _run
+
+
+class TestLocalPcaDimensionEstimator:
+ """Test the LocalPcaDimensionEstimator class."""
+
+ def test_default_keys(self, run_local_pca_estimator):
+ """Test with default in_key and out_key."""
+ data = torch.randn(50, 10)
+ result = run_local_pca_estimator(data, k=5)
+ assert "dimension" in result
+ assert result["dimension"].shape == (50,)
+ assert result["dimension"].dtype in (torch.float32, torch.float64)
+ valid = torch.isfinite(result["dimension"])
+ assert valid.sum() > 0
+ assert (result["dimension"][valid] >= 1).all()
+
+ def test_custom_keys(self, run_local_pca_estimator):
+ """Test with custom in_key and out_key."""
+ data = torch.randn(50, 8)
+ result = run_local_pca_estimator(data, k=5, in_key="linear2", out_key="intrinsic_dim")
+ assert "intrinsic_dim" in result
+ assert "linear2" in result
+ assert result["intrinsic_dim"].shape == (50,)
+
+ def test_output_shape(self, run_local_pca_estimator):
+ """Test output shape (N,) for (N, D) input."""
+ data = torch.randn(100, 5)
+ result = run_local_pca_estimator(data, k=5)
+ assert result["dimension"].shape == (100,)
+
+ def test_known_dimension_2d_maxgap(self, run_local_pca_estimator, plane_data):
+ """Test on 2D manifold with maxgap criterion."""
+ result = run_local_pca_estimator(plane_data, k=5, criterion="maxgap")
+ d = result["dimension"]
+ valid = torch.isfinite(d)
+ mean_d = d[valid].mean().item()
+ assert 1.0 < mean_d < 5.0
+
+ def test_known_dimension_2d_ratio(self, run_local_pca_estimator, plane_data):
+ """Test on 2D manifold with ratio criterion."""
+ result = run_local_pca_estimator(plane_data, k=5, criterion="ratio")
+ d = result["dimension"]
+ valid = torch.isfinite(d)
+ mean_d = d[valid].mean().item()
+ assert 1.0 < mean_d < 5.0
+
+ def test_known_dimension_circle(self, run_local_pca_estimator, circle_data):
+ """Test on 1D manifold (circle) embedded in 2D."""
+ result = run_local_pca_estimator(circle_data, k=5)
+ d = result["dimension"]
+ valid = torch.isfinite(d)
+ mean_d = d[valid].mean().item()
+ assert 0.5 < mean_d < 3.0
+
+ @pytest.mark.parametrize(
+ "shape",
+ [(1, 10, 8), (5, 10, 8), (2, 3, 10, 4)],
+ ids=["1x10x8", "5x10x8", "2x3x10x4"],
+ )
+ def test_batch_shape_preservation(self, run_local_pca_estimator, shape):
+ """Test that (..., N, D) preserves batch shape, output is (..., N)."""
+ data = torch.randn(*shape)
+ batch_size = shape[:-2]
+ N = shape[-2]
+ result = run_local_pca_estimator(data, k=5, batch_size=batch_size)
+ assert result["dimension"].shape == (*batch_size, N)
+
+ def test_too_few_points_raises(self, run_local_pca_estimator):
+ """Test that N < k+1 raises."""
+ with pytest.raises(ValueError, match="At least k\\+1 points"):
+ run_local_pca_estimator(torch.randn(5, 5), k=5) # 5 < 5+1
+
+ def test_k_validation(self):
+ """Test that invalid k raises (k < 1 or wrong type)."""
+ with pytest.raises(ValueError, match="k must be at least 1"):
+ LocalPcaDimensionEstimator(k=0)
+ with pytest.raises(TypeError, match="k must be an int or 'auto'"):
+ LocalPcaDimensionEstimator(k=2.5)
+
+ def test_determinism(self, run_local_pca_estimator):
+ """Test that same input yields same output."""
+ torch.manual_seed(42)
+ data = torch.randn(80, 6)
+ r1 = run_local_pca_estimator(data.clone(), k=5)["dimension"]
+ r2 = run_local_pca_estimator(data.clone(), k=5)["dimension"]
+ assert torch.allclose(r1, r2, equal_nan=True)
+
+ def test_k_auto(self, run_local_pca_estimator, plane_data):
+ """Test k='auto' uses n**0.5."""
+ result = run_local_pca_estimator(plane_data, k="auto")
+ assert "dimension" in result
+ assert result["dimension"].shape == (100,)
+
+ def test_repr(self):
+ """Test __repr__ includes class name, in_keys, out_keys, k, criterion, and eps."""
+ est = LocalPcaDimensionEstimator(k=3, criterion="maxgap")
+ r = repr(est)
+ assert "LocalPcaDimensionEstimator" in r
+ assert "in_keys=['data']" in r
+ assert "out_keys=['dimension']" in r
+ assert "k=3" in r
+ assert "criterion='maxgap'" in r
+ assert "eps=" in r
+
+ def test_sklearn_missing_raises(self):
+ """Test that ImportError is raised when sklearn is not installed."""
+ with patch.object(importlib.util, "find_spec", return_value=None):
+ est = LocalPcaDimensionEstimator(k=5)
+ td = TensorDict({"data": torch.randn(20, 5)}, batch_size=[])
+ with pytest.raises(ImportError, match="scikit-learn"):
+ est(td)
+
+ def test_unknown_criterion_raises(self, run_local_pca_estimator):
+ """Test that unknown criterion raises ValueError."""
+ data = torch.randn(20, 5)
+ est = LocalPcaDimensionEstimator(k=5, criterion="invalid")
+ td = TensorDict({"data": data}, batch_size=[])
+ with pytest.raises(ValueError, match="Unknown criterion"):
+ est(td)
+
+ def test_maxgap_single_eigenvalue(self, run_local_pca_estimator):
+ """Test maxgap with 1D data (single eigenvalue) returns 1."""
+ data = torch.randn(15, 1)
+ result = run_local_pca_estimator(data, k=2, criterion="maxgap")
+ assert (result["dimension"] >= 1).all()
+ assert result["dimension"].shape == (15,)
+
+ def test_ratio_empty_eigenvalues(self):
+ """Test ratio criterion with empty eigenvalues returns 1."""
+ assert _dim_from_eigenvalues_ratio(np.array([]), alpha=0.05) == 1
+
+ def test_maxgap_few_eigenvalues(self):
+ """Test maxgap with single eigenvalue returns 1."""
+ assert _dim_from_eigenvalues_maxgap(np.array([1.0])) == 1
+ assert _dim_from_eigenvalues_maxgap(np.array([])) == 1
+
+ def test_constant_data_handled(self, run_local_pca_estimator):
+ """Test constant data (zero variance) does not crash. Returns NaN (no valid neighbors)."""
+ data = torch.ones(10, 5)
+ result = run_local_pca_estimator(data, k=2)
+ assert result["dimension"].shape == (10,)
+ assert torch.isnan(result["dimension"]).all()
+
+ def test_few_neighborhood_points_returns_nan(self):
+ """Test that k=0 (single-point neighborhood) returns nan."""
+ from sklearn.decomposition import PCA
+
+ data = torch.randn(5, 3)
+ result = _local_pca(data, k=0, eps=1e-5, criterion="maxgap", alpha=0.05, pca_cls=PCA)
+ assert result.shape == (5,)
+ assert torch.isnan(result).all()
+
+ def test_empty_eigenvalues_returns_one(self):
+ """Test that PCA returning empty eigenvalues yields 1.0."""
+
+ class MockPCA:
+ def __init__(self, n_components=None):
+ pass
+
+ def fit(self, X):
+ self.explained_variance_ = np.array([])
+ return self
+
+ data = torch.randn(5, 3)
+ result = _local_pca(data, k=2, eps=1e-5, criterion="maxgap", alpha=0.05, pca_cls=MockPCA)
+ assert result.shape == (5,)
+ assert (result == 1.0).all()
+
+
+@pytest.fixture
+def run_ca_pca_estimator():
+ torch.manual_seed(42)
+
+ def _run(data, k=5, in_key="data", batch_size=None, **estimator_kwargs):
+ if batch_size is None:
+ batch_size = [] if data.ndim == 2 else data.shape[:-2]
+ td = TensorDict({in_key: data}, batch_size=batch_size)
+ return CaPcaDimensionEstimator(k=k, in_key=in_key, **estimator_kwargs)(td)
+
+ return _run
+
+
+class TestCaPcaDimensionEstimator:
+ """Test the CaPcaDimensionEstimator class."""
+
+ def test_default_keys(self, run_ca_pca_estimator):
+ """Test with default in_key and out_key."""
+ data = torch.randn(50, 10)
+ result = run_ca_pca_estimator(data, k=5)
+ assert "dimension" in result
+ assert result["dimension"].shape == (50,)
+ assert result["dimension"].dtype in (torch.float32, torch.float64)
+ valid = torch.isfinite(result["dimension"])
+ assert valid.sum() > 0
+ assert (result["dimension"][valid] >= 1).all()
+
+ def test_custom_keys(self, run_ca_pca_estimator):
+ """Test with custom in_key and out_key."""
+ data = torch.randn(50, 8)
+ result = run_ca_pca_estimator(data, k=5, in_key="linear2", out_key="intrinsic_dim")
+ assert "intrinsic_dim" in result
+ assert "linear2" in result
+ assert result["intrinsic_dim"].shape == (50,)
+
+ def test_output_shape(self, run_ca_pca_estimator):
+ """Test output shape (N,) for (N, D) input."""
+ data = torch.randn(100, 5)
+ result = run_ca_pca_estimator(data, k=5)
+ assert result["dimension"].shape == (100,)
+
+ def test_known_dimension_2d(self, run_ca_pca_estimator, plane_data):
+ """Test on 2D manifold embedded in higher space."""
+ result = run_ca_pca_estimator(plane_data, k=5)
+ d = result["dimension"]
+ valid = torch.isfinite(d)
+ mean_d = d[valid].mean().item()
+ assert 1.0 < mean_d < 5.0
+
+ def test_known_dimension_circle(self, run_ca_pca_estimator, circle_data):
+ """Test on 1D manifold (circle) embedded in 2D."""
+ result = run_ca_pca_estimator(circle_data, k=5)
+ d = result["dimension"]
+ valid = torch.isfinite(d)
+ mean_d = d[valid].mean().item()
+ assert 0.5 < mean_d < 3.0
+
+ @pytest.mark.parametrize(
+ "shape",
+ [(1, 10, 8), (5, 10, 8), (2, 3, 10, 4)],
+ ids=["1x10x8", "5x10x8", "2x3x10x4"],
+ )
+ def test_batch_shape_preservation(self, run_ca_pca_estimator, shape):
+ """Test that (..., N, D) preserves batch shape, output is (..., N)."""
+ data = torch.randn(*shape)
+ batch_size = shape[:-2]
+ N = shape[-2]
+ result = run_ca_pca_estimator(data, k=5, batch_size=batch_size)
+ assert result["dimension"].shape == (*batch_size, N)
+
+ def test_too_few_points_raises(self, run_ca_pca_estimator):
+ """Test that N < k+2 raises."""
+ with pytest.raises(ValueError, match="At least k\\+2 points"):
+ run_ca_pca_estimator(torch.randn(6, 5), k=5) # 6 < 5+2
+
+ def test_k_validation(self):
+ """Test that invalid k raises (k < 1 or wrong type)."""
+ with pytest.raises(ValueError, match="k must be at least 1"):
+ CaPcaDimensionEstimator(k=0)
+ with pytest.raises(TypeError, match="k must be an int or 'auto'"):
+ CaPcaDimensionEstimator(k=2.5)
+
+ def test_determinism(self, run_ca_pca_estimator):
+ """Test that same input yields same output."""
+ torch.manual_seed(42)
+ data = torch.randn(80, 6)
+ r1 = run_ca_pca_estimator(data.clone(), k=5)["dimension"]
+ r2 = run_ca_pca_estimator(data.clone(), k=5)["dimension"]
+ assert torch.allclose(r1, r2, equal_nan=True)
+
+ def test_k_auto(self, run_ca_pca_estimator, plane_data):
+ """Test k='auto' uses n**0.5."""
+ result = run_ca_pca_estimator(plane_data, k="auto")
+ assert "dimension" in result
+ assert result["dimension"].shape == (100,)
+
+ def test_repr(self):
+ """Test __repr__ includes class name, in_keys, out_keys, k, and eps."""
+ est = CaPcaDimensionEstimator(k=3)
+ r = repr(est)
+ assert "CaPcaDimensionEstimator" in r
+ assert "in_keys=['data']" in r
+ assert "out_keys=['dimension']" in r
+ assert "k=3" in r
+ assert "eps=" in r
+
+ def test_sklearn_missing_raises(self):
+ """Test that ImportError is raised when sklearn is not installed."""
+ with patch.object(importlib.util, "find_spec", return_value=None):
+ est = CaPcaDimensionEstimator(k=5)
+ td = TensorDict({"data": torch.randn(20, 5)}, batch_size=[])
+ with pytest.raises(ImportError, match="scikit-learn"):
+ est(td)
+
+ def test_constant_data_handled(self, run_ca_pca_estimator):
+ """Test constant data (all duplicates) returns nan without crashing."""
+ data = torch.ones(10, 5)
+ result = run_ca_pca_estimator(data, k=5)
+ assert result["dimension"].shape == (10,)
+ assert torch.isnan(result["dimension"]).all()
+
+ def test_k1_uses_two_neighbors(self, run_ca_pca_estimator):
+ """Test k=1 uses k+1=2 neighbors and returns valid dimension estimates."""
+ data = torch.randn(10, 5)
+ result = run_ca_pca_estimator(data, k=1)
+ assert result["dimension"].shape == (10,)
+ assert torch.isfinite(result["dimension"]).all()
+ assert (result["dimension"] >= 1).all()
diff --git a/uv.lock b/uv.lock
index 154c8bc..7ee46df 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1361,6 +1361,23 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" },
]
+[[package]]
+name = "ipywidgets"
+version = "8.1.8"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "comm" },
+ { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
+ { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
+ { name = "jupyterlab-widgets" },
+ { name = "traitlets" },
+ { name = "widgetsnbextension" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/4c/ae/c5ce1edc1afe042eadb445e95b0671b03cee61895264357956e61c0d2ac0/ipywidgets-8.1.8.tar.gz", hash = "sha256:61f969306b95f85fba6b6986b7fe45d73124d1d9e3023a8068710d47a22ea668", size = 116739, upload-time = "2025-11-01T21:18:12.393Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/56/6d/0d9848617b9f753b87f214f1c682592f7ca42de085f564352f10f0843026/ipywidgets-8.1.8-py3-none-any.whl", hash = "sha256:ecaca67aed704a338f88f67b1181b58f821ab5dc89c1f0f5ef99db43c1c2921e", size = 139808, upload-time = "2025-11-01T21:18:10.956Z" },
+]
+
[[package]]
name = "jaxtyping"
version = "0.3.2"
@@ -1472,6 +1489,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b1/dd/ead9d8ea85bf202d90cc513b533f9c363121c7792674f78e0d8a854b63b4/jupyterlab_pygments-0.3.0-py3-none-any.whl", hash = "sha256:841a89020971da1d8693f1a99997aefc5dc424bb1b251fd6322462a1b8842780", size = 15884, upload-time = "2023-11-23T09:26:34.325Z" },
]
+[[package]]
+name = "jupyterlab-widgets"
+version = "3.0.16"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/26/2d/ef58fed122b268c69c0aa099da20bc67657cdfb2e222688d5731bd5b971d/jupyterlab_widgets-3.0.16.tar.gz", hash = "sha256:423da05071d55cf27a9e602216d35a3a65a3e41cdf9c5d3b643b814ce38c19e0", size = 897423, upload-time = "2025-11-01T21:11:29.724Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ab/b5/36c712098e6191d1b4e349304ef73a8d06aed77e56ceaac8c0a306c7bda1/jupyterlab_widgets-3.0.16-py3-none-any.whl", hash = "sha256:45fa36d9c6422cf2559198e4db481aa243c7a32d9926b500781c830c80f7ecf8", size = 914926, upload-time = "2025-11-01T21:11:28.008Z" },
+]
+
[[package]]
name = "kiwisolver"
version = "1.4.8"
@@ -4102,6 +4128,7 @@ dev = [
{ name = "zennit", version = "1.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11.11'" },
]
docs = [
+ { name = "ipywidgets" },
{ name = "nbsphinx" },
{ name = "pandoc" },
{ name = "plotly" },
@@ -4115,6 +4142,7 @@ docs = [
]
notebooks = [
{ name = "ipykernel" },
+ { name = "ipywidgets" },
]
scripts = [
{ name = "captum" },
@@ -4152,6 +4180,7 @@ dev = [
{ name = "zennit", specifier = ">=0.5.1" },
]
docs = [
+ { name = "ipywidgets", specifier = ">=8.0.0" },
{ name = "nbsphinx", specifier = ">=0.9.6" },
{ name = "pandoc", specifier = ">=2.4" },
{ name = "plotly", specifier = ">=5.24.1" },
@@ -4163,7 +4192,10 @@ docs = [
{ name = "sphinx-design", specifier = ">=0.6.1" },
{ name = "sphinxcontrib-bibtex", specifier = ">=2.6.5" },
]
-notebooks = [{ name = "ipykernel", specifier = ">=6.29.5" }]
+notebooks = [
+ { name = "ipykernel", specifier = ">=6.29.5" },
+ { name = "ipywidgets", specifier = ">=8.0.0" },
+]
scripts = [
{ name = "captum", specifier = ">=0.8.0" },
{ name = "datasets", specifier = ">=4.0.0" },
@@ -4692,6 +4724,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5a/84/44687a29792a70e111c5c477230a72c4b957d88d16141199bf9acb7537a3/websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526", size = 58826, upload-time = "2024-04-23T22:16:14.422Z" },
]
+[[package]]
+name = "widgetsnbextension"
+version = "4.0.15"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/bd/f4/c67440c7fb409a71b7404b7aefcd7569a9c0d6bd071299bf4198ae7a5d95/widgetsnbextension-4.0.15.tar.gz", hash = "sha256:de8610639996f1567952d763a5a41af8af37f2575a41f9852a38f947eb82a3b9", size = 1097402, upload-time = "2025-11-01T21:15:55.178Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl", hash = "sha256:8156704e4346a571d9ce73b84bee86a29906c9abfd7223b7228a28899ccf3366", size = 2196503, upload-time = "2025-11-01T21:15:53.565Z" },
+]
+
[[package]]
name = "win32-setctime"
version = "1.2.0"