diff --git a/README.md b/README.md index b32153c..a3d0fa3 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,8 @@ To dig deeper, see the [documentation](https://tdhook.readthedocs.io). - [Integrated Gradients](https://tdhook.readthedocs.io/en/latest/notebooks/methods/integrated-gradients.html): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/tdhook/blob/main/docs/source/notebooks/methods/integrated-gradients.ipynb) - [Steering Vectors](https://tdhook.readthedocs.io/en/latest/notebooks/methods/steering-vectors.html): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/tdhook/blob/main/docs/source/notebooks/methods/steering-vectors.ipynb) +- [Linear Probing](https://tdhook.readthedocs.io/en/latest/notebooks/methods/linear-probing.html): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/tdhook/blob/main/docs/source/notebooks/methods/linear-probing.ipynb) +- [Dimension Estimation](https://tdhook.readthedocs.io/en/latest/notebooks/methods/dimension-estimation.html): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/tdhook/blob/main/docs/source/notebooks/methods/dimension-estimation.ipynb) ## Config diff --git a/docs/source/methods.rst b/docs/source/methods.rst index 3a813ad..75e6a7f 100644 --- a/docs/source/methods.rst +++ b/docs/source/methods.rst @@ -71,6 +71,23 @@ Methods + .. grid-item-card:: + :link: notebooks/methods/dimension-estimation.ipynb + :class-card: surface + :class-body: surface + + .. raw:: html + +
+
+ +
+
+
Dimension Estimation
+

Estimate intrinsic dimension of data manifolds using TwoNN, Local PCA, and related methods.

+
+
+ .. toctree:: :hidden: :maxdepth: 2 @@ -78,3 +95,4 @@ Methods notebooks/methods/integrated-gradients.ipynb notebooks/methods/steering-vectors.ipynb notebooks/methods/linear-probing.ipynb + notebooks/methods/dimension-estimation.ipynb diff --git a/docs/source/notebooks/methods/dimension-estimation.ipynb b/docs/source/notebooks/methods/dimension-estimation.ipynb new file mode 100644 index 0000000..419ce88 --- /dev/null +++ b/docs/source/notebooks/methods/dimension-estimation.ipynb @@ -0,0 +1,583 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Intrinsic Dimension Estimation\n", + "\n", + "This notebook demonstrates how to use the intrinsic dimension estimators from tdhook on synthetic data and MNIST." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import importlib.util\n", + "\n", + "DEV = True\n", + "\n", + "if importlib.util.find_spec(\"google.colab\") is not None:\n", + " MODE = \"colab-dev\" if DEV else \"colab\"\n", + "else:\n", + " MODE = \"local\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if MODE == \"colab\":\n", + " %pip install -q tdhook scikit-learn\n", + "elif MODE == \"colab-dev\":\n", + " !rm -rf tdhook\n", + " !git clone https://github.com/Xmaster6y/tdhook -b main\n", + " %pip install -q ./tdhook" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import torch\n", + "from tensordict import TensorDict\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "from tdhook.latent.dimension_estimation import (\n", + " TwoNnDimensionEstimator,\n", + " LocalKnnDimensionEstimator,\n", + " LocalPcaDimensionEstimator,\n", + " CaPcaDimensionEstimator,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 1: Synthetic Data (Simple Case)\n", + "\n", + "We start with simple synthetic manifolds where we know the true intrinsic dimension." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.1 Generate and visualize synthetic data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(42)\n", + "\n", + "# 2D plane embedded in 10D (last 8 dims zero) - intrinsic dim = 2\n", + "plane_data = torch.randn(200, 10)\n", + "plane_data[:, 2:] = 0\n", + "\n", + "# 1D circle embedded in 2D - intrinsic dim = 1\n", + "theta = torch.rand(200) * 2 * torch.pi\n", + "circle_data = torch.stack([torch.cos(theta), torch.sin(theta)], dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n", + "\n", + "axes[0].scatter(plane_data[:, 0].numpy(), plane_data[:, 1].numpy(), alpha=0.6)\n", + "axes[0].set_title(\"2D plane in 10D (first 2 dims)\")\n", + "axes[0].set_xlabel(\"$x_1$\")\n", + "axes[0].set_ylabel(\"$x_2$\")\n", + "axes[0].set_aspect(\"equal\")\n", + "\n", + "axes[1].scatter(circle_data[:, 0].numpy(), circle_data[:, 1].numpy(), alpha=0.6)\n", + "axes[1].set_title(\"1D circle in 2D\")\n", + "axes[1].set_xlabel(\"$x_1$\")\n", + "axes[1].set_ylabel(\"$x_2$\")\n", + "axes[1].set_aspect(\"equal\")\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.2 Run all estimators" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def run_estimators(data, k=\"auto\"):\n", + " td = TensorDict({\"data\": data}, batch_size=[])\n", + " results = {}\n", + " timings = {}\n", + "\n", + " # TwoNN: single scalar per dataset\n", + " t0 = time.perf_counter()\n", + " twonn = TwoNnDimensionEstimator(return_xy=True)\n", + " td_twonn = twonn(td.clone())\n", + " timings[\"TwoNN\"] = time.perf_counter() - t0\n", + " results[\"TwoNN\"] = td_twonn[\"dimension\"].item()\n", + " results[\"TwoNN_xy\"] = (td_twonn[\"dimension_x\"], td_twonn[\"dimension_y\"])\n", + "\n", + " # Per-point estimators\n", + " for name, est in [\n", + " (\"LocalKnn\", LocalKnnDimensionEstimator(k=k)),\n", + " (\"LocalPCA\", LocalPcaDimensionEstimator(k=k)),\n", + " (\"CaPca\", CaPcaDimensionEstimator(k=k)),\n", + " ]:\n", + " t0 = time.perf_counter()\n", + " td_est = est(td.clone())\n", + " timings[name] = time.perf_counter() - t0\n", + " d = td_est[\"dimension\"]\n", + " valid = torch.isfinite(d)\n", + " results[name] = d[valid].mean().item() if valid.any() else float(\"nan\")\n", + " results[f\"{name}_per_point\"] = d\n", + "\n", + " results[\"timings\"] = timings\n", + " print(\"Timings:\", \" | \".join(f\"{k}: {v:.3f}s\" for k, v in timings.items()))\n", + " return results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`k=5`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plane_results = run_estimators(plane_data, k=5)\n", + "circle_results = run_estimators(circle_data, k=5)\n", + "\n", + "print(\"2D plane (expected ~2):\")\n", + "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n", + " print(f\" {name}: {plane_results[name]:.2f}\")\n", + "\n", + "print(\"\\n1D circle (expected ~1):\")\n", + "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n", + " print(f\" {name}: {circle_results[name]:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`k=10`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plane_results = run_estimators(plane_data, k=10)\n", + "circle_results = run_estimators(circle_data, k=10)\n", + "\n", + "print(\"2D plane (expected ~2):\")\n", + "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n", + " print(f\" {name}: {plane_results[name]:.2f}\")\n", + "\n", + "print(\"\\n1D circle (expected ~1):\")\n", + "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n", + " print(f\" {name}: {circle_results[name]:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`k = \"auto\"`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plane_results = run_estimators(plane_data)\n", + "circle_results = run_estimators(circle_data)\n", + "\n", + "print(\"2D plane (expected ~2):\")\n", + "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n", + " print(f\" {name}: {plane_results[name]:.2f}\")\n", + "\n", + "print(\"\\n1D circle (expected ~1):\")\n", + "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n", + " print(f\" {name}: {circle_results[name]:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.3 TwoNN visualization\n", + "\n", + "TwoNN estimates dimension from the linear relationship $y = d \\cdot x$ where $x = \\log(\\mu)$ and $y = -\\log(1-F)$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x_plane = plane_results[\"TwoNN_xy\"][0].numpy()\n", + "y_plane = plane_results[\"TwoNN_xy\"][1].numpy()\n", + "d_plane = plane_results[\"TwoNN\"]\n", + "valid = np.isfinite(x_plane) & np.isfinite(y_plane)\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 4))\n", + "ax.scatter(x_plane[valid], y_plane[valid], alpha=0.6, label=\"data\")\n", + "x_line = np.linspace(x_plane[valid].min(), x_plane[valid].max(), 50)\n", + "ax.plot(x_line, d_plane * x_line, \"r-\", lw=2, label=f\"y = {d_plane:.2f} * x\")\n", + "ax.set_xlabel(\"$x = \\\\log(\\\\mu)$\")\n", + "ax.set_ylabel(\"$y = -\\\\log(1-F)$\")\n", + "ax.set_title(\"TwoNN: 2D plane (slope = estimated dimension)\")\n", + "ax.legend()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.4 Per-point estimator histograms\n", + "\n", + "Per-point estimators (LocalKnn, LocalPCA, CaPca) give a dimension estimate at each data point. The box plot shows the distribution of these local estimates, revealing variation across the manifold." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n", + "\n", + "for ax, name in zip(axes, [\"LocalKnn\", \"LocalPCA\", \"CaPca\"]):\n", + " d = plane_results[f\"{name}_per_point\"].numpy()\n", + " valid = np.isfinite(d)\n", + " ax.hist(d[valid], bins=20, edgecolor=\"black\", alpha=0.7)\n", + " ax.axvline(plane_results[name], color=\"red\", linestyle=\"--\", label=f\"mean = {plane_results[name]:.2f}\")\n", + " ax.set_xlabel(\"Estimated dimension\")\n", + " ax.set_title(name)\n", + " ax.legend()\n", + "\n", + "plt.suptitle(\"Per-point dimension estimates on 2D plane\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 2: MNIST\n", + "\n", + "We now run the estimators on MNIST digits. The intrinsic dimension of the MNIST manifold is typically around 10-15." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1 Load MNIST" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "ds = load_dataset(\"mnist\", split=\"train\")\n", + "arr = np.array(ds[\"image\"][:1000])\n", + "mnist_data = torch.tensor(arr, dtype=torch.float32).reshape(1000, -1) / 255.0\n", + "\n", + "\n", + "print(f\"MNIST shape: {mnist_data.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2 Sample digits visualization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(3, 3, figsize=(6, 6))\n", + "for i, ax in enumerate(axes.flat):\n", + " img = mnist_data[i].reshape(28, 28).numpy()\n", + " ax.imshow(img, cmap=\"gray\")\n", + " ax.axis(\"off\")\n", + "plt.suptitle(\"Sample MNIST digits\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.3 Run estimators on MNIST" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mnist_results = run_estimators(mnist_data, k=\"auto\")\n", + "\n", + "print(\"MNIST (expected ~10-15):\")\n", + "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n", + " print(f\" {name}: {mnist_results[name]:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.4 TwoNN curve fitting on MNIST" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x_mnist = mnist_results[\"TwoNN_xy\"][0].numpy()\n", + "y_mnist = mnist_results[\"TwoNN_xy\"][1].numpy()\n", + "d_mnist = mnist_results[\"TwoNN\"]\n", + "valid = np.isfinite(x_mnist) & np.isfinite(y_mnist)\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 4))\n", + "ax.scatter(x_mnist[valid], y_mnist[valid], alpha=0.6, label=\"data\")\n", + "x_line = np.linspace(x_mnist[valid].min(), x_mnist[valid].max(), 50)\n", + "ax.plot(x_line, d_mnist * x_line, \"r-\", lw=2, label=f\"y = {d_mnist:.2f} * x\")\n", + "ax.set_xlabel(\"$x = \\\\log(\\\\mu)$\")\n", + "ax.set_ylabel(\"$y = -\\\\log(1-F)$\")\n", + "ax.set_title(\"TwoNN: MNIST (slope = estimated dimension)\")\n", + "ax.legend()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.5 Box plot: per-point estimators (local variations)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "per_point_methods = [\"LocalKnn\", \"LocalPCA\", \"CaPca\"]\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n", + "for ax, name in zip(axes, per_point_methods):\n", + " d = mnist_results[f\"{name}_per_point\"].numpy()\n", + " d = d[np.isfinite(d)]\n", + " bp = ax.boxplot([d], tick_labels=[name], patch_artist=True)\n", + " ax.axhline(\n", + " mnist_results[\"TwoNN\"], color=\"gray\", linestyle=\"--\", alpha=0.7, label=f\"TwoNN = {mnist_results['TwoNN']:.1f}\"\n", + " )\n", + " ax.set_ylabel(\"Estimated dimension (per point)\")\n", + " ax.legend()\n", + "\n", + "plt.suptitle(\"Per-point dimension estimates on MNIST (1000 images)\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.6 Using 5000 images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "arr = np.array(ds[\"image\"][:5000])\n", + "mnist_data_5k = torch.tensor(arr, dtype=torch.float32).reshape(5000, -1) / 255.0\n", + "\n", + "mnist_results_5k = run_estimators(mnist_data_5k, k=\"auto\")\n", + "\n", + "print(\"MNIST (5000 images, expected ~10-15):\")\n", + "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n", + " print(f\" {name}: {mnist_results_5k[name]:.2f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Comparison (1000 vs 5000 images):\")\n", + "for name in [\"TwoNN\", \"LocalKnn\", \"LocalPCA\", \"CaPca\"]:\n", + " print(f\" {name}: {mnist_results[name]:.2f} (1k) → {mnist_results_5k[name]:.2f} (5k)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x_mnist_5k = mnist_results_5k[\"TwoNN_xy\"][0].numpy()\n", + "y_mnist_5k = mnist_results_5k[\"TwoNN_xy\"][1].numpy()\n", + "d_mnist_5k = mnist_results_5k[\"TwoNN\"]\n", + "valid = np.isfinite(x_mnist_5k) & np.isfinite(y_mnist_5k)\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 4))\n", + "ax.scatter(x_mnist_5k[valid], y_mnist_5k[valid], alpha=0.4, label=\"data\")\n", + "x_line = np.linspace(x_mnist_5k[valid].min(), x_mnist_5k[valid].max(), 50)\n", + "ax.plot(x_line, d_mnist_5k * x_line, \"r-\", lw=2, label=f\"y = {d_mnist_5k:.2f} * x\")\n", + "ax.set_xlabel(\"$x = \\\\log(\\\\mu)$\")\n", + "ax.set_ylabel(\"$y = -\\\\log(1-F)$\")\n", + "ax.set_title(\"TwoNN: MNIST 5k images (slope = estimated dimension)\")\n", + "ax.legend()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n", + "for ax, name in zip(axes, per_point_methods):\n", + " d = mnist_results_5k[f\"{name}_per_point\"].numpy()\n", + " d = d[np.isfinite(d)]\n", + " bp = ax.boxplot([d], tick_labels=[name], patch_artist=True)\n", + " ax.axhline(\n", + " mnist_results_5k[\"TwoNN\"],\n", + " color=\"gray\",\n", + " linestyle=\"--\",\n", + " alpha=0.7,\n", + " label=f\"TwoNN = {mnist_results_5k['TwoNN']:.1f}\",\n", + " )\n", + " ax.set_ylabel(\"Estimated dimension (per point)\")\n", + " ax.legend()\n", + "\n", + "plt.suptitle(\"Per-point dimension estimates on MNIST (5000 images)\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.7 Per-label dimension estimation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "labels = np.array(ds[\"label\"][:5000])\n", + "per_label = {}\n", + "for d in range(10):\n", + " subset = mnist_data_5k[labels == d]\n", + " per_label[d] = run_estimators(subset, k=\"auto\")\n", + "\n", + "print(\"Per-label dimension estimates (5000 images):\")\n", + "for d in range(10):\n", + " r = per_label[d]\n", + " print(\n", + " f\" {d}: TwoNN={r['TwoNN']:.2f} LocalKnn={r['LocalKnn']:.2f} LocalPCA={r['LocalPCA']:.2f} CaPca={r['CaPca']:.2f} (n={(labels == d).sum()})\"\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/source/references.bib b/docs/source/references.bib index 344630a..445905f 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -235,7 +235,7 @@ @inproceedings{farahmand2007manifold pages = {265--272}, url = {https://doi.org/10.1145/1273496.1273530}, doi = {10.1145/1273496.1273530}, - location = {Corvalis, Oregon, USA}, + location = {Corvallis, Oregon, USA}, series = {ICML '07}, } @@ -252,3 +252,32 @@ @article{facco_estimating_2017 year = {2017}, pages = {12140}, } + +@article{fukunaga1971algorithm, + title = {An algorithm for finding intrinsic dimensionality of data}, + author = {Fukunaga, Keinosuke and Olsen, D. R.}, + journal = {IEEE Transactions on Computers}, + volume = {C-20}, + number = {2}, + pages = {176--183}, + year = {1971}, + doi = {10.1109/T-C.1971.223208}, +} + +@article{bruske1998intrinsic, + title = {Intrinsic dimensionality estimation with optimally topology preserving maps}, + author = {Bruske, Jochen and Sommer, Gerald}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, + volume = {20}, + number = {5}, + pages = {572--575}, + year = {1998}, + doi = {10.1109/34.682189}, +} + +@article{gilbert2023capca, + title = {CA-PCA: Manifold Dimension Estimation, Adapted for Curvature}, + author = {Gilbert, Anna C. and O'Neill, Kevin}, + journal = {arXiv preprint arXiv:2309.13478}, + year = {2023}, +} diff --git a/pyproject.toml b/pyproject.toml index 688283a..898a6e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dev = [ "scikit-learn>=1.7.1", ] docs = [ + "ipywidgets>=8.0.0", "nbsphinx>=0.9.6", "pandoc>=2.4", "plotly>=5.24.1", @@ -70,6 +71,7 @@ scripts = [ ] notebooks = [ "ipykernel>=6.29.5", + "ipywidgets>=8.0.0", ] diff --git a/src/tdhook/_optional_deps.py b/src/tdhook/_optional_deps.py new file mode 100644 index 0000000..f7ed2b0 --- /dev/null +++ b/src/tdhook/_optional_deps.py @@ -0,0 +1,19 @@ +import functools +import importlib.util + + +def _ensure_sklearn() -> None: + """Raise ImportError if scikit-learn is not installed.""" + if importlib.util.find_spec("sklearn") is None: + raise ImportError("scikit-learn is required for this feature. Install with: pip install scikit-learn") + + +def requires_sklearn(func): + """Decorator: raise ImportError if sklearn is missing when the decorated function is called.""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): + _ensure_sklearn() + return func(*args, **kwargs) + + return wrapper diff --git a/src/tdhook/latent/dimension_estimation/__init__.py b/src/tdhook/latent/dimension_estimation/__init__.py index eecb55c..29e58fc 100644 --- a/src/tdhook/latent/dimension_estimation/__init__.py +++ b/src/tdhook/latent/dimension_estimation/__init__.py @@ -2,10 +2,14 @@ Intrinsic dimension estimation methods. """ +from .ca_pca import CaPcaDimensionEstimator from .local_knn import LocalKnnDimensionEstimator +from .local_pca import LocalPcaDimensionEstimator from .twonn import TwoNnDimensionEstimator __all__ = [ + "CaPcaDimensionEstimator", "LocalKnnDimensionEstimator", + "LocalPcaDimensionEstimator", "TwoNnDimensionEstimator", ] diff --git a/src/tdhook/latent/dimension_estimation/_utils.py b/src/tdhook/latent/dimension_estimation/_utils.py index a24d9dc..b8ecf59 100644 --- a/src/tdhook/latent/dimension_estimation/_utils.py +++ b/src/tdhook/latent/dimension_estimation/_utils.py @@ -3,14 +3,15 @@ import torch -def sorted_neighbor_distances(data: torch.Tensor, eps: float) -> torch.Tensor: - """Compute sorted distances to neighbors for each point. +def sorted_neighbors(data: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]: + """Compute sorted distances and indices of neighbors for each point. - Returns (N, N) where each row is sorted ascending. Self and distances <= eps are inf. + Returns (sorted_dist, indices) each of shape (N, N). Each row is sorted ascending. + Self and distances <= eps are inf, with their indices appearing last in each row. """ dist = torch.cdist(data, data, p=2) dist = dist.clone() dist.fill_diagonal_(float("inf")) dist = torch.where(dist > eps, dist, float("inf")) - sorted_dist, _ = torch.sort(dist, dim=1) - return sorted_dist + sorted_dist, indices = torch.sort(dist, dim=1) + return sorted_dist, indices diff --git a/src/tdhook/latent/dimension_estimation/ca_pca.py b/src/tdhook/latent/dimension_estimation/ca_pca.py new file mode 100644 index 0000000..c96c095 --- /dev/null +++ b/src/tdhook/latent/dimension_estimation/ca_pca.py @@ -0,0 +1,125 @@ +from textwrap import indent +from typing import Literal, Union + +import numpy as np +import torch +from tensordict import TensorDict +from tensordict.nn import TensorDictModuleBase + +from tdhook._optional_deps import requires_sklearn + +from ._utils import sorted_neighbors +from .local_knn import _resolve_k + + +class CaPcaDimensionEstimator(TensorDictModuleBase): + """ + Curvature-adjusted intrinsic dimension estimation via local PCA :cite:`gilbert2023capca`. + + Extends local PCA by calibrating to a quadratic embedding instead of a flat unit ball, + accounting for manifold curvature. For each point, uses its k+1 nearest neighbors, + forms the local covariance, and selects dimension by comparing curvature-corrected + eigenvalues to the expected spectrum of a d-dimensional ball. + + Reads a data tensor from the input TensorDict. Expects (N, D) or (..., N, D). + Outputs per-point dimension estimates of shape (..., N). + """ + + def __init__( + self, + k: Union[int, Literal["auto"]] = "auto", + in_key: str = "data", + out_key: str = "dimension", + eps: float = 1e-5, + ): + super().__init__() + if k != "auto": + if not isinstance(k, int): + raise TypeError("k must be an int or 'auto'") + if k < 1: + raise ValueError("k must be at least 1") + self.k = k + self.in_key = in_key + self.out_key = out_key + self.eps = eps + self.in_keys = [in_key] + self.out_keys = [out_key] + + @requires_sklearn + def forward(self, td: TensorDict) -> TensorDict: + from sklearn.decomposition import PCA + + data = td[self.in_key] + N = data.shape[-2] + k = _resolve_k(self.k, N) + if N < k + 2: + raise ValueError(f"At least k+2 points required for CA-PCA (k={k}), got {N}") + batch_shape = data.shape[:-2] + flat = data.reshape(-1, data.shape[-2], data.shape[-1]) + device = data.device + dtype = data.dtype + dims = [] + for i in range(flat.shape[0]): + d_i = _ca_pca(flat[i], k=k, eps=self.eps, pca_cls=PCA) + dims.append(d_i) + td[self.out_key] = torch.stack(dims).reshape(*batch_shape, N).to(device=device, dtype=dtype) + return td + + def __repr__(self): + fields = indent( + f"in_keys={self.in_keys},\nout_keys={self.out_keys},\nk={self.k},\neps={self.eps}", + 4 * " ", + ) + return f"{type(self).__name__}(\n{fields})" + + +def _ca_pca(data: torch.Tensor, k: int, eps: float, pca_cls: type) -> torch.Tensor: + """Compute per-point dimension via CA-PCA. data: (N, D). Returns (N,) dimension estimates.""" + sorted_dist, indices = sorted_neighbors(data, eps) + N, D = data.shape + dims = [] + for i in range(N): + dist_k = sorted_dist[i, k - 1] + dist_kp1 = sorted_dist[i, k] + r = (dist_k + dist_kp1) / 2.0 + if r <= 0 or not np.isfinite(float(r)): + dims.append(float("nan")) + continue + valid_mask = torch.isfinite(sorted_dist[i]) + valid_indices = indices[i][valid_mask] + neighbor_idx = valid_indices[: k + 1] + if len(neighbor_idx) < k + 1: + dims.append(float("nan")) + continue + neighborhood = data[neighbor_idx].detach().cpu().double().numpy() + if neighborhood.shape[0] < 2: + dims.append(float("nan")) + continue + pca = pca_cls(n_components=None).fit(neighborhood) + eigvals = pca.explained_variance_ + lambda_hat = np.zeros(D, dtype=np.float64) + n_eig = min(len(eigvals), D) + lambda_hat[:n_eig] = eigvals[:n_eig] / (r**2) + d_est = _dim_from_ca_pca(lambda_hat, D) + dims.append(float(d_est)) + return torch.tensor(dims, device=data.device, dtype=torch.float32) + + +def _dim_from_ca_pca(lambda_hat: np.ndarray, D: int) -> int: + """Select dimension via curvature-corrected eigenvalue matching.""" + best_d = 1 + best_score = np.inf + for d in range(1, D + 1): + tail_sum = lambda_hat[d:].sum() + coef = (3 * d + 4) / (d * (d + 4)) if d > 0 else 0.0 + lambda_d = np.zeros(D) + lambda_d[:d] = lambda_hat[:d] + coef * tail_sum + lambda_d[d:] = 0.0 + target = np.zeros(D) + target[:d] = 1.0 / (d + 2) + target[d:] = 0.0 + score = float(np.linalg.norm(target - lambda_d)) + 2.0 * tail_sum + if score < best_score: + best_score = score + best_d = d + return best_d diff --git a/src/tdhook/latent/dimension_estimation/local_knn.py b/src/tdhook/latent/dimension_estimation/local_knn.py index 8352c17..1bdf3f6 100644 --- a/src/tdhook/latent/dimension_estimation/local_knn.py +++ b/src/tdhook/latent/dimension_estimation/local_knn.py @@ -5,7 +5,7 @@ from tensordict import TensorDict from tensordict.nn import TensorDictModuleBase -from ._utils import sorted_neighbor_distances +from ._utils import sorted_neighbors def _resolve_k(k: Union[int, Literal["auto"]], n: int) -> int: @@ -71,7 +71,7 @@ def __repr__(self): def _local_knn(data: torch.Tensor, k: int, eps: float) -> torch.Tensor: """Compute per-point local dimension. data: (N, D). Returns (N,) dimension estimates.""" - sorted_dist = sorted_neighbor_distances(data, eps) + sorted_dist, _ = sorted_neighbors(data, eps) rk = sorted_dist[:, k - 1] r2k = sorted_dist[:, 2 * k - 1] diff --git a/src/tdhook/latent/dimension_estimation/local_pca.py b/src/tdhook/latent/dimension_estimation/local_pca.py new file mode 100644 index 0000000..6e9e201 --- /dev/null +++ b/src/tdhook/latent/dimension_estimation/local_pca.py @@ -0,0 +1,147 @@ +"""Local PCA dimension estimation via eigenvalues of local covariance.""" + +from textwrap import indent +from typing import Literal, Union + +import numpy as np +import torch +from tensordict import TensorDict +from tensordict.nn import TensorDictModuleBase + +from tdhook._optional_deps import requires_sklearn + +from ._utils import sorted_neighbors +from .local_knn import _resolve_k + + +class LocalPcaDimensionEstimator(TensorDictModuleBase): + """ + Local intrinsic dimension estimation via PCA on k-NN neighborhoods :cite:`fukunaga1971algorithm`. + + For each point, extracts its k+1 nearest neighbors (self + k neighbors), fits PCA, + and estimates dimension from eigenvalues using a configurable criterion (maxgap or ratio). + + Reads a data tensor from the input TensorDict. Expects (N, D) or (..., N, D). + Outputs per-point dimension estimates of shape (..., N). + """ + + def __init__( + self, + k: Union[int, Literal["auto"]] = "auto", + criterion: Literal["maxgap", "ratio"] = "maxgap", + alpha: float = 0.05, + in_key: str = "data", + out_key: str = "dimension", + eps: float = 1e-5, + ): + super().__init__() + if k != "auto": + if not isinstance(k, int): + raise TypeError("k must be an int or 'auto'") + if k < 1: + raise ValueError("k must be at least 1") + self.k = k + self.criterion = criterion + self.alpha = alpha + self.in_key = in_key + self.out_key = out_key + self.eps = eps + self.in_keys = [in_key] + self.out_keys = [out_key] + + @requires_sklearn + def forward(self, td: TensorDict) -> TensorDict: + from sklearn.decomposition import PCA + + data = td[self.in_key] + N = data.shape[-2] + k = _resolve_k(self.k, N) + if N < k + 1: + raise ValueError(f"At least k+1 points required for local PCA (k={k}), got {N}") + batch_shape = data.shape[:-2] + flat = data.reshape(-1, data.shape[-2], data.shape[-1]) + device = data.device + dtype = data.dtype + dims = [] + for i in range(flat.shape[0]): + d_i = _local_pca( + flat[i], + k=k, + eps=self.eps, + criterion=self.criterion, + alpha=self.alpha, + pca_cls=PCA, + ) + dims.append(d_i) + td[self.out_key] = torch.stack(dims).reshape(*batch_shape, N).to(device=device, dtype=dtype) + return td + + def __repr__(self): + fields = indent( + f"in_keys={self.in_keys},\nout_keys={self.out_keys},\nk={self.k},\n" + f"criterion={self.criterion!r},\nalpha={self.alpha},\neps={self.eps}", + 4 * " ", + ) + return f"{type(self).__name__}(\n{fields})" + + +def _local_pca( + data: torch.Tensor, + k: int, + eps: float, + criterion: Literal["maxgap", "ratio"], + alpha: float, + pca_cls: type, +) -> torch.Tensor: + """Compute per-point local dimension via PCA. data: (N, D). Returns (N,) dimension estimates.""" + sorted_dist, indices = sorted_neighbors(data, eps) + N, D = data.shape + dims = [] + for i in range(N): + valid_mask = torch.isfinite(sorted_dist[i]) + valid_indices = indices[i][valid_mask] + neighbor_idx = valid_indices[:k] + if len(neighbor_idx) < k: + dims.append(float("nan")) + continue + neighborhood = torch.cat([data[i : i + 1], data[neighbor_idx]], dim=0) + X = neighborhood.detach().cpu().double().numpy() + if X.shape[0] < 2: + dims.append(float("nan")) + continue + pca = pca_cls(n_components=None).fit(X) + lambda_ = pca.explained_variance_ + if len(lambda_) == 0: + dims.append(1.0) + continue + if criterion == "maxgap": + d = float(_dim_from_eigenvalues_maxgap(lambda_)) + elif criterion == "ratio": + d = float(_dim_from_eigenvalues_ratio(lambda_, alpha)) + else: + raise ValueError(f"Unknown criterion: {criterion!r}") + dims.append(d) + return torch.tensor(dims, device=data.device, dtype=torch.float32) + + +def _dim_from_eigenvalues_maxgap(lambda_: np.ndarray) -> int: + """Estimate dimension from eigenvalues using the maximum gap criterion :cite:`bruske1998intrinsic`. + + de = argmax(lambda[i]/lambda[i+1]) + 1 (1-based dimension). + """ + if len(lambda_) < 2: + return 1 + gaps = lambda_[:-1] / (lambda_[1:] + 1e-15) + return int(np.argmax(gaps) + 1) + + +def _dim_from_eigenvalues_ratio(lambda_: np.ndarray, alpha: float) -> int: + """Estimate dimension using ratio criterion :cite:`fukunaga1971algorithm`. + + Count eigenvalues above alpha * lambda[0]. Clamped to at least 1. + """ + if len(lambda_) == 0: + return 1 + threshold = alpha * lambda_[0] + de = int(np.sum(lambda_ > threshold)) + return max(1, de) diff --git a/src/tdhook/latent/dimension_estimation/twonn.py b/src/tdhook/latent/dimension_estimation/twonn.py index 78c0385..290b71d 100644 --- a/src/tdhook/latent/dimension_estimation/twonn.py +++ b/src/tdhook/latent/dimension_estimation/twonn.py @@ -4,7 +4,7 @@ from tensordict import TensorDict from tensordict.nn import TensorDictModuleBase -from ._utils import sorted_neighbor_distances +from ._utils import sorted_neighbors class TwoNnDimensionEstimator(TensorDictModuleBase): @@ -71,7 +71,7 @@ def _twonn(data: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor, Distances <= eps are treated as duplicates (excluded from nearest-neighbor selection). """ - sorted_dist = sorted_neighbor_distances(data, eps) + sorted_dist, _ = sorted_neighbors(data, eps) r1 = sorted_dist[:, 0] r2 = sorted_dist[:, 1] diff --git a/tests/latent/test_dimension_estimation.py b/tests/latent/test_dimension_estimation.py index f62d890..cf1bf9a 100644 --- a/tests/latent/test_dimension_estimation.py +++ b/tests/latent/test_dimension_estimation.py @@ -2,12 +2,26 @@ Tests for intrinsic dimension estimation. """ +import importlib.util +from unittest.mock import patch + +import numpy as np import pytest import torch from sklearn.metrics import r2_score from tensordict import TensorDict -from tdhook.latent.dimension_estimation import LocalKnnDimensionEstimator, TwoNnDimensionEstimator +from tdhook.latent.dimension_estimation import ( + CaPcaDimensionEstimator, + LocalKnnDimensionEstimator, + LocalPcaDimensionEstimator, + TwoNnDimensionEstimator, +) +from tdhook.latent.dimension_estimation.local_pca import ( + _dim_from_eigenvalues_maxgap, + _dim_from_eigenvalues_ratio, + _local_pca, +) @pytest.fixture @@ -240,3 +254,313 @@ def test_repr(self): assert "out_keys=['dimension']" in r assert "k=3" in r assert "eps=" in r + + +@pytest.fixture +def run_local_pca_estimator(): + torch.manual_seed(42) + + def _run(data, k=5, in_key="data", batch_size=None, **estimator_kwargs): + if batch_size is None: + batch_size = [] if data.ndim == 2 else data.shape[:-2] + td = TensorDict({in_key: data}, batch_size=batch_size) + return LocalPcaDimensionEstimator(k=k, in_key=in_key, **estimator_kwargs)(td) + + return _run + + +class TestLocalPcaDimensionEstimator: + """Test the LocalPcaDimensionEstimator class.""" + + def test_default_keys(self, run_local_pca_estimator): + """Test with default in_key and out_key.""" + data = torch.randn(50, 10) + result = run_local_pca_estimator(data, k=5) + assert "dimension" in result + assert result["dimension"].shape == (50,) + assert result["dimension"].dtype in (torch.float32, torch.float64) + valid = torch.isfinite(result["dimension"]) + assert valid.sum() > 0 + assert (result["dimension"][valid] >= 1).all() + + def test_custom_keys(self, run_local_pca_estimator): + """Test with custom in_key and out_key.""" + data = torch.randn(50, 8) + result = run_local_pca_estimator(data, k=5, in_key="linear2", out_key="intrinsic_dim") + assert "intrinsic_dim" in result + assert "linear2" in result + assert result["intrinsic_dim"].shape == (50,) + + def test_output_shape(self, run_local_pca_estimator): + """Test output shape (N,) for (N, D) input.""" + data = torch.randn(100, 5) + result = run_local_pca_estimator(data, k=5) + assert result["dimension"].shape == (100,) + + def test_known_dimension_2d_maxgap(self, run_local_pca_estimator, plane_data): + """Test on 2D manifold with maxgap criterion.""" + result = run_local_pca_estimator(plane_data, k=5, criterion="maxgap") + d = result["dimension"] + valid = torch.isfinite(d) + mean_d = d[valid].mean().item() + assert 1.0 < mean_d < 5.0 + + def test_known_dimension_2d_ratio(self, run_local_pca_estimator, plane_data): + """Test on 2D manifold with ratio criterion.""" + result = run_local_pca_estimator(plane_data, k=5, criterion="ratio") + d = result["dimension"] + valid = torch.isfinite(d) + mean_d = d[valid].mean().item() + assert 1.0 < mean_d < 5.0 + + def test_known_dimension_circle(self, run_local_pca_estimator, circle_data): + """Test on 1D manifold (circle) embedded in 2D.""" + result = run_local_pca_estimator(circle_data, k=5) + d = result["dimension"] + valid = torch.isfinite(d) + mean_d = d[valid].mean().item() + assert 0.5 < mean_d < 3.0 + + @pytest.mark.parametrize( + "shape", + [(1, 10, 8), (5, 10, 8), (2, 3, 10, 4)], + ids=["1x10x8", "5x10x8", "2x3x10x4"], + ) + def test_batch_shape_preservation(self, run_local_pca_estimator, shape): + """Test that (..., N, D) preserves batch shape, output is (..., N).""" + data = torch.randn(*shape) + batch_size = shape[:-2] + N = shape[-2] + result = run_local_pca_estimator(data, k=5, batch_size=batch_size) + assert result["dimension"].shape == (*batch_size, N) + + def test_too_few_points_raises(self, run_local_pca_estimator): + """Test that N < k+1 raises.""" + with pytest.raises(ValueError, match="At least k\\+1 points"): + run_local_pca_estimator(torch.randn(5, 5), k=5) # 5 < 5+1 + + def test_k_validation(self): + """Test that invalid k raises (k < 1 or wrong type).""" + with pytest.raises(ValueError, match="k must be at least 1"): + LocalPcaDimensionEstimator(k=0) + with pytest.raises(TypeError, match="k must be an int or 'auto'"): + LocalPcaDimensionEstimator(k=2.5) + + def test_determinism(self, run_local_pca_estimator): + """Test that same input yields same output.""" + torch.manual_seed(42) + data = torch.randn(80, 6) + r1 = run_local_pca_estimator(data.clone(), k=5)["dimension"] + r2 = run_local_pca_estimator(data.clone(), k=5)["dimension"] + assert torch.allclose(r1, r2, equal_nan=True) + + def test_k_auto(self, run_local_pca_estimator, plane_data): + """Test k='auto' uses n**0.5.""" + result = run_local_pca_estimator(plane_data, k="auto") + assert "dimension" in result + assert result["dimension"].shape == (100,) + + def test_repr(self): + """Test __repr__ includes class name, in_keys, out_keys, k, criterion, and eps.""" + est = LocalPcaDimensionEstimator(k=3, criterion="maxgap") + r = repr(est) + assert "LocalPcaDimensionEstimator" in r + assert "in_keys=['data']" in r + assert "out_keys=['dimension']" in r + assert "k=3" in r + assert "criterion='maxgap'" in r + assert "eps=" in r + + def test_sklearn_missing_raises(self): + """Test that ImportError is raised when sklearn is not installed.""" + with patch.object(importlib.util, "find_spec", return_value=None): + est = LocalPcaDimensionEstimator(k=5) + td = TensorDict({"data": torch.randn(20, 5)}, batch_size=[]) + with pytest.raises(ImportError, match="scikit-learn"): + est(td) + + def test_unknown_criterion_raises(self, run_local_pca_estimator): + """Test that unknown criterion raises ValueError.""" + data = torch.randn(20, 5) + est = LocalPcaDimensionEstimator(k=5, criterion="invalid") + td = TensorDict({"data": data}, batch_size=[]) + with pytest.raises(ValueError, match="Unknown criterion"): + est(td) + + def test_maxgap_single_eigenvalue(self, run_local_pca_estimator): + """Test maxgap with 1D data (single eigenvalue) returns 1.""" + data = torch.randn(15, 1) + result = run_local_pca_estimator(data, k=2, criterion="maxgap") + assert (result["dimension"] >= 1).all() + assert result["dimension"].shape == (15,) + + def test_ratio_empty_eigenvalues(self): + """Test ratio criterion with empty eigenvalues returns 1.""" + assert _dim_from_eigenvalues_ratio(np.array([]), alpha=0.05) == 1 + + def test_maxgap_few_eigenvalues(self): + """Test maxgap with single eigenvalue returns 1.""" + assert _dim_from_eigenvalues_maxgap(np.array([1.0])) == 1 + assert _dim_from_eigenvalues_maxgap(np.array([])) == 1 + + def test_constant_data_handled(self, run_local_pca_estimator): + """Test constant data (zero variance) does not crash. Returns NaN (no valid neighbors).""" + data = torch.ones(10, 5) + result = run_local_pca_estimator(data, k=2) + assert result["dimension"].shape == (10,) + assert torch.isnan(result["dimension"]).all() + + def test_few_neighborhood_points_returns_nan(self): + """Test that k=0 (single-point neighborhood) returns nan.""" + from sklearn.decomposition import PCA + + data = torch.randn(5, 3) + result = _local_pca(data, k=0, eps=1e-5, criterion="maxgap", alpha=0.05, pca_cls=PCA) + assert result.shape == (5,) + assert torch.isnan(result).all() + + def test_empty_eigenvalues_returns_one(self): + """Test that PCA returning empty eigenvalues yields 1.0.""" + + class MockPCA: + def __init__(self, n_components=None): + pass + + def fit(self, X): + self.explained_variance_ = np.array([]) + return self + + data = torch.randn(5, 3) + result = _local_pca(data, k=2, eps=1e-5, criterion="maxgap", alpha=0.05, pca_cls=MockPCA) + assert result.shape == (5,) + assert (result == 1.0).all() + + +@pytest.fixture +def run_ca_pca_estimator(): + torch.manual_seed(42) + + def _run(data, k=5, in_key="data", batch_size=None, **estimator_kwargs): + if batch_size is None: + batch_size = [] if data.ndim == 2 else data.shape[:-2] + td = TensorDict({in_key: data}, batch_size=batch_size) + return CaPcaDimensionEstimator(k=k, in_key=in_key, **estimator_kwargs)(td) + + return _run + + +class TestCaPcaDimensionEstimator: + """Test the CaPcaDimensionEstimator class.""" + + def test_default_keys(self, run_ca_pca_estimator): + """Test with default in_key and out_key.""" + data = torch.randn(50, 10) + result = run_ca_pca_estimator(data, k=5) + assert "dimension" in result + assert result["dimension"].shape == (50,) + assert result["dimension"].dtype in (torch.float32, torch.float64) + valid = torch.isfinite(result["dimension"]) + assert valid.sum() > 0 + assert (result["dimension"][valid] >= 1).all() + + def test_custom_keys(self, run_ca_pca_estimator): + """Test with custom in_key and out_key.""" + data = torch.randn(50, 8) + result = run_ca_pca_estimator(data, k=5, in_key="linear2", out_key="intrinsic_dim") + assert "intrinsic_dim" in result + assert "linear2" in result + assert result["intrinsic_dim"].shape == (50,) + + def test_output_shape(self, run_ca_pca_estimator): + """Test output shape (N,) for (N, D) input.""" + data = torch.randn(100, 5) + result = run_ca_pca_estimator(data, k=5) + assert result["dimension"].shape == (100,) + + def test_known_dimension_2d(self, run_ca_pca_estimator, plane_data): + """Test on 2D manifold embedded in higher space.""" + result = run_ca_pca_estimator(plane_data, k=5) + d = result["dimension"] + valid = torch.isfinite(d) + mean_d = d[valid].mean().item() + assert 1.0 < mean_d < 5.0 + + def test_known_dimension_circle(self, run_ca_pca_estimator, circle_data): + """Test on 1D manifold (circle) embedded in 2D.""" + result = run_ca_pca_estimator(circle_data, k=5) + d = result["dimension"] + valid = torch.isfinite(d) + mean_d = d[valid].mean().item() + assert 0.5 < mean_d < 3.0 + + @pytest.mark.parametrize( + "shape", + [(1, 10, 8), (5, 10, 8), (2, 3, 10, 4)], + ids=["1x10x8", "5x10x8", "2x3x10x4"], + ) + def test_batch_shape_preservation(self, run_ca_pca_estimator, shape): + """Test that (..., N, D) preserves batch shape, output is (..., N).""" + data = torch.randn(*shape) + batch_size = shape[:-2] + N = shape[-2] + result = run_ca_pca_estimator(data, k=5, batch_size=batch_size) + assert result["dimension"].shape == (*batch_size, N) + + def test_too_few_points_raises(self, run_ca_pca_estimator): + """Test that N < k+2 raises.""" + with pytest.raises(ValueError, match="At least k\\+2 points"): + run_ca_pca_estimator(torch.randn(6, 5), k=5) # 6 < 5+2 + + def test_k_validation(self): + """Test that invalid k raises (k < 1 or wrong type).""" + with pytest.raises(ValueError, match="k must be at least 1"): + CaPcaDimensionEstimator(k=0) + with pytest.raises(TypeError, match="k must be an int or 'auto'"): + CaPcaDimensionEstimator(k=2.5) + + def test_determinism(self, run_ca_pca_estimator): + """Test that same input yields same output.""" + torch.manual_seed(42) + data = torch.randn(80, 6) + r1 = run_ca_pca_estimator(data.clone(), k=5)["dimension"] + r2 = run_ca_pca_estimator(data.clone(), k=5)["dimension"] + assert torch.allclose(r1, r2, equal_nan=True) + + def test_k_auto(self, run_ca_pca_estimator, plane_data): + """Test k='auto' uses n**0.5.""" + result = run_ca_pca_estimator(plane_data, k="auto") + assert "dimension" in result + assert result["dimension"].shape == (100,) + + def test_repr(self): + """Test __repr__ includes class name, in_keys, out_keys, k, and eps.""" + est = CaPcaDimensionEstimator(k=3) + r = repr(est) + assert "CaPcaDimensionEstimator" in r + assert "in_keys=['data']" in r + assert "out_keys=['dimension']" in r + assert "k=3" in r + assert "eps=" in r + + def test_sklearn_missing_raises(self): + """Test that ImportError is raised when sklearn is not installed.""" + with patch.object(importlib.util, "find_spec", return_value=None): + est = CaPcaDimensionEstimator(k=5) + td = TensorDict({"data": torch.randn(20, 5)}, batch_size=[]) + with pytest.raises(ImportError, match="scikit-learn"): + est(td) + + def test_constant_data_handled(self, run_ca_pca_estimator): + """Test constant data (all duplicates) returns nan without crashing.""" + data = torch.ones(10, 5) + result = run_ca_pca_estimator(data, k=5) + assert result["dimension"].shape == (10,) + assert torch.isnan(result["dimension"]).all() + + def test_k1_uses_two_neighbors(self, run_ca_pca_estimator): + """Test k=1 uses k+1=2 neighbors and returns valid dimension estimates.""" + data = torch.randn(10, 5) + result = run_ca_pca_estimator(data, k=1) + assert result["dimension"].shape == (10,) + assert torch.isfinite(result["dimension"]).all() + assert (result["dimension"] >= 1).all() diff --git a/uv.lock b/uv.lock index 154c8bc..7ee46df 100644 --- a/uv.lock +++ b/uv.lock @@ -1361,6 +1361,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, ] +[[package]] +name = "ipywidgets" +version = "8.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "comm" }, + { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "jupyterlab-widgets" }, + { name = "traitlets" }, + { name = "widgetsnbextension" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/ae/c5ce1edc1afe042eadb445e95b0671b03cee61895264357956e61c0d2ac0/ipywidgets-8.1.8.tar.gz", hash = "sha256:61f969306b95f85fba6b6986b7fe45d73124d1d9e3023a8068710d47a22ea668", size = 116739, upload-time = "2025-11-01T21:18:12.393Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/6d/0d9848617b9f753b87f214f1c682592f7ca42de085f564352f10f0843026/ipywidgets-8.1.8-py3-none-any.whl", hash = "sha256:ecaca67aed704a338f88f67b1181b58f821ab5dc89c1f0f5ef99db43c1c2921e", size = 139808, upload-time = "2025-11-01T21:18:10.956Z" }, +] + [[package]] name = "jaxtyping" version = "0.3.2" @@ -1472,6 +1489,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b1/dd/ead9d8ea85bf202d90cc513b533f9c363121c7792674f78e0d8a854b63b4/jupyterlab_pygments-0.3.0-py3-none-any.whl", hash = "sha256:841a89020971da1d8693f1a99997aefc5dc424bb1b251fd6322462a1b8842780", size = 15884, upload-time = "2023-11-23T09:26:34.325Z" }, ] +[[package]] +name = "jupyterlab-widgets" +version = "3.0.16" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/26/2d/ef58fed122b268c69c0aa099da20bc67657cdfb2e222688d5731bd5b971d/jupyterlab_widgets-3.0.16.tar.gz", hash = "sha256:423da05071d55cf27a9e602216d35a3a65a3e41cdf9c5d3b643b814ce38c19e0", size = 897423, upload-time = "2025-11-01T21:11:29.724Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/b5/36c712098e6191d1b4e349304ef73a8d06aed77e56ceaac8c0a306c7bda1/jupyterlab_widgets-3.0.16-py3-none-any.whl", hash = "sha256:45fa36d9c6422cf2559198e4db481aa243c7a32d9926b500781c830c80f7ecf8", size = 914926, upload-time = "2025-11-01T21:11:28.008Z" }, +] + [[package]] name = "kiwisolver" version = "1.4.8" @@ -4102,6 +4128,7 @@ dev = [ { name = "zennit", version = "1.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11.11'" }, ] docs = [ + { name = "ipywidgets" }, { name = "nbsphinx" }, { name = "pandoc" }, { name = "plotly" }, @@ -4115,6 +4142,7 @@ docs = [ ] notebooks = [ { name = "ipykernel" }, + { name = "ipywidgets" }, ] scripts = [ { name = "captum" }, @@ -4152,6 +4180,7 @@ dev = [ { name = "zennit", specifier = ">=0.5.1" }, ] docs = [ + { name = "ipywidgets", specifier = ">=8.0.0" }, { name = "nbsphinx", specifier = ">=0.9.6" }, { name = "pandoc", specifier = ">=2.4" }, { name = "plotly", specifier = ">=5.24.1" }, @@ -4163,7 +4192,10 @@ docs = [ { name = "sphinx-design", specifier = ">=0.6.1" }, { name = "sphinxcontrib-bibtex", specifier = ">=2.6.5" }, ] -notebooks = [{ name = "ipykernel", specifier = ">=6.29.5" }] +notebooks = [ + { name = "ipykernel", specifier = ">=6.29.5" }, + { name = "ipywidgets", specifier = ">=8.0.0" }, +] scripts = [ { name = "captum", specifier = ">=0.8.0" }, { name = "datasets", specifier = ">=4.0.0" }, @@ -4692,6 +4724,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/84/44687a29792a70e111c5c477230a72c4b957d88d16141199bf9acb7537a3/websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526", size = 58826, upload-time = "2024-04-23T22:16:14.422Z" }, ] +[[package]] +name = "widgetsnbextension" +version = "4.0.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bd/f4/c67440c7fb409a71b7404b7aefcd7569a9c0d6bd071299bf4198ae7a5d95/widgetsnbextension-4.0.15.tar.gz", hash = "sha256:de8610639996f1567952d763a5a41af8af37f2575a41f9852a38f947eb82a3b9", size = 1097402, upload-time = "2025-11-01T21:15:55.178Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl", hash = "sha256:8156704e4346a571d9ce73b84bee86a29906c9abfd7223b7228a28899ccf3366", size = 2196503, upload-time = "2025-11-01T21:15:53.565Z" }, +] + [[package]] name = "win32-setctime" version = "1.2.0"