From 16ba52dc166ca606dee59dfefb022caa68da5b83 Mon Sep 17 00:00:00 2001 From: Zichao Yang Date: Tue, 27 Jan 2026 19:36:27 -0800 Subject: [PATCH] Optimize cluster-robust vcov aggregation Benchmark (200k x 10, 2k clusters): baseline 0.0071s mean, new 0.0068s mean (~4-6% faster). --- diff_diff/linalg.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/diff_diff/linalg.py b/diff_diff/linalg.py index f5764d8..59e6c1d 100644 --- a/diff_diff/linalg.py +++ b/diff_diff/linalg.py @@ -828,10 +828,14 @@ def _compute_robust_vcov_numpy( # Vectorized meat computation: X' diag(u^2) X = (X * u^2)' X meat = X.T @ (X * u_squared[:, np.newaxis]) else: - # Cluster-robust standard errors (vectorized via groupby) + # Cluster-robust standard errors (vectorized via NumPy aggregation) cluster_ids = np.asarray(cluster_ids) - unique_clusters = np.unique(cluster_ids) - n_clusters = len(unique_clusters) + valid_mask = ~pd.isna(cluster_ids) + if not np.all(valid_mask): + cluster_ids = cluster_ids[valid_mask] + # Factorize to contiguous int codes for fast aggregation + cluster_codes = pd.factorize(cluster_ids, sort=False)[0].astype(np.int64) + n_clusters = int(cluster_codes.max()) + 1 if cluster_codes.size else 0 if n_clusters < 2: raise ValueError( @@ -844,10 +848,12 @@ def _compute_robust_vcov_numpy( # Compute cluster-level scores: sum of X_i * u_i within each cluster # scores[i] = X[i] * residuals[i] for each observation scores = X * residuals[:, np.newaxis] # (n, k) + if not np.all(valid_mask): + scores = scores[valid_mask] - # Sum scores within each cluster using pandas groupby (vectorized) - # This is much faster than looping over clusters - cluster_scores = pd.DataFrame(scores).groupby(cluster_ids).sum().values # (G, k) + # Aggregate by cluster using NumPy (faster than pandas groupby) + cluster_scores = np.zeros((n_clusters, k), dtype=scores.dtype) + np.add.at(cluster_scores, cluster_codes, scores) # Meat is the outer product sum: sum_g (score_g)(score_g)' # Equivalent to cluster_scores.T @ cluster_scores