Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot-method for check_dag() #352

Merged
merged 7 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: see
Title: Model Visualisation Toolbox for 'easystats' and 'ggplot2'
Version: 0.8.5
Version: 0.8.5.1
Authors@R:
c(person(given = "Daniel",
family = "Lüdecke",
Expand Down Expand Up @@ -78,6 +78,7 @@ Suggests:
DHARMa,
emmeans,
factoextra,
ggdag,
ggdist,
ggraph,
ggrepel,
Expand Down Expand Up @@ -119,3 +120,4 @@ Config/testthat/edition: 3
Config/testthat/parallel: true
Config/Needs/website: easystats/easystatstemplate
Config/rcmdcheck/ignore-inconsequential-notes: true
Remotes: easystats/performance#761
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ S3method(plot,see_bayesfactor_parameters)
S3method(plot,see_bayesfactor_savagedickey)
S3method(plot,see_binned_residuals)
S3method(plot,see_check_collinearity)
S3method(plot,see_check_dag)
S3method(plot,see_check_distribution)
S3method(plot,see_check_distribution_numeric)
S3method(plot,see_check_heteroscedasticity)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# see (development)

## Changes

- New `plot()` method for `performance::check_dag()`.

# see 0.8.5

## Major Changes
Expand Down
132 changes: 132 additions & 0 deletions R/plot.check_dag.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#' Plot method for check DAGs
#'
#' The `plot()` method for the `performance::check_dag()` function.
#'
#' @param x A `check_dag` object.
#' @param size_point Numeric value specifying size of point geoms.
#' @param colors Character vector of length five, indicating the colors (in
#' hex-format) for different types of variables.
#' @param which Character string indicating which plot to show. Can be either
#' `"all"`, `"current"` or `"required"`.
#' @param check_colliders Logical indicating whether to highlight colliders.
#' Set to `FALSE` if the algorithm to detect colliders is very slow.
#' @param ... Not used.
#'
#' @return A ggplot2-object.
#'
#' @examplesIf require("ggdag", quietly = TRUE)
#' library(performance)
#' # incorrect adjustment
#' dag <- check_dag(
#' y ~ x + b + c,
#' x ~ b,
#' outcome = "y",
#' exposure = "x"
#' )
#' dag
#' plot(dag)
#'
#' # plot only model with required adjustments
#' plot(dag, which = "required")
#'
#' # collider-bias?
#' dag <- check_dag(
#' y ~ x + c + d,
#' x ~ c + d,
#' b ~ x,
#' b ~ y,
#' outcome = "y",
#' exposure = "x",
#' adjusted = "c"
#' )
#' plot(dag)
#' @export
plot.see_check_dag <- function(x,
size_point = 15,
colors = NULL,
which = "all",
check_colliders = TRUE,
...) {
.data <- NULL
insight::check_if_installed(c("ggdag", "ggplot2"))
which <- match.arg(which, choices = c("all", "current", "required"))

# get plot data
p1 <- p2 <- suppressWarnings(ggdag::dag_adjustment_sets(x))
adjusted_for <- attributes(x)$adjusted

# for current plot, we need to update the "adjusted" column
p1$data$adjusted <- "unadjusted"
if (!is.null(adjusted_for)) {
p1$data$adjusted[p1$data$name %in% adjusted_for] <- "adjusted"
}

# tweak data
p1$data$type <- as.character(p1$data$adjusted)
if (check_colliders) {
p1$data$type[vapply(p1$data$name, ggdag::is_collider, logical(1), .dag = x)] <- "collider"
}
p1$data$type[p1$data$name == attributes(x)$outcome] <- "outcome"
p1$data$type[p1$data$name %in% attributes(x)$exposure] <- "exposure"
p1$data$type <- factor(p1$data$type, levels = c("outcome", "exposure", "adjusted", "unadjusted", "collider"))

p2$data$type <- as.character(p2$data$adjusted)
if (check_colliders) {
p2$data$type[vapply(p2$data$name, ggdag::is_collider, logical(1), .dag = x)] <- "collider"
}
p2$data$type[p2$data$name == attributes(x)$outcome] <- "outcome"
p2$data$type[p2$data$name %in% attributes(x)$exposure] <- "exposure"
p2$data$type <- factor(p2$data$type, levels = c("outcome", "exposure", "adjusted", "unadjusted", "collider"))

if (is.null(colors)) {
point_colors <- see_colors(c("yellow", "cyan", "blue grey", "red", "orange"))
} else if (length(colors) != 5) {
insight::format_error("`colors` must be a character vector with five color-values.")
} else {
point_colors <- colors
}
names(point_colors) <- c("outcome", "exposure", "adjusted", "unadjusted", "collider")

plot1 <- ggplot2::ggplot(p1$data, ggplot2::aes(x = .data$x, y = .data$y)) +
geom_point_borderless(ggplot2::aes(fill = .data$type), size = size_point) +
ggdag::geom_dag_edges(
ggplot2::aes(
xend = .data$xend,
yend = .data$yend,
edge_alpha = .data$adjusted
)
) +
ggdag::scale_adjusted() +
ggdag::geom_dag_label(ggplot2::aes(label = .data$name)) +
ggdag::theme_dag() +
ggplot2::scale_fill_manual(values = point_colors) +
ggplot2::ggtitle("Current model") +
ggplot2::guides(edge_alpha = "none")

plot2 <- ggplot2::ggplot(p2$data, ggplot2::aes(x = .data$x, y = .data$y)) +
geom_point_borderless(ggplot2::aes(fill = .data$type), size = size_point) +
ggdag::geom_dag_edges(
ggplot2::aes(
xend = .data$xend,
yend = .data$yend,
edge_alpha = .data$adjusted
)
) +
ggdag::scale_adjusted() +
ggdag::geom_dag_label(ggplot2::aes(label = .data$name)) +
ggdag::theme_dag() +
ggplot2::scale_fill_manual(values = point_colors) +
ggplot2::ggtitle("Required model") +
ggplot2::guides(edge_alpha = "none")

if (which == "all") {
# fix legends
plot2 <- plot2 + ggplot2::theme(legend.position = "none")
# plot
plots(plot1, plot2, n_rows = 1)
} else if (which == "current") {
plot1
} else {
plot2
}
}
66 changes: 66 additions & 0 deletions man/plot.see_check_dag.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading