From c11b9003ea615b3fd7742d16d435b0b1cd8710dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kirill=20M=C3=BCller?= Date: Thu, 7 Mar 2024 14:49:10 +0100 Subject: [PATCH] n_groups --- NAMESPACE | 2 ++ R/duckplyr-package.R | 1 + R/n_groups.R | 32 ++++++++++++++++++++++++++++ R/overwrite.R | 1 + R/restore.R | 1 + dplyr-methods/n_groups.txt | 3 +++ tests/testthat/test-as_duckplyr_df.R | 16 ++++++++++++++ tools/00-funs.R | 3 ++- 8 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 R/n_groups.R create mode 100644 dplyr-methods/n_groups.txt diff --git a/NAMESPACE b/NAMESPACE index 774a343b..87ccf4d7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -31,6 +31,7 @@ S3method(inner_join,duckplyr_df) S3method(intersect,duckplyr_df) S3method(left_join,duckplyr_df) S3method(mutate,duckplyr_df) +S3method(n_groups,duckplyr_df) S3method(nest_by,duckplyr_df) S3method(nest_join,duckplyr_df) S3method(print,relational_relexpr) @@ -283,6 +284,7 @@ importFrom(dplyr,mutate) importFrom(dplyr,mutate_all) importFrom(dplyr,n) importFrom(dplyr,n_distinct) +importFrom(dplyr,n_groups) importFrom(dplyr,nest_by) importFrom(dplyr,nest_join) importFrom(dplyr,nth) diff --git a/R/duckplyr-package.R b/R/duckplyr-package.R index 62a1ea25..d725b5b1 100644 --- a/R/duckplyr-package.R +++ b/R/duckplyr-package.R @@ -15,6 +15,7 @@ #' @importFrom dplyr group_trim #' @importFrom dplyr grouped_df #' @importFrom dplyr if_else +#' @importFrom dplyr n_groups #' @importFrom glue glue #' @importFrom lifecycle deprecated #' @importFrom tibble as_tibble diff --git a/R/n_groups.R b/R/n_groups.R new file mode 100644 index 00000000..dedcc5e8 --- /dev/null +++ b/R/n_groups.R @@ -0,0 +1,32 @@ +# Generated by 02-duckplyr_df-methods.R +#' @export +n_groups.duckplyr_df <- function(x) { + # Our implementation + rel_try( + # Always fall back to dplyr + "No relational implementation for n_groups()" = TRUE, + { + return(out) + } + ) + + # dplyr forward + n_groups <- dplyr$n_groups.data.frame + out <- n_groups(x) + return(out) + + # dplyr implementation + nrow(group_data(x)) +} + +duckplyr_n_groups <- function(x, ...) { + try_fetch( + x <- as_duckplyr_df(x), + error = function(e) { + testthat::skip(conditionMessage(e)) + } + ) + out <- n_groups(x, ...) + class(out) <- setdiff(class(out), "duckplyr_df") + out +} diff --git a/R/overwrite.R b/R/overwrite.R index b75fe840..5b42c3b0 100644 --- a/R/overwrite.R +++ b/R/overwrite.R @@ -30,6 +30,7 @@ methods_overwrite <- function() { vctrs::s3_register("dplyr::intersect", "data.frame", intersect.duckplyr_df) vctrs::s3_register("dplyr::left_join", "data.frame", left_join.duckplyr_df) vctrs::s3_register("dplyr::mutate", "data.frame", mutate.duckplyr_df) + vctrs::s3_register("dplyr::n_groups", "data.frame", n_groups.duckplyr_df) vctrs::s3_register("dplyr::nest_by", "data.frame", nest_by.duckplyr_df) vctrs::s3_register("dplyr::nest_join", "data.frame", nest_join.duckplyr_df) vctrs::s3_register("dplyr::pull", "data.frame", pull.duckplyr_df) diff --git a/R/restore.R b/R/restore.R index 8640755b..c0109556 100644 --- a/R/restore.R +++ b/R/restore.R @@ -30,6 +30,7 @@ methods_restore <- function() { vctrs::s3_register("dplyr::intersect", "data.frame", dplyr$intersect.data.frame) vctrs::s3_register("dplyr::left_join", "data.frame", dplyr$left_join.data.frame) vctrs::s3_register("dplyr::mutate", "data.frame", dplyr$mutate.data.frame) + vctrs::s3_register("dplyr::n_groups", "data.frame", dplyr$n_groups.data.frame) vctrs::s3_register("dplyr::nest_by", "data.frame", dplyr$nest_by.data.frame) vctrs::s3_register("dplyr::nest_join", "data.frame", dplyr$nest_join.data.frame) vctrs::s3_register("dplyr::pull", "data.frame", dplyr$pull.data.frame) diff --git a/dplyr-methods/n_groups.txt b/dplyr-methods/n_groups.txt new file mode 100644 index 00000000..62d3d86c --- /dev/null +++ b/dplyr-methods/n_groups.txt @@ -0,0 +1,3 @@ +n_groups.data.frame <- function(x) { + nrow(group_data(x)) +} diff --git a/tests/testthat/test-as_duckplyr_df.R b/tests/testthat/test-as_duckplyr_df.R index d764303e..c80c812e 100644 --- a/tests/testthat/test-as_duckplyr_df.R +++ b/tests/testthat/test-as_duckplyr_df.R @@ -1629,6 +1629,22 @@ test_that("as_duckplyr_df() and mutate(c = .data$b)", { expect_equal(pre, post) }) +test_that("as_duckplyr_df() and n_groups()", { + withr::local_envvar(DUCKPLYR_FORCE = "FALSE") + + skip("Special") + + # Data + test_df <- data.frame(a = 1:6 + 0, b = 2, g = rep(1:3, 1:3)) + + # Run + pre <- test_df %>% as_duckplyr_df() %>% n_groups() + post <- test_df %>% n_groups() %>% as_duckplyr_df() + + # Compare + expect_equal(pre, post) +}) + test_that("as_duckplyr_df() and nest_by()", { withr::local_envvar(DUCKPLYR_FORCE = "FALSE") diff --git a/tools/00-funs.R b/tools/00-funs.R index 49eb3c59..bdbdc1d0 100644 --- a/tools/00-funs.R +++ b/tools/00-funs.R @@ -16,7 +16,6 @@ df_methods <- filter(!grepl("_$|^as[.]tbl$", name)) %>% # special dplyr methods, won't implement filter(!(name %in% c( - "n_groups", "same_src", # data frames can be copied into duck-frames with zero cost NULL ))) %>% @@ -35,6 +34,7 @@ df_methods <- "group_split", "group_trim", "groups", + "n_groups", "rowwise", NULL ))) %>% @@ -711,6 +711,7 @@ test_skip_map <- c( group_split = "WAT", group_trim = "Grouped", groups = "Special", + n_groups = "Special", nest_by = "WAT", # FIXME: Fail with rowwise() rowwise = "Stack overflow",