From f418434f2146204d163ac0efcc965f95751171e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kirill=20M=C3=BCller?= Date: Thu, 7 Mar 2024 05:32:16 +0100 Subject: [PATCH] group_keys --- NAMESPACE | 1 + R/dplyr.R | 1 + R/group_keys.R | 41 ++++++++++++++++++++++++++++ R/overwrite.R | 1 + R/restore.R | 1 + dplyr-methods/group_keys.txt | 12 ++++++++ tests/testthat/test-as_duckplyr_df.R | 16 +++++++++++ tests/testthat/test-filter.R | 2 +- tests/testthat/test-slice.R | 4 +-- tools/00-funs.R | 4 ++- 10 files changed, 79 insertions(+), 4 deletions(-) create mode 100644 R/group_keys.R create mode 100644 dplyr-methods/group_keys.txt diff --git a/NAMESPACE b/NAMESPACE index 7ac8a7c1..80dc3f93 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -16,6 +16,7 @@ S3method(explain,duckplyr_df) S3method(full_join,duckplyr_df) S3method(group_by,duckplyr_df) S3method(group_data,duckplyr_df) +S3method(group_keys,duckplyr_df) S3method(group_vars,duckplyr_df) S3method(head,duckplyr_df) S3method(inner_join,duckplyr_df) diff --git a/R/dplyr.R b/R/dplyr.R index d46ad533..83b83631 100644 --- a/R/dplyr.R +++ b/R/dplyr.R @@ -25,6 +25,7 @@ expand_if_across <- dplyr$expand_if_across expr_substitute <- dplyr$expr_substitute get_slice_size <- dplyr$get_slice_size group_by_drop_default <- dplyr$group_by_drop_default +group_keys0 <- dplyr$group_keys0 is_compatible <- dplyr$is_compatible is_cross_by <- dplyr$is_cross_by join_by_common <- dplyr$join_by_common diff --git a/R/group_keys.R b/R/group_keys.R new file mode 100644 index 00000000..88a5daa4 --- /dev/null +++ b/R/group_keys.R @@ -0,0 +1,41 @@ +# Generated by 02-duckplyr_df-methods.R +#' @export +group_keys.duckplyr_df <- function(.tbl, ...) { + # Our implementation + rel_try( + # Always fall back to dplyr + "No relational implementation for group_keys()" = TRUE, + { + return(out) + } + ) + + # dplyr forward + group_keys <- dplyr$group_keys.data.frame + out <- group_keys(.tbl, ...) + return(out) + + # dplyr implementation + if (dots_n(...) > 0) { + lifecycle::deprecate_warn( + "1.0.0", "group_keys(... = )", + details = "Please `group_by()` first", + always = TRUE + ) + .tbl <- group_by(.tbl, ...) + } + out <- group_data(.tbl) + group_keys0(out) +} + +duckplyr_group_keys <- function(.tbl, ...) { + try_fetch( + .tbl <- as_duckplyr_df(.tbl), + error = function(e) { + testthat::skip(conditionMessage(e)) + } + ) + out <- group_keys(.tbl, ...) + class(out) <- setdiff(class(out), "duckplyr_df") + out +} diff --git a/R/overwrite.R b/R/overwrite.R index aa7b6b7a..01c548a1 100644 --- a/R/overwrite.R +++ b/R/overwrite.R @@ -16,6 +16,7 @@ methods_overwrite <- function() { vctrs::s3_register("dplyr::full_join", "data.frame", full_join.duckplyr_df) vctrs::s3_register("dplyr::group_by", "data.frame", group_by.duckplyr_df) vctrs::s3_register("dplyr::group_data", "data.frame", group_data.duckplyr_df) + vctrs::s3_register("dplyr::group_keys", "data.frame", group_keys.duckplyr_df) vctrs::s3_register("dplyr::group_vars", "data.frame", group_vars.duckplyr_df) vctrs::s3_register("dplyr::inner_join", "data.frame", inner_join.duckplyr_df) vctrs::s3_register("dplyr::intersect", "data.frame", intersect.duckplyr_df) diff --git a/R/restore.R b/R/restore.R index ceda7370..f0b07be7 100644 --- a/R/restore.R +++ b/R/restore.R @@ -16,6 +16,7 @@ methods_restore <- function() { vctrs::s3_register("dplyr::full_join", "data.frame", dplyr$full_join.data.frame) vctrs::s3_register("dplyr::group_by", "data.frame", dplyr$group_by.data.frame) vctrs::s3_register("dplyr::group_data", "data.frame", dplyr$group_data.data.frame) + vctrs::s3_register("dplyr::group_keys", "data.frame", dplyr$group_keys.data.frame) vctrs::s3_register("dplyr::group_vars", "data.frame", dplyr$group_vars.data.frame) vctrs::s3_register("dplyr::inner_join", "data.frame", dplyr$inner_join.data.frame) vctrs::s3_register("dplyr::intersect", "data.frame", dplyr$intersect.data.frame) diff --git a/dplyr-methods/group_keys.txt b/dplyr-methods/group_keys.txt new file mode 100644 index 00000000..fe3d3542 --- /dev/null +++ b/dplyr-methods/group_keys.txt @@ -0,0 +1,12 @@ +group_keys.data.frame <- function(.tbl, ...) { + if (dots_n(...) > 0) { + lifecycle::deprecate_warn( + "1.0.0", "group_keys(... = )", + details = "Please `group_by()` first", + always = TRUE + ) + .tbl <- group_by(.tbl, ...) + } + out <- group_data(.tbl) + group_keys0(out) +} diff --git a/tests/testthat/test-as_duckplyr_df.R b/tests/testthat/test-as_duckplyr_df.R index daa885d7..c60f39bc 100644 --- a/tests/testthat/test-as_duckplyr_df.R +++ b/tests/testthat/test-as_duckplyr_df.R @@ -770,6 +770,22 @@ test_that("as_duckplyr_df() and group_data()", { expect_equal(pre, post) }) +test_that("as_duckplyr_df() and group_keys()", { + withr::local_envvar(DUCKPLYR_FORCE = "FALSE") + + skip("Special") + + # Data + test_df <- data.frame(a = 1:6 + 0, b = 2, g = rep(1:3, 1:3)) + + # Run + pre <- test_df %>% as_duckplyr_df() %>% group_keys() + post <- test_df %>% group_keys() %>% as_duckplyr_df() + + # Compare + expect_equal(pre, post) +}) + test_that("as_duckplyr_df() and group_vars()", { withr::local_envvar(DUCKPLYR_FALLBACK_FORCE = "TRUE") diff --git a/tests/testthat/test-filter.R b/tests/testthat/test-filter.R index f91a2a3a..1f2f0d38 100644 --- a/tests/testthat/test-filter.R +++ b/tests/testthat/test-filter.R @@ -230,7 +230,7 @@ test_that("grouped filter handles indices (#880)", { res2 <- duckplyr_mutate(res, Petal = Petal.Width * Petal.Length) expect_equal(nrow(res), nrow(res2)) expect_equal(group_rows(res), group_rows(res2)) - expect_equal(group_keys(res), group_keys(res2)) + expect_equal(duckplyr_group_keys(res), duckplyr_group_keys(res2)) }) test_that("duckplyr_filter(FALSE) handles indices", { diff --git a/tests/testthat/test-slice.R b/tests/testthat/test-slice.R index 7616feab..6152ad23 100644 --- a/tests/testthat/test-slice.R +++ b/tests/testthat/test-slice.R @@ -77,11 +77,11 @@ test_that("slice preserves groups iff requested", { gf <- duckplyr_group_by(tibble(g = c(1, 2, 2, 3, 3, 3), id = 1:6), g) out <- duckplyr_slice(gf, 2, 3) - expect_equal(group_keys(out), tibble(g = c(2, 3))) + expect_equal(duckplyr_group_keys(out), tibble(g = c(2, 3))) expect_equal(group_rows(out), list_of(1, c(2, 3))) out <- duckplyr_slice(gf, 2, 3, .preserve = TRUE) - expect_equal(group_keys(out), tibble(g = c(1, 2, 3))) + expect_equal(duckplyr_group_keys(out), tibble(g = c(1, 2, 3))) expect_equal(group_rows(out), list_of(integer(), 1, c(2, 3))) }) diff --git a/tools/00-funs.R b/tools/00-funs.R index 9937558f..11d934d8 100644 --- a/tools/00-funs.R +++ b/tools/00-funs.R @@ -16,7 +16,7 @@ df_methods <- filter(!grepl("_$|^as[.]tbl$", name)) %>% # special dplyr methods, won't implement filter(!(name %in% c( - "group_indices", "group_keys", "group_map", "group_modify", "group_nest", "group_size", "group_split", "group_trim", "groups", "n_groups", + "group_indices", "group_map", "group_modify", "group_nest", "group_size", "group_split", "group_trim", "groups", "n_groups", "same_src", # data frames can be copied into duck-frames with zero cost NULL ))) %>% @@ -26,6 +26,7 @@ df_methods <- "dplyr_row_slice", "group_by", "group_data", + "group_keys", "rowwise", NULL ))) %>% @@ -693,6 +694,7 @@ test_skip_map <- c( dplyr_reconstruct = "Hack", group_by = "Grouped", group_data = "Special", + group_keys = "Special", group_map = "WAT", group_modify = "Grouped", group_nest = "Always returns tibble",