diff --git a/NAMESPACE b/NAMESPACE index cdc8e1f6..9a51505f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -412,6 +412,9 @@ importFrom(vctrs,vec_check_size) importFrom(vctrs,vec_data) importFrom(vctrs,vec_in) importFrom(vctrs,vec_match) +importFrom(vctrs,vec_ptype) +importFrom(vctrs,vec_ptype2) +importFrom(vctrs,vec_ptype_finalise) importFrom(vctrs,vec_rbind) importFrom(vctrs,vec_recycle_common) importFrom(vctrs,vec_rep) diff --git a/R/dplyr.R b/R/dplyr.R index e7862b9b..595109e3 100644 --- a/R/dplyr.R +++ b/R/dplyr.R @@ -67,6 +67,7 @@ join_rows <- dplyr$join_rows mutate_cols <- dplyr$mutate_cols mutate_keep <- dplyr$mutate_keep named_args <- dplyr$named_args +rethrow_error_join_incompatible_type <- dplyr$rethrow_error_join_incompatible_type rows_bind <- dplyr$rows_bind rows_cast_y <- dplyr$rows_cast_y rows_check_by <- dplyr$rows_check_by diff --git a/R/duckplyr-package.R b/R/duckplyr-package.R index d677d338..e352bc2b 100644 --- a/R/duckplyr-package.R +++ b/R/duckplyr-package.R @@ -74,6 +74,9 @@ #' @importFrom vctrs vec_data #' @importFrom vctrs vec_in #' @importFrom vctrs vec_match +#' @importFrom vctrs vec_ptype +#' @importFrom vctrs vec_ptype_finalise +#' @importFrom vctrs vec_ptype2 #' @importFrom vctrs vec_rbind #' @importFrom vctrs vec_recycle_common #' @importFrom vctrs vec_rep diff --git a/R/join.R b/R/join.R index e58b3750..b13fc53c 100644 --- a/R/join.R +++ b/R/join.R @@ -1,4 +1,13 @@ -rel_join_impl <- function(x, y, by, join, na_matches, suffix, keep, error_call = caller_env()) { +rel_join_impl <- function( + x, + y, + by, + join, + na_matches, + suffix = c(".x", ".y"), + keep = NULL, + error_call = caller_env() +) { mutating <- !(join %in% c("semi", "anti")) if (mutating) { @@ -23,6 +32,25 @@ rel_join_impl <- function(x, y, by, join, na_matches, suffix, keep, error_call = y_rel <- duckdb_rel_from_df(y) y_rel <- rel_set_alias(y_rel, "rhs") + # FIXME: Split join_cols, https://github.com/tidyverse/dplyr/issues/7050 + vars <- join_cols( + x_names = x_names, + y_names = y_names, + by = by, + suffix = suffix, + keep = keep, + error_call = error_call + ) + + x_in <- vec_ptype(x) + y_in <- vec_ptype(y) + + x_key <- set_names(x_in[vars$x$key], names(vars$x$key)) + y_key <- set_names(y_in[vars$y$key], names(vars$x$key)) + + # Side effect: check join compatibility + join_ptype_common(x_key, y_key, vars, error_call = error_call) + # Rename if non-unique column names if (mutating) { if (length(intersect(x_names, y_names)) != 0) { @@ -74,15 +102,6 @@ rel_join_impl <- function(x, y, by, join, na_matches, suffix, keep, error_call = list(x_rel, y_rel) ) - vars <- join_cols( - x_names = x_names, - y_names = y_names, - by = by, - suffix = suffix, - keep = keep, - error_call = error_call - ) - exprs <- c( nexprs_from_loc(x_names_remap, vars$x$out), nexprs_from_loc(y_names_remap, vars$y$out) diff --git a/R/join_ptype_common.R b/R/join_ptype_common.R new file mode 100644 index 00000000..1bea5197 --- /dev/null +++ b/R/join_ptype_common.R @@ -0,0 +1,15 @@ +# https://github.com/tidyverse/dplyr/pull/7029 + +join_ptype_common <- function(x, y, vars, error_call = caller_env()) { + # Explicit `x/y_arg = ""` to avoid auto naming in `cnd$x_arg` + ptype <- try_fetch( + vec_ptype2(x, y, x_arg = "", y_arg = "", call = error_call), + vctrs_error_incompatible_type = function(cnd) { + rethrow_error_join_incompatible_type(cnd, vars, error_call) + } + ) + # Finalize unspecified columns (#6804) + ptype <- vec_ptype_finalise(ptype) + + ptype +} diff --git a/tests/testthat/_snaps/join.md b/tests/testthat/_snaps/join.md index 224e48a9..73111ec5 100644 --- a/tests/testthat/_snaps/join.md +++ b/tests/testthat/_snaps/join.md @@ -75,6 +75,17 @@ i `x$a` is a . i `y$b` is a . +# filtering joins reference original column in `y` when there are type errors (#6465) + + Code + (expect_error(duckplyr_semi_join(x, y, by = join_by(a == b)))) + Output + + Error in `semi_join()`: + ! Can't join `x$a` with `y$b` due to incompatible types. + i `x$a` is a . + i `y$b` is a . + # error if passed additional arguments Code diff --git a/tests/testthat/test-join.R b/tests/testthat/test-join.R index 94b59b6e..8e047639 100644 --- a/tests/testthat/test-join.R +++ b/tests/testthat/test-join.R @@ -378,7 +378,6 @@ test_that("mutating joins don't trigger many-to-many warning when called indirec }) test_that("mutating joins compute common columns", { - skip("TODO duckdb") df1 <- tibble(x = c(1, 2), y = c(2, 3)) df2 <- tibble(x = c(1, 3), z = c(2, 3)) expect_snapshot(out <- duckplyr_left_join(df1, df2)) @@ -450,7 +449,6 @@ test_that("mutating joins reference original column in `y` when there are type e }) test_that("filtering joins reference original column in `y` when there are type errors (#6465)", { - skip("TODO duckdb") x <- tibble(a = 1) y <- tibble(b = "1") diff --git a/tools/00-funs.R b/tools/00-funs.R index 52af75e6..3613d6b2 100644 --- a/tools/00-funs.R +++ b/tools/00-funs.R @@ -112,10 +112,7 @@ duckplyr_tests <- head(n = -1, list( "mutating joins trigger multiple match warning", "mutating joins don't trigger multiple match warning when called indirectly", - "filtering joins reference original column in `y` when there are type errors (#6465)", - "mutating joins trigger many-to-many warning", - "mutating joins compute common columns", "mutating joins don't trigger many-to-many warning when called indirectly", NULL ),