From 813c88fe6b5f7688d6ac1431f22c5c9ab12d43f3 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 29 May 2025 17:49:17 -0400 Subject: [PATCH] Cleaning and updates --- R/aux_functions.R | 34 +++++++++++++++++++--------------- R/dist_functions.R | 5 ++--- R/get_weights_from_mm.R | 10 +++++----- R/get_weights_from_subclass.R | 16 ++++++++-------- R/input_processing.R | 2 +- R/match.qoi.R | 6 +++--- R/match_data.R | 6 +++--- R/summary.matchit.R | 25 ++++++++++++------------- src/nn_matchC_distmat.cpp | 8 ++++++-- src/nn_matchC_mahcovs.cpp | 9 +++++++-- src/nn_matchC_vec.cpp | 9 +++++++-- src/subclass_scootC.cpp | 13 ++++++------- vignettes/matching-methods.Rmd | 2 +- 13 files changed, 80 insertions(+), 65 deletions(-) diff --git a/R/aux_functions.R b/R/aux_functions.R index 141f8b15..f9a91631 100644 --- a/R/aux_functions.R +++ b/R/aux_functions.R @@ -1,6 +1,6 @@ #Function to ensure no subclass is devoid of both treated and control units by "scooting" units #from other subclasses. -subclass_scoot <- function(sub, treat, x, min.n = 1) { +subclass_scoot <- function(sub, treat, x, min.n = 1L) { #Reassigns subclasses so there are no empty subclasses #for each treatment group. subtab <- table(treat, sub) @@ -45,7 +45,9 @@ info_to_method <- function(info) { out.list[["kto1"]] <- { if (is_null(info$ratio)) NULL - else paste0(if (is_not_null(info$max.controls)) "variable ratio ", round(info$ratio, 2L), ":1") + else sprintf("%s%s:1", + if (is_not_null(info$max.controls)) "variable ratio " else "", + round(info$ratio, 2L)) } out.list[["type"]] <- { @@ -84,7 +86,7 @@ info_to_distance <- function(info) { linear <- FALSE } - dist <- switch(distance, + .dist <- switch(distance, "glm" = switch(link, "logit" = "logistic regression", "probit" = "probit regression", @@ -105,10 +107,10 @@ info_to_distance <- function(info) { "randomforest" = "a random forest") if (linear) { - dist <- paste(dist, "and linearized") + .dist <- sprintf("%s and linearized", .dist) } - dist + .dist } #Make interaction vector out of matrix of covs; similar to interaction() @@ -124,7 +126,7 @@ exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE, justify = " stop("X must be a matrix, data frame, or list.") } - X <- X[lengths(X) > 0] + X <- X[lengths(X) > 0L] if (is_null(X)) { return(NULL) @@ -329,11 +331,11 @@ ESS <- function(w) { nn <- function(treat, weights, discarded = NULL, s.weights = NULL) { if (is_null(discarded)) { - discarded <- rep.int(FALSE, length(treat)) + discarded <- rep_with(FALSE, treat) } if (is_null(s.weights)) { - s.weights <- rep.int(1, length(treat)) + s.weights <- rep_with(1, treat) } weights <- weights * s.weights @@ -343,13 +345,15 @@ nn <- function(treat, weights, discarded = NULL, s.weights = NULL) { "Matched", "Unmatched", "Discarded"), c("Control", "Treated"))) - # Control Treated - n["All (ESS)", ] <- c(ESS(s.weights[treat == 0]), ESS(s.weights[treat == 1])) - n["All", ] <- c(sum(treat == 0), sum(treat == 1)) - n["Matched (ESS)", ] <- c(ESS(weights[treat == 0]), ESS(weights[treat == 1])) - n["Matched", ] <- c(sum(treat == 0 & weights > 0), sum(treat == 1 & weights > 0)) - n["Unmatched", ] <- c(sum(treat == 0 & weights == 0 & !discarded), sum(treat == 1 & weights == 0 & !discarded)) - n["Discarded", ] <- c(sum(treat == 0 & discarded), sum(treat == 1 & discarded)) + t1 <- treat == 1 + + # Control Treated + n["All (ESS)", ] <- c(ESS(s.weights[!t1]), ESS(s.weights[t1])) + n["All", ] <- c(sum(!t1), sum(t1)) + n["Matched (ESS)", ] <- c(ESS(weights[!t1]), ESS(weights[t1])) + n["Matched", ] <- c(sum(!t1 & weights > 0), sum(t1 & weights > 0)) + n["Unmatched", ] <- c(sum(!t1 & weights == 0 & !discarded), sum(t1 & weights == 0 & !discarded)) + n["Discarded", ] <- c(sum(!t1 & discarded), sum(t1 & discarded)) n } diff --git a/R/dist_functions.R b/R/dist_functions.R index e8eccc43..5f62322d 100644 --- a/R/dist_functions.R +++ b/R/dist_functions.R @@ -418,12 +418,11 @@ get_covs_matrix_for_dist <- function(formula = NULL, data = NULL) { function(x) contrasts(x, contrasts = FALSE) / sqrt(2))) if (ncol(X) > 1L) { - assign <- attr(X, "assign")[-1L] + .assign <- attr(X, "assign")[-1L] X <- X[, -1L, drop = FALSE] + attr(X, "assign") <- .assign } - attr(X, "assign") <- assign - attr(X, "treat") <- model.response(mf) X diff --git a/R/get_weights_from_mm.R b/R/get_weights_from_mm.R index 699de9e3..49afbdf3 100644 --- a/R/get_weights_from_mm.R +++ b/R/get_weights_from_mm.R @@ -4,19 +4,19 @@ get_weights_from_mm <- function(match.matrix, treat, focal = NULL) { match.matrix <- charmm2nummm(match.matrix, treat) } - weights <- weights_matrixC(match.matrix, treat, focal) + w <- weights_matrixC(match.matrix, treat, focal) - if (all_equal_to(weights, 0)) { + if (all_equal_to(w, 0)) { .err("no units were matched") } - if (all_equal_to(weights[treat == 1], 0)) { + if (all_equal_to(w[treat == 1], 0)) { .err("no treated units were matched") } - if (all_equal_to(weights[treat == 0], 0)) { + if (all_equal_to(w[treat == 0], 0)) { .err("no control units were matched") } - setNames(weights, names(treat)) + setNames(w, names(treat)) } \ No newline at end of file diff --git a/R/get_weights_from_subclass.R b/R/get_weights_from_subclass.R index 84892a75..de84cde6 100644 --- a/R/get_weights_from_subclass.R +++ b/R/get_weights_from_subclass.R @@ -16,7 +16,7 @@ get_weights_from_subclass <- function(subclass, treat, estimand = "ATT") { .err("No control units were matched") } - weights <- rep_with(0, treat) + w <- rep_with(0, treat) if (!is.factor(subclass)) { subclass <- factor(subclass, nmax = min(length(i1), length(i0))) @@ -28,19 +28,19 @@ get_weights_from_subclass <- function(subclass, treat, estimand = "ATT") { subclass <- unclass(subclass) if (estimand == "ATT") { - weights[i1] <- 1 - weights[i0] <- (treated_by_sub / control_by_sub)[subclass[i0]] + w[i1] <- 1 + w[i0] <- (treated_by_sub / control_by_sub)[subclass[i0]] } else if (estimand == "ATC") { - weights[i1] <- (control_by_sub / treated_by_sub)[subclass[i1]] - weights[i0] <- 1 + w[i1] <- (control_by_sub / treated_by_sub)[subclass[i1]] + w[i0] <- 1 } else if (estimand == "ATE") { - weights[i1] <- 1 + (control_by_sub / treated_by_sub)[subclass[i1]] - weights[i0] <- 1 + (treated_by_sub / control_by_sub)[subclass[i0]] + w[i1] <- 1 + (control_by_sub / treated_by_sub)[subclass[i1]] + w[i0] <- 1 + (treated_by_sub / control_by_sub)[subclass[i0]] } - weights + w } # get_weights_from_subclass2 <- function(subclass, treat, estimand = "ATT") { diff --git a/R/input_processing.R b/R/input_processing.R index 1bfc4377..097419e1 100644 --- a/R/input_processing.R +++ b/R/input_processing.R @@ -255,7 +255,7 @@ process.distance <- function(distance, method = NULL, treat) { attr(distance, "link") <- link } else if (tolower(distance) %in% tolower(c("GAMcloglog", "GAMlog", "GAMlogit", "GAMprobit"))) { - link <- tolower(substr(distance, 4, nchar(distance))) + link <- tolower(substr(distance, 4L, nchar(distance))) .wrn(sprintf('`distance = "%s"` will be deprecated; please use `distance = "gam", link = "%s"` in the future', distance, link)) diff --git a/R/match.qoi.R b/R/match.qoi.R index 6dda0043..f02eaf4a 100644 --- a/R/match.qoi.R +++ b/R/match.qoi.R @@ -20,7 +20,7 @@ bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, i1 <- which(tt == 1) i0 <- which(tt == 0) - too.small <- sum(ww[i1] != 0) < 2 && sum(ww[i0] != 0) < 2 + too.small <- sum(ww[i1] != 0) < 2L && sum(ww[i0] != 0) < 2L xsum["Means Treated"] <- wm(xx[i1], ww[i1], na.rm = TRUE) xsum["Means Control"] <- wm(xx[i0], ww[i0], na.rm = TRUE) @@ -215,7 +215,7 @@ qqsum <- function(x, t, w = NULL, standardize = FALSE) { wn0 <- length(w0) if (wn1 < wn0) { - if (length(u) <= 5) { + if (length(u) <= 5L) { x0probs <- vapply(u, function(u_) wm(x0 == u_, w0), numeric(1L)) x0cumprobs <- c(0, .cumsum_prob(x0probs)) x0 <- u[findInterval(.cumsum_prob(w1), x0cumprobs, rightmost.closed = TRUE)] @@ -226,7 +226,7 @@ qqsum <- function(x, t, w = NULL, standardize = FALSE) { } } else if (wn1 > wn0) { - if (length(u) <= 5) { + if (length(u) <= 5L) { x1probs <- vapply(u, function(u_) wm(x1 == u_, w1), numeric(1L)) x1cumprobs <- c(0, .cumsum_prob(x1probs)) x1 <- u[findInterval(.cumsum_prob(w0), x1cumprobs, rightmost.closed = TRUE)] diff --git a/R/match_data.R b/R/match_data.R index a02d52e8..8dd40f72 100644 --- a/R/match_data.R +++ b/R/match_data.R @@ -191,13 +191,13 @@ match_data <- function(object, data.found <- FALSE for (i in 1:4) { - if (i == 2) { + if (i == 2L) { data <- try(eval(object$call$data, envir = environment(object$formula)), silent = TRUE) } - else if (i == 3) { + else if (i == 3L) { data <- try(eval(object$call$data, envir = parent.frame()), silent = TRUE) } - else if (i == 4) { + else if (i == 4L) { data <- object[["model"]][["data"]] } diff --git a/R/summary.matchit.R b/R/summary.matchit.R index 412ab855..d194568a 100644 --- a/R/summary.matchit.R +++ b/R/summary.matchit.R @@ -231,7 +231,7 @@ summary.matchit <- function(object, #Remove tics has_tics <- which(startsWith(nam, "`") & endsWith(nam, "`")) - nam[has_tics] <- substr(nam[has_tics], 2, nchar(nam[has_tics]) - 1) + nam[has_tics] <- substr(nam[has_tics], 2L, nchar(nam[has_tics]) - 1L) } kk <- ncol(X) @@ -289,7 +289,7 @@ summary.matchit <- function(object, to.remove <- rep.int(FALSE, n.int) int.names <- character(n.int) - k <- 1 + k <- 1L for (i in seq_len(kk)) { for (j in i:kk) { x2 <- X[, i] * X[, j] @@ -316,7 +316,7 @@ summary.matchit <- function(object, int.names[k] <- paste(nam[i], nam[j], sep = " * ") } } - k <- k + 1 + k <- k + 1L } } @@ -360,17 +360,17 @@ summary.matchit <- function(object, ## Imbalance Reduction if (matched && un && improvement) { - reduction <- matrix(NA_real_, nrow = nrow(sum.all), ncol = ncol(sum.all) - 2, + reduction <- matrix(NA_real_, nrow = nrow(sum.all), ncol = ncol(sum.all) - 2L, dimnames = list(rownames(sum.all), colnames(sum.all)[-(1:2)])) stat.all <- abs(sum.all[, -(1:2), drop = FALSE]) stat.matched <- abs(sum.matched[, -(1:2), drop = FALSE]) #Everything but variance ratios - reduction[, -2] <- 100 * (stat.all[, -2] - stat.matched[, -2]) / stat.all[, -2] + reduction[, -2L] <- 100 * (stat.all[, -2L] - stat.matched[, -2L]) / stat.all[, -2L] #Just variance ratios; turn to log first - vr.all <- abs(log(stat.all[, 2])) - vr.matched <- abs(log(stat.matched[, 2])) + vr.all <- abs(log(stat.all[, 2L])) + vr.matched <- abs(log(stat.matched[, 2L])) reduction[, 2] <- 100 * (vr.all - vr.matched) / vr.all reduction[stat.all == 0 & stat.matched == 0] <- 0 @@ -595,7 +595,7 @@ summary.matchit.subclass <- function(object, int.names[k] <- paste(nam[i], nam[j], sep = " * ") } } - k <- k + 1 + k <- k + 1L } } rownames(sum.sub.int) <- int.names @@ -644,19 +644,19 @@ print.summary.matchit <- function(x, digits = max(3, getOption("digits") - 3), if (is_not_null(x$sum.all)) { cat("\nSummary of Balance for All Data:\n") - print(round_df_char(x$sum.all[, -7, drop = FALSE], digits, pad = "0", na_vals = "."), + print(round_df_char(x$sum.all[, -7L, drop = FALSE], digits, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } if (is_not_null(x$sum.matched)) { cat("\nSummary of Balance for Matched Data:\n") - if (all(is.na(x$sum.matched[, 7]))) x$sum.matched <- x$sum.matched[, -7, drop = FALSE] #Remove pair dist if empty + if (all(is.na(x$sum.matched[, 7L]))) x$sum.matched <- x$sum.matched[, -7L, drop = FALSE] #Remove pair dist if empty print(round_df_char(x$sum.matched, digits, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } if (is_not_null(x$reduction)) { cat("\nPercent Balance Improvement:\n") - print(round_df_char(x$reduction[, -5, drop = FALSE], 1, pad = "0", na_vals = "."), right = TRUE, + print(round_df_char(x$reduction[, -5L, drop = FALSE], 1, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } if (is_not_null(x$nn)) { @@ -737,7 +737,7 @@ print.summary.matchit.subclass <- function(x, digits = max(3L, getOption("digits .process_X <- function(object, addlvariables = NULL, data = NULL) { X <- { - if (is_null(object$X)) matrix(nrow = length(object$treat), ncol = 0) + if (is_null(object$X)) matrix(nrow = length(object$treat), ncol = 0L) else get_covs_matrix(data = object$X) } @@ -827,5 +827,4 @@ print.summary.matchit.subclass <- function(x, digits = max(3L, getOption("digits # addl_assign <- get_assign(addlvariables) cbind(X, addlvariables[, setdiff(colnames(addlvariables), colnames(X)), drop = FALSE]) - } diff --git a/src/nn_matchC_distmat.cpp b/src/nn_matchC_distmat.cpp index bc00b30b..67749105 100644 --- a/src/nn_matchC_distmat.cpp +++ b/src/nn_matchC_distmat.cpp @@ -167,8 +167,9 @@ IntegerMatrix nn_matchC_distmat(const IntegerVector& treat_, for (r = 1; r <= max_ratio; r++) { ord_r = ord[as(ratio[ord - 1]) >= r]; + ord_r = ord_r - 1; - for (int t_id_t_i : ord_r - 1) { + for (int t_id_t_i : ord_r) { // t_id_t_i; index of treated unit to match among treated units // t_id_i: index of treated unit to match among all units counter++; @@ -254,7 +255,8 @@ IntegerMatrix nn_matchC_distmat(const IntegerVector& treat_, } } else { - for (int t_id_t_i : ord - 1) { + int t_id_t_i; + for (int t_id_t_i_ : ord) { // t_id_t_i; index of treated unit to match among treated units // t_id_i: index of treated unit to match among all units counter++; @@ -263,6 +265,8 @@ IntegerMatrix nn_matchC_distmat(const IntegerVector& treat_, Rcpp::checkUserInterrupt(); } + t_id_t_i = t_id_t_i_ - 1; + t_id_i = ind_focal[t_id_t_i]; p.increment(); diff --git a/src/nn_matchC_mahcovs.cpp b/src/nn_matchC_mahcovs.cpp index 66d510d8..2993c295 100644 --- a/src/nn_matchC_mahcovs.cpp +++ b/src/nn_matchC_mahcovs.cpp @@ -200,8 +200,9 @@ IntegerMatrix nn_matchC_mahcovs(const IntegerVector& treat_, for (r = 1; r <= max_ratio; r++) { ord_r = ord[as(ratio[ord - 1]) >= r]; + ord_r = ord_r - 1; - for (int t_id_t_i : ord_r - 1) { + for (int t_id_t_i : ord_r) { // t_id_t_i; index of treated unit to match among treated units // t_id_i: index of treated unit to match among all units counter++; @@ -292,7 +293,9 @@ IntegerMatrix nn_matchC_mahcovs(const IntegerVector& treat_, } } else { - for (int t_id_t_i : ord - 1) { + int t_id_t_i; + + for (int t_id_t_i_ : ord) { // t_id_t_i; index of treated unit to match among treated units // t_id_i: index of treated unit to match among all units counter++; @@ -301,6 +304,8 @@ IntegerMatrix nn_matchC_mahcovs(const IntegerVector& treat_, Rcpp::checkUserInterrupt(); } + t_id_t_i = t_id_t_i_ - 1; + t_id_i = ind_focal[t_id_t_i]; p.increment(); diff --git a/src/nn_matchC_vec.cpp b/src/nn_matchC_vec.cpp index 590af386..e32537d6 100644 --- a/src/nn_matchC_vec.cpp +++ b/src/nn_matchC_vec.cpp @@ -188,8 +188,9 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, for (r = 1; r <= max_ratio; r++) { ord_r = ord[as(ratio[ord - 1]) >= r]; + ord_r = ord_r - 1; - for (int t_id_t_i : ord_r - 1) { + for (int t_id_t_i : ord_r) { // t_id_t_i; index of treated unit to match among treated units // t_id_i: index of treated unit to match among all units counter++; @@ -285,7 +286,9 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, } } else { - for (int t_id_t_i : ord - 1) { + int t_id_t_i; + + for (int t_id_t_i_ : ord) { // t_id_t_i; index of treated unit to match among treated units // t_id_i: index of treated unit to match among all units counter++; @@ -294,6 +297,8 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, Rcpp::checkUserInterrupt(); } + t_id_t_i = t_id_t_i_ - 1; + t_id_i = ind_focal[t_id_t_i]; p.increment(); diff --git a/src/subclass_scootC.cpp b/src/subclass_scootC.cpp index 67c7669d..99404437 100644 --- a/src/subclass_scootC.cpp +++ b/src/subclass_scootC.cpp @@ -100,13 +100,12 @@ IntegerVector subclass_scootC(const IntegerVector& subclass_, } //Find unit with closest x in that subclass to take - for (i = 0; i < nt; i++) { - if (subclass[indt[i]] == s2) { - best_i = i; - best_x = x[indt[i]]; - break; - } - } + best_i = std::distance(indt.begin(), + std::find_if(indt.begin(), indt.end(), + [&subclass, &s2](int a){ + return subclass[a] == s2; + })); + best_x = x[indt[best_i]]; for (i = best_i + 1; i < nt; i++) { if (subclass[indt[i]] != s2) { diff --git a/vignettes/matching-methods.Rmd b/vignettes/matching-methods.Rmd index 691458b8..10892d20 100644 --- a/vignettes/matching-methods.Rmd +++ b/vignettes/matching-methods.Rmd @@ -161,7 +161,7 @@ To perform exact matching on all supplied covariates, the `method` argument can ### Anti-exact matching (`antiexact`) -Anti-exact matching adds a restriction such that a treated and control unit with same values of any of the specified anti-exact matching variables cannot be paired. This can be useful when finding comparison units outside of a unit's group, such as when matching units in one group to units in another when units within the same group might otherwise be close matches. See examples [here](https://stackoverflow.com/questions/66526115/propensity-score-matching-with-panel-data) and [here](https://stackoverflow.com/questions/61120201/avoiding-duplicates-from-propensity-score-matching?rq=1). A similar effect can be implemented by supplying negative caliper values. +Anti-exact matching adds a restriction such that a treated and control unit with same values of any of the specified anti-exact matching variables cannot be paired. This can be useful when finding comparison units outside of a unit's group, such as when matching units in one group to units in another when units within the same group might otherwise be close matches. See examples [here](https://stackoverflow.com/q/66526115/6348551) and [here](https://stackoverflow.com/q/61120201/6348551). A similar effect can be implemented by supplying negative caliper values. ### Matching with replacement (`replace`)