From d24286b9324f02d092ac5265a1db77bf431e13e0 Mon Sep 17 00:00:00 2001 From: "Mattan S. Ben-Shachar" Date: Tue, 3 Sep 2024 09:02:49 +0300 Subject: [PATCH] #921 --- NAMESPACE | 5 + R/get_datagrid.R | 562 +++++++++++++++-------------- man/get_datagrid.Rd | 14 +- tests/testthat/test-get_datagrid.R | 86 ++++- 4 files changed, 397 insertions(+), 270 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index ee48a9f83..84e084ae4 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -485,14 +485,19 @@ S3method(get_data,zcpglm) S3method(get_data,zeroinfl) S3method(get_data,zerotrunc) S3method(get_datagrid,character) +S3method(get_datagrid,comparisons) S3method(get_datagrid,data.frame) S3method(get_datagrid,datagrid) S3method(get_datagrid,default) S3method(get_datagrid,double) +S3method(get_datagrid,emmGrid) +S3method(get_datagrid,emm_list) S3method(get_datagrid,factor) S3method(get_datagrid,logical) S3method(get_datagrid,logitr) S3method(get_datagrid,numeric) +S3method(get_datagrid,predictions) +S3method(get_datagrid,slopes) S3method(get_datagrid,visualisation_matrix) S3method(get_datagrid,wbm) S3method(get_deviance,MixMod) diff --git a/R/get_datagrid.R b/R/get_datagrid.R index 63f18b9bb..8fec20079 100644 --- a/R/get_datagrid.R +++ b/R/get_datagrid.R @@ -1,9 +1,13 @@ #' Create a reference grid #' #' Create a reference matrix, useful for visualisation, with evenly spread and -#' combined values. Usually used to make generate predictions using [get_predicted()]. -#' See this [vignette](https://easystats.github.io/modelbased/articles/visualisation_matrix.html) +#' combined values. Usually used to make generate predictions using +#' [get_predicted()]. See this +#' [vignette](https://easystats.github.io/modelbased/articles/visualisation_matrix.html) #' for a tutorial on how to create a visualisation matrix using this function. +#' \cr\cr +#' Alternatively, these can also be used to extract the "grid" columns from +#' objects generated by `emmeans` and `marginaleffects`. #' #' @param x An object from which to construct the reference grid. #' @param by Indicates the _focal predictors_ (variables) for the reference grid @@ -181,9 +185,7 @@ get_datagrid <- function(x, ...) { } -# ------------------------------------------------------------------------- -# Below are get_datagrid functions for DataFrames -# ------------------------------------------------------------------------- +# Functions for data.frames ----------------------------------------------- #' @rdname get_datagrid #' @export @@ -406,71 +408,10 @@ get_datagrid.data.frame <- function(x, - - - - - - - -# Utils ------------------------------------------------------------------- - -#' @keywords internal -.get_datagrid_summary <- function(x, numerics = "mean", factors = "reference", na.rm = TRUE, ...) { - if (na.rm) x <- stats::na.omit(x) - - if (is.numeric(x)) { - if (is.numeric(numerics)) { - out <- numerics - } else if (numerics %in% c("all", "combination")) { - out <- unique(x) - } else { - out <- eval(parse(text = paste0(numerics, "(x)"))) - } - } else if (factors %in% c("all", "combination")) { - out <- unique(x) - } else if (factors == "mode") { - # Get mode - out <- names(sort(table(x), decreasing = TRUE)[1]) - } else { - # Get reference - if (is.factor(x)) { - all_levels <- levels(x) - } else if (is.character(x) || is.logical(x)) { - all_levels <- unique(x) - } else { - format_error(paste0( - "Argument is not numeric nor factor but ", class(x), ".", - "Please report the bug at https://github.com/easystats/insight/issues" - )) - } - # see "get_modelmatrix()" and #626. Reference level is currently - # a character vector, which causes the error - # > Error in `contrasts<-`(`*tmp*`, value = contr.funs[1 + isOF[nn]]) : - # > contrasts can be applied only to factors with 2 or more levels - # this is usually avoided by calling ".pad_modelmatrix()", but this - # function ignores character vectors. so we need to make sure that this - # factor level is also of class factor. - out <- factor(all_levels[1]) - # although we have reference level only, we still need information - # about all levels, see #695 - levels(out) <- all_levels - } - out -} - - - - - -# ------------------------------------------------------------------------- -# Below are get_datagrid functions that work on a vector (a single column) +# Functions that work on a vector (a single column) ---------------------- # See tests/test-get_datagrid.R for examples -# ------------------------------------------------------------------------- - - -# Numeric ----------------------------------------------------------------- +## Numeric ------------------------------------ #' @rdname get_datagrid #' @export @@ -503,71 +444,10 @@ get_datagrid.numeric <- function(x, length = 10, range = "range", ...) { get_datagrid.double <- get_datagrid.numeric -#' @keywords internal -.create_spread <- function(x, length = 10, range = "range", ci = 0.95, ...) { - range <- match.arg(tolower(range), c("range", "iqr", "ci", "hdi", "eti", "sd", "mad", "grid")) - - # bayestestR only for some options - if (range %in% c("ci", "hdi", "eti")) { - check_if_installed("bayestestR") - } - - # check if range = "grid" - then use mean/sd for every numeric that - # is not first predictor... - if (range == "grid") { - range <- "sd" - if (isFALSE(list(...)$is_first_predictor)) { - length <- 3 - } - } - # If Range is a dispersion (e.g., SD or MAD) - if (range %in% c("sd", "mad")) { - spread <- -floor((length - 1) / 2):ceiling((length - 1) / 2) - if (range == "sd") { - disp <- stats::sd(x, na.rm = TRUE) - center <- mean(x, na.rm = TRUE) - labs <- ifelse(sign(spread) == -1, paste(spread, "SD"), - ifelse(sign(spread) == 1, paste0("+", spread, " SD"), "Mean") # nolint - ) - } else { - disp <- stats::mad(x, na.rm = TRUE) - center <- stats::median(x, na.rm = TRUE) - labs <- ifelse(sign(spread) == -1, paste(spread, "MAD"), - ifelse(sign(spread) == 1, paste0("+", spread, " MAD"), "Median") # nolint - ) - } - out <- center + spread * disp - names(out) <- labs - return(out) - } - # If Range is an interval - if (range == "iqr") { # nolint - mini <- stats::quantile(x, (1 - ci) / 2, ...) - maxi <- stats::quantile(x, (1 + ci) / 2, ...) - } else if (range == "ci") { - out <- bayestestR::ci(x, ci = ci, verbose = FALSE, ...) - mini <- out$CI_low - maxi <- out$CI_high - } else if (range == "eti") { - out <- bayestestR::eti(x, ci = ci, verbose = FALSE, ...) - mini <- out$CI_low - maxi <- out$CI_high - } else if (range == "hdi") { - out <- bayestestR::hdi(x, ci = ci, verbose = FALSE, ...) - mini <- out$CI_low - maxi <- out$CI_high - } else { - mini <- min(x, na.rm = TRUE) - maxi <- max(x, na.rm = TRUE) - } - seq(mini, maxi, length.out = length) -} - - -# Factors & Characters ---------------------------------------------------- +## Factors & Characters ---------------------------------------------------- #' @rdname get_datagrid @@ -599,136 +479,7 @@ get_datagrid.logical <- get_datagrid.character -# Utilities ----------------------------------------------------------------- - -#' @keywords internal -.get_datagrid_clean_target <- function(x, by = NULL, ...) { - by_expression <- NA - varname <- NA - original_target <- by - - if (!is.null(by)) { - if (is.data.frame(x) && by %in% names(x)) { - return(data.frame(varname = by, expression = NA)) - } - - # If there is an equal sign - if (grepl("length.out =", by, fixed = TRUE)) { - by_expression <- by # This is an edgecase - } else if (grepl("=", by, fixed = TRUE)) { - parts <- trim_ws(unlist(strsplit(by, "=", fixed = TRUE), use.names = FALSE)) # Split and clean - varname <- parts[1] # left-hand part is probably the name of the variable - by <- parts[2] # right-hand part is the real target - } - - if (is.na(by_expression) && is.data.frame(x)) { - if (is.na(varname)) { - format_error( - "Couldn't find which variable were selected in `by`. Check spelling and specification." - ) - } else { - x <- x[[varname]] - } - } - - # If brackets are detected [a, b] - if (is.na(by_expression) && grepl("\\[.*\\]", by)) { - # Clean -------------------- - # Keep the content - parts <- trim_ws(unlist(regmatches(by, gregexpr("\\[.+?\\]", by)), use.names = FALSE)) - # Drop the brackets - parts <- gsub("\\[|\\]", "", parts) - # Split by a separator like ',' - parts <- trim_ws(unlist(strsplit(parts, ",", fixed = TRUE), use.names = FALSE)) - # If the elements have quotes around them, drop them - if (all(grepl("\\'.*\\'", parts))) parts <- gsub("'", "", parts, fixed = TRUE) - if (all(grepl('\\".*\\"', parts))) parts <- gsub('"', "", parts, fixed = TRUE) - - # Make expression ---------- - if (is.factor(x) || is.character(x)) { - # Factor - # Add quotes around them - parts <- paste0("'", parts, "'") - # Convert to character - by_expression <- paste0("as.factor(c(", toString(parts), "))") - } else if (length(parts) == 1) { - # Numeric - # If one, might be a shortcut - shortcuts <- c( - "meansd", "sd", "mad", "quartiles", "quartiles2", "zeromax", - "minmax", "terciles", "terciles2", "fivenum" - ) - if (parts %in% shortcuts) { - if (parts %in% c("meansd", "sd")) { - center <- mean(x, na.rm = TRUE) - spread <- stats::sd(x, na.rm = TRUE) - by_expression <- paste0("c(", center - spread, ",", center, ",", center + spread, ")") - } else if (parts == "mad") { - center <- stats::median(x, na.rm = TRUE) - spread <- stats::mad(x, na.rm = TRUE) - by_expression <- paste0("c(", center - spread, ",", center, ",", center + spread, ")") - } else if (parts == "quartiles") { - by_expression <- paste0("c(", paste(as.vector(stats::quantile(x, na.rm = TRUE)), collapse = ","), ")") - } else if (parts == "quartiles2") { - by_expression <- paste0("c(", paste(as.vector(stats::quantile(x, na.rm = TRUE))[2:4], collapse = ","), ")") - } else if (parts == "terciles") { - by_expression <- paste0("c(", paste(as.vector(stats::quantile(x, probs = (1:2) / 3, na.rm = TRUE)), collapse = ","), ")") # nolint - } else if (parts == "terciles2") { - by_expression <- paste0("c(", paste(as.vector(stats::quantile(x, probs = (0:3) / 3, na.rm = TRUE)), collapse = ","), ")") # nolint - } else if (parts == "fivenum") { - by_expression <- paste0("c(", paste(as.vector(stats::fivenum(x, na.rm = TRUE)), collapse = ","), ")") - } else if (parts == "zeromax") { - by_expression <- paste0("c(0,", max(x, na.rm = TRUE), ")") - } else if (parts == "minmax") { - by_expression <- paste0("c(", min(x, na.rm = TRUE), ",", max(x, na.rm = TRUE), ")") - } - } else if (is.numeric(parts)) { - by_expression <- parts - } else { - format_error( - paste0( - "The `by` argument (", by, ") should either indicate the minimum and the maximum, or one of the following options: ", # nolint - toString(shortcuts), - "." - ) - ) - } - # If only two, it's probably the range - } else if (length(parts) == 2) { - by_expression <- paste0("seq(", parts[1], ", ", parts[2], ", length.out = length)") - # If more, it's probably the vector - } else if (length(parts) > 2L) { - parts <- as.numeric(parts) - by_expression <- paste0("c(", toString(parts), ")") - } - # Else, try to directly eval the content - } else { - by_expression <- by - # Try to eval and make sure it works - tryCatch( - { - # This is just to make sure that an expression with `length` in - # it doesn't fail because of this undefined var - length <- 10 # nolint - eval(parse(text = by)) - }, - error = function(r) { - format_error( - paste0("The `by` argument (`", original_target, "`) cannot be read and could be mispecified.") - ) - } - ) - } - } - data.frame(varname = varname, expression = by_expression, stringsAsFactors = FALSE) -} - - - - -# ------------------------------------------------------------------------- -# Below are get_datagrid functions that work on statistical models -# ------------------------------------------------------------------------- +# Functions that work on statistical models ------------------------------- #' @rdname get_datagrid #' @export @@ -886,9 +637,7 @@ get_datagrid.wbm <- function(x, -# ------------------------------------------------------------------------- -# Below are get_datagrid functions that work on get_datagrid -# ------------------------------------------------------------------------- +# Functions that work on get_datagrid ------------------------------------- #' @export get_datagrid.visualisation_matrix <- function(x, reference = attributes(x)$reference, ...) { @@ -908,9 +657,287 @@ get_datagrid.datagrid <- get_datagrid.visualisation_matrix +# Functions for emmeans/marginaleffects --------------- -# helper ----------------- +#' @rdname get_datagrid +#' @export +get_datagrid.emmGrid <- function(x, ...) { + suppressWarnings(s <- as.data.frame(x)) + # We want all the columns *before* the estimate column + est_col_idx <- which(colnames(s) == attr(s, "estName")) + which_cols <- seq_len(est_col_idx - 1) + + data.frame(s)[, which_cols, drop = FALSE] +} + +#' @export +get_datagrid.emm_list <- function(x, ...) { + k <- length(x) + res <- vector("list", length = k) + for (i in seq_len(k)) { + res[[i]] <- get_datagrid(x[[i]]) + } + all_cols <- Reduce(lapply(res, colnames), f = union) + for (i in seq_len(k)) { + res[[i]][,setdiff(all_cols, colnames(res[[i]]))] <- NA + } + out <- do.call("rbind", res) + + clear_cols <- colnames(out)[sapply(out, Negate(anyNA))] # these should be first + out[,c(clear_cols, setdiff(colnames(out), clear_cols)), drop = FALSE] +} + +#' @rdname get_datagrid +#' @export +get_datagrid.slopes <- function(x, ...) { + cols_newdata <- colnames(attr(x, "newdata")) + cols_contrast <- colnames(x)[grep("^contrast_?", colnames(x))] + cols_misc <- c("by", "hypothesis") + cols_grid <- union(union(cols_newdata, cols_contrast), cols_misc) + + data.frame(x)[, intersect(colnames(x), cols_grid), drop = FALSE] +} + +#' @export +get_datagrid.predictions <- get_datagrid.slopes + +#' @export +get_datagrid.comparisons <- get_datagrid.slopes + + +# Utilities ----------------------------------------------------------------- + +#' @keywords internal +.get_datagrid_clean_target <- function(x, by = NULL, ...) { + by_expression <- NA + varname <- NA + original_target <- by + + if (!is.null(by)) { + if (is.data.frame(x) && by %in% names(x)) { + return(data.frame(varname = by, expression = NA)) + } + + # If there is an equal sign + if (grepl("length.out =", by, fixed = TRUE)) { + by_expression <- by # This is an edgecase + } else if (grepl("=", by, fixed = TRUE)) { + parts <- trim_ws(unlist(strsplit(by, "=", fixed = TRUE), use.names = FALSE)) # Split and clean + varname <- parts[1] # left-hand part is probably the name of the variable + by <- parts[2] # right-hand part is the real target + } + + if (is.na(by_expression) && is.data.frame(x)) { + if (is.na(varname)) { + format_error( + "Couldn't find which variable were selected in `by`. Check spelling and specification." + ) + } else { + x <- x[[varname]] + } + } + + # If brackets are detected [a, b] + if (is.na(by_expression) && grepl("\\[.*\\]", by)) { + # Clean -------------------- + # Keep the content + parts <- trim_ws(unlist(regmatches(by, gregexpr("\\[.+?\\]", by)), use.names = FALSE)) + # Drop the brackets + parts <- gsub("\\[|\\]", "", parts) + # Split by a separator like ',' + parts <- trim_ws(unlist(strsplit(parts, ",", fixed = TRUE), use.names = FALSE)) + # If the elements have quotes around them, drop them + if (all(grepl("\\'.*\\'", parts))) parts <- gsub("'", "", parts, fixed = TRUE) + if (all(grepl('\\".*\\"', parts))) parts <- gsub('"', "", parts, fixed = TRUE) + + # Make expression ---------- + if (is.factor(x) || is.character(x)) { + # Factor + # Add quotes around them + parts <- paste0("'", parts, "'") + # Convert to character + by_expression <- paste0("as.factor(c(", toString(parts), "))") + } else if (length(parts) == 1) { + # Numeric + # If one, might be a shortcut + shortcuts <- c( + "meansd", "sd", "mad", "quartiles", "quartiles2", "zeromax", + "minmax", "terciles", "terciles2", "fivenum" + ) + if (parts %in% shortcuts) { + if (parts %in% c("meansd", "sd")) { + center <- mean(x, na.rm = TRUE) + spread <- stats::sd(x, na.rm = TRUE) + by_expression <- paste0("c(", center - spread, ",", center, ",", center + spread, ")") + } else if (parts == "mad") { + center <- stats::median(x, na.rm = TRUE) + spread <- stats::mad(x, na.rm = TRUE) + by_expression <- paste0("c(", center - spread, ",", center, ",", center + spread, ")") + } else if (parts == "quartiles") { + by_expression <- paste0("c(", paste(as.vector(stats::quantile(x, na.rm = TRUE)), collapse = ","), ")") + } else if (parts == "quartiles2") { + by_expression <- paste0("c(", paste(as.vector(stats::quantile(x, na.rm = TRUE))[2:4], collapse = ","), ")") + } else if (parts == "terciles") { + by_expression <- paste0("c(", paste(as.vector(stats::quantile(x, probs = (1:2) / 3, na.rm = TRUE)), collapse = ","), ")") # nolint + } else if (parts == "terciles2") { + by_expression <- paste0("c(", paste(as.vector(stats::quantile(x, probs = (0:3) / 3, na.rm = TRUE)), collapse = ","), ")") # nolint + } else if (parts == "fivenum") { + by_expression <- paste0("c(", paste(as.vector(stats::fivenum(x, na.rm = TRUE)), collapse = ","), ")") + } else if (parts == "zeromax") { + by_expression <- paste0("c(0,", max(x, na.rm = TRUE), ")") + } else if (parts == "minmax") { + by_expression <- paste0("c(", min(x, na.rm = TRUE), ",", max(x, na.rm = TRUE), ")") + } + } else if (is.numeric(parts)) { + by_expression <- parts + } else { + format_error( + paste0( + "The `by` argument (", by, ") should either indicate the minimum and the maximum, or one of the following options: ", # nolint + toString(shortcuts), + "." + ) + ) + } + # If only two, it's probably the range + } else if (length(parts) == 2) { + by_expression <- paste0("seq(", parts[1], ", ", parts[2], ", length.out = length)") + # If more, it's probably the vector + } else if (length(parts) > 2L) { + parts <- as.numeric(parts) + by_expression <- paste0("c(", toString(parts), ")") + } + # Else, try to directly eval the content + } else { + by_expression <- by + # Try to eval and make sure it works + tryCatch( + { + # This is just to make sure that an expression with `length` in + # it doesn't fail because of this undefined var + length <- 10 # nolint + eval(parse(text = by)) + }, + error = function(r) { + format_error( + paste0("The `by` argument (`", original_target, "`) cannot be read and could be mispecified.") + ) + } + ) + } + } + data.frame(varname = varname, expression = by_expression, stringsAsFactors = FALSE) +} + +#' @keywords internal +.get_datagrid_summary <- function(x, numerics = "mean", factors = "reference", na.rm = TRUE, ...) { + if (na.rm) x <- stats::na.omit(x) + + if (is.numeric(x)) { + if (is.numeric(numerics)) { + out <- numerics + } else if (numerics %in% c("all", "combination")) { + out <- unique(x) + } else { + out <- eval(parse(text = paste0(numerics, "(x)"))) + } + } else if (factors %in% c("all", "combination")) { + out <- unique(x) + } else if (factors == "mode") { + # Get mode + out <- names(sort(table(x), decreasing = TRUE)[1]) + } else { + # Get reference + if (is.factor(x)) { + all_levels <- levels(x) + } else if (is.character(x) || is.logical(x)) { + all_levels <- unique(x) + } else { + format_error(paste0( + "Argument is not numeric nor factor but ", class(x), ".", + "Please report the bug at https://github.com/easystats/insight/issues" + )) + } + # see "get_modelmatrix()" and #626. Reference level is currently + # a character vector, which causes the error + # > Error in `contrasts<-`(`*tmp*`, value = contr.funs[1 + isOF[nn]]) : + # > contrasts can be applied only to factors with 2 or more levels + # this is usually avoided by calling ".pad_modelmatrix()", but this + # function ignores character vectors. so we need to make sure that this + # factor level is also of class factor. + out <- factor(all_levels[1]) + # although we have reference level only, we still need information + # about all levels, see #695 + levels(out) <- all_levels + } + out +} + +#' @keywords internal +.create_spread <- function(x, length = 10, range = "range", ci = 0.95, ...) { + range <- match.arg(tolower(range), c("range", "iqr", "ci", "hdi", "eti", "sd", "mad", "grid")) + + # bayestestR only for some options + if (range %in% c("ci", "hdi", "eti")) { + check_if_installed("bayestestR") + } + + # check if range = "grid" - then use mean/sd for every numeric that + # is not first predictor... + if (range == "grid") { + range <- "sd" + if (isFALSE(list(...)$is_first_predictor)) { + length <- 3 + } + } + + # If Range is a dispersion (e.g., SD or MAD) + if (range %in% c("sd", "mad")) { + spread <- -floor((length - 1) / 2):ceiling((length - 1) / 2) + if (range == "sd") { + disp <- stats::sd(x, na.rm = TRUE) + center <- mean(x, na.rm = TRUE) + labs <- ifelse(sign(spread) == -1, paste(spread, "SD"), + ifelse(sign(spread) == 1, paste0("+", spread, " SD"), "Mean") # nolint + ) + } else { + disp <- stats::mad(x, na.rm = TRUE) + center <- stats::median(x, na.rm = TRUE) + labs <- ifelse(sign(spread) == -1, paste(spread, "MAD"), + ifelse(sign(spread) == 1, paste0("+", spread, " MAD"), "Median") # nolint + ) + } + out <- center + spread * disp + names(out) <- labs + + return(out) + } + + # If Range is an interval + if (range == "iqr") { # nolint + mini <- stats::quantile(x, (1 - ci) / 2, ...) + maxi <- stats::quantile(x, (1 + ci) / 2, ...) + } else if (range == "ci") { + out <- bayestestR::ci(x, ci = ci, verbose = FALSE, ...) + mini <- out$CI_low + maxi <- out$CI_high + } else if (range == "eti") { + out <- bayestestR::eti(x, ci = ci, verbose = FALSE, ...) + mini <- out$CI_low + maxi <- out$CI_high + } else if (range == "hdi") { + out <- bayestestR::hdi(x, ci = ci, verbose = FALSE, ...) + mini <- out$CI_low + maxi <- out$CI_high + } else { + mini <- min(x, na.rm = TRUE) + maxi <- max(x, na.rm = TRUE) + } + seq(mini, maxi, length.out = length) +} + +#' @keywords internal .data_match <- function(x, to, ...) { if (!is.data.frame(to)) { to <- as.data.frame(to) @@ -924,7 +951,7 @@ get_datagrid.datagrid <- get_datagrid.visualisation_matrix .to_numeric(row.names(x)[idx]) } - +#' @keywords internal .get_model_data_for_grid <- function(x, data) { # Retrieve data, based on variable names if (is.null(data)) { @@ -973,7 +1000,7 @@ get_datagrid.datagrid <- get_datagrid.visualisation_matrix } - +#' @keywords internal .extract_at_interactions <- function(by) { # get interaction terms, but only if these are not inside brackets (like "[4:8]") interaction_terms <- grepl("(:|\\*)(?![^\\[]*\\])", by, perl = TRUE) @@ -987,9 +1014,10 @@ get_datagrid.datagrid <- get_datagrid.visualisation_matrix } +#' @keywords internal .replace_attr <- function(data, custom_attr) { for (nm in setdiff(names(custom_attr), names(attributes(data.frame())))) { attr(data, which = nm) <- custom_attr[[nm]] } data -} +} \ No newline at end of file diff --git a/man/get_datagrid.Rd b/man/get_datagrid.Rd index cd3607334..6799424f5 100644 --- a/man/get_datagrid.Rd +++ b/man/get_datagrid.Rd @@ -6,6 +6,8 @@ \alias{get_datagrid.numeric} \alias{get_datagrid.factor} \alias{get_datagrid.default} +\alias{get_datagrid.emmGrid} +\alias{get_datagrid.slopes} \title{Create a reference grid} \usage{ get_datagrid(x, ...) @@ -42,6 +44,10 @@ get_datagrid(x, ...) at, ... ) + +\method{get_datagrid}{emmGrid}(x, ...) + +\method{get_datagrid}{slopes}(x, ...) } \arguments{ \item{x}{An object from which to construct the reference grid.} @@ -176,9 +182,13 @@ Reference grid data frame. } \description{ Create a reference matrix, useful for visualisation, with evenly spread and -combined values. Usually used to make generate predictions using \code{\link[=get_predicted]{get_predicted()}}. -See this \href{https://easystats.github.io/modelbased/articles/visualisation_matrix.html}{vignette} +combined values. Usually used to make generate predictions using +\code{\link[=get_predicted]{get_predicted()}}. See this +\href{https://easystats.github.io/modelbased/articles/visualisation_matrix.html}{vignette} for a tutorial on how to create a visualisation matrix using this function. +\cr\cr +Alternatively, these can also be used to extract the "grid" columns from +objects generated by \code{emmeans} and \code{marginaleffects}. } \examples{ \dontshow{if (require("bayestestR", quietly = TRUE) && require("datawizard", quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} diff --git a/tests/testthat/test-get_datagrid.R b/tests/testthat/test-get_datagrid.R index d6137f545..a3a89417e 100644 --- a/tests/testthat/test-get_datagrid.R +++ b/tests/testthat/test-get_datagrid.R @@ -213,8 +213,92 @@ test_that("get_datagrid - models", { expect_identical(dim(get_datagrid(mod, include_random = FALSE, include_smooth = FALSE)), as.integer(c(10, 1))) }) +test_that("get_datagrid - emmeans", { + skip_if_not_installed("emmeans") + skip_on_cran() + + data("mtcars") + + mod1 <- glm(am ~ hp + factor(cyl), family = binomial("logit"), data = mtcars) + mod2 <- lm(mpg ~ hp + factor(cyl), data = mtcars) + + hp_vals <- c(50, 100) + + em1 <- emmeans::emmeans(mod1, ~ cyl | hp, at = list(hp = hp_vals)) + em2 <- emmeans::regrid(em1) + em3 <- emmeans::emmeans(mod2, ~ cyl | hp, at = list(hp = hp_vals)) + + res <- get_datagrid(em1) + expect_equal(res, get_datagrid(em2)) + expect_equal(res, get_datagrid(em3)) + expect_s3_class(res, "data.frame") + expect_equal(dim(res), c(6, 2)) + expect_true(all(c(4, 6, 8) %in% res[[1]])) + expect_true(all(hp_vals %in% res[[2]])) + + res <- get_datagrid(emmeans::contrast(em1, method = "poly", max.degree = 2)) + expect_s3_class(res, "data.frame") + expect_equal(dim(res), c(4, 2)) + expect_true("contrast" %in% colnames(res)) + expect_true(all(c("linear", "quadratic") %in% res[["contrast"]])) + expect_true(all(hp_vals %in% res[["hp"]])) + + # emm_list + em1 <- emmeans::emmeans(mod1, pairwise ~ cyl | hp, at = list(hp = hp_vals)) + em2 <- emmeans::emmeans(mod1, pairwise ~ cyl | hp, at = list(hp = hp_vals), regrid = TRUE) + em3 <- emmeans::emmeans(mod2, pairwise ~ cyl | hp, at = list(hp = hp_vals)) + + res <- get_datagrid(em1) + expect_equal(res, get_datagrid(em2)) + expect_equal(res, get_datagrid(em3)) + expect_s3_class(res, "data.frame") + expect_equal(dim(res), c(12, 3)) + expect_true("contrast" %in% colnames(res)) + expect_true(anyNA(res[["contrast"]])) + expect_true(all(c(4, 6, 8, NA) %in% res[["cyl"]])) + expect_true(all(hp_vals %in% res[["hp"]])) +}) + +test_that("get_datagrid - marginaleffects", { + skip_if_not_installed("marginaleffects") + skip_on_cran() + + data("mtcars") + + mod1 <- glm(am ~ hp + factor(cyl), family = binomial("logit"), data = mtcars) + mod2 <- lm(mpg ~ hp + factor(cyl), data = mtcars) + + mp1 <- marginaleffects::avg_predictions(mod1, variables = list("hp" = c(50, 100)), + by = c("cyl", "hp")) + mp2 <- marginaleffects::avg_predictions(mod2, variables = list("hp" = c(50, 100), + cyl = unique)) + + res <- get_datagrid(mp1) + expect_s3_class(res, "data.frame") + expect_equal(dim(res), c(6, 2)) + expect_true(all(c(4, 6, 8) %in% res[[1]])) + expect_true(all(c(50, 100) %in% res[[2]])) + + res2 <- get_datagrid(mp2) + expect_s3_class(res2, "data.frame") + expect_equal(dim(res2), c(6, 2)) + expect_true(all(c(4, 6, 8) %in% res2[[2]])) + expect_true(all(c(50, 100) %in% res2[[1]])) + + + + + mod <- lm(mpg ~ wt + hp + qsec, data = mtcars) + myme <- marginaleffects::comparisons(mod, + variables = c("wt", "hp"), + cross = TRUE, + newdata = marginaleffects::datagrid(qsec = range)) + res <- get_datagrid(myme) + expect_true(all(c("wt", "mpg", "hp", "qsec") %in% colnames(res))) + expect_true(all(c("contrast_hp", "contrast_wt") %in% colnames(res))) +}) -test_that("factor levels as reference / non-focal terms works", { +test_that("get_datagrid - factor levels as reference / non-focal terms works", { d <- structure(list( lfp = structure(c( 2L, 2L, 2L, 2L, 2L, 2L, 2L,