Skip to content

Commit

Permalink
Allow to pass multiple predicates in Cols and mix them with other s…
Browse files Browse the repository at this point in the history
…electors (#3279)
  • Loading branch information
bkamins authored Feb 5, 2023
1 parent cf893d2 commit 3368d85
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
* Add support for `operator` keyword argument in `Cols`
to take a set operation to apply to passed selectors (`union` by default)
([3224](https://github.com/JuliaData/DataFrames.jl/pull/3224))
* Allow to pass multiple predicates in `Cols` and mix them with
other selectors
([3279](https://github.com/JuliaData/DataFrames.jl/pull/3279))
* Improve support for setting group order in `groupby`
([3253](https://github.com/JuliaData/DataFrames.jl/pull/3253))
* Joining functions now support `order` keyword argument allowing the user
Expand Down
12 changes: 6 additions & 6 deletions src/other/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,16 @@ end
@inline Base.getindex(x::AbstractIndex, idx::All) =
isempty(idx.cols) ? (1:length(x)) : throw(ArgumentError("All(args...) is not supported: use Cols(args...) instead"))

@inline _getindex_cols(x::AbstractIndex, idx::Any) = x[idx]
@inline _getindex_cols(x::AbstractIndex, idx::Function) = findall(idx, names(x))
# the definition below is needed because `:` is a Function
@inline _getindex_cols(x::AbstractIndex, idx::Colon) = x[idx]

@inline function Base.getindex(x::AbstractIndex, idx::Cols)
isempty(idx.cols) && return Int[]
return idx.operator(getindex.(Ref(x), idx.cols)...)
return idx.operator(_getindex_cols.(Ref(x), idx.cols)...)
end

# the definition below is needed because `:` is a Function
@inline Base.getindex(x::AbstractIndex, idx::Cols{Tuple{typeof(:)}}) = x[:]
@inline Base.getindex(x::AbstractIndex, idx::Cols{<:Tuple{Function}}) =
findall(idx.cols[1], names(x))

@inline function Base.getindex(x::AbstractIndex, idx::AbstractVector{<:Integer})
if any(v -> v isa Bool, idx)
throw(ArgumentError("Bool values except for AbstractVector{Bool} are not " *
Expand Down
22 changes: 18 additions & 4 deletions test/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,10 @@ end
@test df[:, Cols(x -> x[1] == 'a')] == df[:, [1, 2]]
@test df[:, Cols(x -> x[end] == '1')] == df[:, [1, 3]]
@test df[:, Cols(x -> x[end] == '3')] == DataFrame()
@test_throws ArgumentError df[:, Cols(x -> true, 1)]
@test_throws ArgumentError df[:, Cols(1, x -> true)]
@test df[:, Cols(x -> true, 1)] == df
@test df[:, Cols(1, x -> true)] == df
@test df[:, Cols(x -> true, 1, operator=intersect)] == DataFrame(a1=1)
@test df[:, Cols(1, x -> true, operator=intersect)] == DataFrame(a1=1)

@test ncol(select(df, Cols(operator=intersect))) == 0
@test ncol(df[:, Cols(operator=intersect)]) == 0
Expand Down Expand Up @@ -539,8 +541,20 @@ end
@test df[:, Cols(x -> x[1] == 'a', operator=intersect)] == df[:, [1, 2]]
@test df[:, Cols(x -> x[end] == '1', operator=intersect)] == df[:, [1, 3]]
@test df[:, Cols(x -> x[end] == '3', operator=intersect)] == DataFrame()
@test_throws ArgumentError df[:, Cols(x -> true, 1, operator=intersect)]
@test_throws ArgumentError df[:, Cols(1, x -> true, operator=intersect)]
@test df[:, Cols(x -> true, 1, operator=intersect)] == df[:, 1:1]
@test df[:, Cols(1, x -> true, operator=intersect)] == df[:, 1:1]

@test df[:, Cols(startswith("a"), endswith("2"))] ==
select(df, Cols(startswith("a"), endswith("2"))) ==
df[:, ["a1", "a2", "b2"]]
@test df[:, Cols(startswith("a"), endswith("2"), operator=intersect)] ==
df[:, Cols(startswith("a"), :, endswith("2"), operator=intersect)] ==
select(df, Cols(startswith("a"), endswith("2"), operator=intersect)) ==
df[:, ["a2"]]
@test df[:, Cols(startswith("a"), endswith("2"), operator=setdiff)] ==
select(df, Cols(startswith("a"), endswith("2"), operator=setdiff)) ==
df[:, ["a1"]]
@test df[:, Cols(startswith("a"), endswith("2"), :, operator=setdiff)] == DataFrame()
end

@testset "views" begin
Expand Down

0 comments on commit 3368d85

Please sign in to comment.