-
Notifications
You must be signed in to change notification settings - Fork 21
Prediction with lm_lin() fixes #415 #416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
31e477a
a1734ba
bedcc7e
4b53e99
05c9920
c1da381
59d612b
4c9b28f
9e24ee8
3d0266c
8ee4040
d4bd29b
b824e87
121cb5f
3f14afb
8f229d6
e31c6ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,6 +61,27 @@ | |
| #' new_dat$w <- runif(n) | ||
| #' predict(lm_out, newdata = new_dat, weights = w, interval = "prediction") | ||
| #' | ||
| #' # Works for 'lm_lin' models as well | ||
| #' dat$z <- sample(1:3, size = nrow(dat), replace = TRUE) | ||
| #' lmlin_out1 <- lm_lin(y ~ z, covariates = ~ x, data = dat) | ||
| #' predict(lmlin_out1, newdata = dat, interval = "prediction") | ||
| #' | ||
| #' # Predictions from Lin models are equivalent with and without an intercept | ||
| #' # and for multi-level treatments entered as numeric or factor variables | ||
| #' lmlin_out2 <- lm_lin(y ~ z - 1, covariates = ~ x, data = dat) | ||
| #' lmlin_out3 <- lm_lin(y ~ factor(z), covariates = ~ x, data = dat) | ||
| #' lmlin_out4 <- lm_lin(y ~ factor(z) - 1, covariates = ~ x, data = dat) | ||
| #' | ||
| #' predict(lmlin_out2, newdata = dat, interval = "prediction") | ||
| #' predict(lmlin_out3, newdata = dat, interval = "prediction") | ||
| #' predict(lmlin_out4, newdata = dat, interval = "prediction") | ||
| #' | ||
| #' # In Lin models, predict will stop with an error message if new | ||
| #' # treatment levels are supplied in the new data | ||
| #' new_dat$z <- sample(0:3, size = nrow(new_dat), replace = TRUE) | ||
| #' # predict(lmlin_out, newdata = new_dat) | ||
| #' | ||
| #' | ||
| #' @export | ||
| predict.lm_robust <- function(object, | ||
| newdata, | ||
|
|
@@ -74,30 +95,6 @@ predict.lm_robust <- function(object, | |
|
|
||
| X <- get_X(object, newdata, na.action) | ||
|
|
||
| # lm_lin scaling | ||
| if (!is.null(object$scaled_center)) { | ||
| demeaned_covars <- | ||
| scale( | ||
| X[ | ||
| , | ||
| names(object$scaled_center), | ||
| drop = FALSE | ||
| ], | ||
| center = object$scaled_center, | ||
| scale = FALSE | ||
| ) | ||
|
|
||
| # Interacted with treatment | ||
| treat_name <- attr(object$terms, "term.labels")[1] | ||
| interacted_covars <- X[, treat_name] * demeaned_covars | ||
|
Comment on lines
-90
to
-92
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not have the desired behavior when there are multiple treatment levels. |
||
|
|
||
| X <- cbind( | ||
| X[, attr(X, "assign") <= 1, drop = FALSE], | ||
| demeaned_covars, | ||
| interacted_covars | ||
| ) | ||
| } | ||
|
|
||
| # Get coefs | ||
| coefs <- as.matrix(coef(object)) | ||
|
|
||
|
|
@@ -224,9 +221,64 @@ get_X <- function(object, newdata, na.action) { | |
|
|
||
| X <- model.matrix(rhs_terms, mf, contrasts.arg = object$contrasts) | ||
|
|
||
| # lm_lin scaling (moved down from predict.lm_robust) | ||
| if (!is.null(object$scaled_center)) { | ||
| # Covariates | ||
| demeaned_covars <- | ||
| scale( | ||
| X[ | ||
| , | ||
| names(object$scaled_center), | ||
| drop = FALSE | ||
| ], | ||
| center = object$scaled_center, | ||
| scale = FALSE | ||
| ) | ||
|
|
||
| # Handle treatment variable reconstruction | ||
| treat_name <- attr(object$terms, "term.labels")[1] | ||
| treatment <- mf[, treat_name] | ||
| vals <- sort(unique(treatment)) | ||
| old_vals <- object$treatment_levels | ||
|
|
||
| # Ensure treatment levels in newdata are subset of those for model fit | ||
| if (!all(as.character(vals) %in% as.character(old_vals))) { | ||
| stop( | ||
| "Levels of treatment variable in `newdata` must be a subset of those ", | ||
| "in the model fit." | ||
| ) | ||
| } | ||
| treatment <- model.matrix(~ factor(treatment, levels = old_vals) - 1) | ||
|
|
||
| colnames(treatment) <- paste0(treat_name, "_", old_vals) | ||
| # Drop out first group if there is an intercept | ||
| if (attr(rhs_terms, "intercept") == 1) treatment <- treatment[, -1, drop = FALSE] | ||
|
|
||
| # Interactions matching original fitting logic | ||
| n_treat_cols <- ncol(treatment) | ||
| n_covars <- ncol(demeaned_covars) | ||
|
|
||
| interaction_matrix <- matrix(0, nrow = nrow(X), ncol = n_covars * n_treat_cols) | ||
|
|
||
| for (i in 1:n_covars) { | ||
| cols <- (i - 1) * n_treat_cols + (1:n_treat_cols) | ||
| interaction_matrix[, cols] <- treatment * demeaned_covars[, i] | ||
| } | ||
|
|
||
| X <- cbind( | ||
| if (attr(rhs_terms, "intercept") == 1) { | ||
| matrix(1, nrow = nrow(X), ncol = 1, dimnames = list(NULL, "(Intercept)")) | ||
| }, | ||
| treatment, | ||
| if (attr(rhs_terms, "intercept") == 1 || ncol(treatment) == 1) demeaned_covars, | ||
| interaction_matrix | ||
| ) | ||
| } | ||
|
|
||
| return(X) | ||
| } | ||
|
|
||
|
|
||
| add_fes <- function(preds, object, newdata) { | ||
|
|
||
| # Add factors! | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -83,7 +83,7 @@ | |
| #' \item{weighted}{whether or not weights were applied} | ||
| #' \item{call}{the original function call} | ||
| #' \item{fitted.values}{the matrix of predicted means} | ||
| #' We also return \code{terms} and \code{contrasts}, used by \code{predict}, | ||
| #' We also return \code{terms}, \code{contrasts}, and \code{treatment_levels}, used by \code{predict}, | ||
| #' and \code{scaled_center} (the means of each of the covariates used for centering them). | ||
| #' | ||
| #' @seealso \code{\link{lm_robust}} | ||
|
|
@@ -127,21 +127,36 @@ | |
| #' | ||
| #' lm_lin(y ~ z_clust, covariates = ~ x, data = dat, clusters = clusterID) | ||
| #' | ||
| #' # Works with multi-valued treatments | ||
| #' # Works with multi-valued treatments, whether treatment is specified as a | ||
| #' # factor or not | ||
| #' dat$z_multi <- sample(1:3, size = nrow(dat), replace = TRUE) | ||
| #' | ||
| #' lm_lin(y ~ z_multi, covariates = ~ x, data = dat) | ||
| #' lm_lin(y ~ factor(z_multi), covariates = ~ x, data = dat) | ||
| #' | ||
| #' # Stratified estimator with blocks | ||
| #' dat$blockID <- rep(1:5, each = 8) | ||
| #' dat$z_block <- block_ra(blocks = dat$blockID) | ||
| #' | ||
| #' lm_lin(y ~ z_block, ~ factor(blockID), data = dat) | ||
| #' | ||
| #' # Fitting the model without an intercept provides estimates of mean outcomes | ||
| #' # under each respective treatment condition | ||
| #' lm_lin(y ~ z_multi - 1, covariates = ~ x, data = dat) | ||
| #' | ||
| #' # Predictions are the same in equivalent models with and without an intercept | ||
| #' lmlin_out3 <- lm_lin(y ~ z_multi - 1, covariates = ~ x, data = dat) | ||
| #' lmlin_out4 <- lm_lin(y ~ z_multi, covariates = ~ x, data = dat) | ||
| #' | ||
| #' predict(lmlin_out3, newdata = dat, se.fit = TRUE, interval = "confidence") | ||
| #' predict(lmlin_out4, newdata = dat, se.fit = TRUE, interval = "confidence") | ||
| #' | ||
| #' \dontrun{ | ||
| #' # Can also use 'margins' package if you have it installed to get | ||
| #' # marginal effects | ||
| #' library(margins) | ||
| #' lmlout <- lm_lin(y ~ z_block, ~ x, data = dat) | ||
| #' # Instruct 'margins' to treat z as a factor | ||
| #' lmlout <- lm_lin(y ~ factor(z_block), ~ x, data = dat) | ||
| #' summary(margins(lmlout)) | ||
| #' | ||
| #' # Can output results using 'texreg' | ||
|
|
@@ -230,7 +245,7 @@ lm_lin <- function(formula, | |
| design_mat_treatment <- colnames(design_matrix)[treat_col] | ||
|
|
||
| # Check case where treatment is not factor and is not binary | ||
| if (any(!(treatment %in% c(0, 1)))) { | ||
| if (any(!(treatment %in% c(0, 1))) | (!has_intercept&ncol(treatment) ==1) ) { | ||
|
Comment on lines
-233
to
+248
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change and the subsequent one modify how |
||
| # create dummies for non-factor treatment variable | ||
|
|
||
| # Drop out first group if there is an intercept | ||
|
|
@@ -313,20 +328,10 @@ lm_lin <- function(formula, | |
| interacted_covars | ||
| ) | ||
| } else { | ||
| # If no intercept, but treatment is only one column, | ||
| # need to add base terms for covariates | ||
| if (n_treat_cols == 1) { | ||
| X <- cbind( | ||
| treatment, | ||
| demeaned_covars, | ||
| interacted_covars | ||
| ) | ||
|
Comment on lines
-316
to
-323
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This special case is resolved |
||
| } else { | ||
| X <- cbind( | ||
| treatment, | ||
| interacted_covars | ||
| ) | ||
| } | ||
| X <- cbind( | ||
| treatment, | ||
| interacted_covars | ||
| ) | ||
| } | ||
|
|
||
| # ---------- | ||
|
|
@@ -360,6 +365,12 @@ lm_lin <- function(formula, | |
|
|
||
| return_list[["scaled_center"]] <- center | ||
| setNames(return_list[["scaled_center"]], original_covar_names) | ||
| # Store unique treatment values | ||
| if(attr(terms(model_data), "dataClasses")[attr(terms(model_data),"term.labels")[1]] == "factor"){ | ||
| return_list[["treatment_levels"]] <- model_data$xlevels[[1]] | ||
| } else { | ||
| return_list[["treatment_levels"]] <- sort(unique(design_matrix[, design_mat_treatment])) | ||
| } | ||
|
Comment on lines
+368
to
+373
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is added so that when the model matrix is generated for predictions, we can ensure that the new data only includes a subset of treatment levels that were in the original model fit. Without being able to check this, weird behavior could result from predictions where the new data does not share identical treatment levels with the original data. This is saved in $xlevels in the model object if treatment is a factor, but if treatment is entered into the model as a numeric variable, this information is not otherwise saved. |
||
|
|
||
| return_list[["call"]] <- match.call() | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of lm_lin scaling is moved down to
get_X()