From dd7f1f860c4f66d1fd93de0bc1ae167c357b5055 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sat, 6 Jan 2024 20:04:28 +0800 Subject: [PATCH] use `Integer` during broadcast when possible. --- base/broadcast.jl | 20 ++++++++------------ test/broadcast.jl | 19 ++++++++++++++++--- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index e653c4086e452..e63631c156c7e 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -571,15 +571,15 @@ an `Int`. Any remaining indices in `I` beyond the length of the `keep` tuple are truncated. The `keep` and `default` tuples may be created by `newindexer(argument)`. """ -Base.@propagate_inbounds newindex(arg, I::CartesianIndex) = CartesianIndex(_newindex(axes(arg), I.I)) -Base.@propagate_inbounds newindex(arg, I::Integer) = CartesianIndex(_newindex(axes(arg), (I,))) +Base.@propagate_inbounds newindex(arg, I::CartesianIndex) = to_index(_newindex(axes(arg), I.I)) +Base.@propagate_inbounds newindex(arg, I::Integer) = to_index(_newindex(axes(arg), (I,))) Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple) = (ifelse(length(ax[1]) == 1, ax[1][1], I[1]), _newindex(tail(ax), tail(I))...) Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple) = () Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple{}) = (ax[1][1], _newindex(tail(ax), ())...) Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = () # If dot-broadcasting were already defined, this would be `ifelse.(keep, I, Idefault)`. -@inline newindex(I::CartesianIndex, keep, Idefault) = CartesianIndex(_newindex(I.I, keep, Idefault)) +@inline newindex(I::CartesianIndex, keep, Idefault) = to_index(_newindex(I.I, keep, Idefault)) @inline newindex(i::Integer, keep::Tuple, idefault) = ifelse(keep[1], i, idefault[1]) @inline newindex(i::Integer, keep::Tuple{}, idefault) = CartesianIndex(()) @inline _newindex(I, keep, Idefault) = @@ -599,18 +599,14 @@ Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = () (Base.length(ind1)::Integer != 1, keep...), (first(ind1), Idefault...) end -@inline function Base.getindex(bc::Broadcasted, I::Union{Integer,CartesianIndex}) +@inline function Base.getindex(bc::Broadcasted, Is::Vararg{Union{Integer,CartesianIndex},N}) where {N} + I = to_index(Base.IteratorsMD.flatten(Is)) @boundscheck checkbounds(bc, I) @inbounds _broadcast_getindex(bc, I) end -Base.@propagate_inbounds Base.getindex( - bc::Broadcasted, - i1::Union{Integer,CartesianIndex}, - i2::Union{Integer,CartesianIndex}, - I::Union{Integer,CartesianIndex}..., -) = - bc[CartesianIndex((i1, i2, I...))] -Base.@propagate_inbounds Base.getindex(bc::Broadcasted) = bc[CartesianIndex(())] +to_index(::Tuple{}) = CartesianIndex() +to_index(Is::Tuple{Any}) = Is[1] +to_index(Is::Tuple) = CartesianIndex(Is) @inline Base.checkbounds(bc::Broadcasted, I::Union{Integer,CartesianIndex}) = Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,)) diff --git a/test/broadcast.jl b/test/broadcast.jl index ff0fc1401a703..1e5752921bc93 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -49,9 +49,9 @@ ci(x) = CartesianIndex(x) @test @inferred(newindex(ci((2,2)), (true, false), (-1,-1))) == ci((2,-1)) @test @inferred(newindex(ci((2,2)), (false, true), (-1,-1))) == ci((-1,2)) @test @inferred(newindex(ci((2,2)), (false, false), (-1,-1))) == ci((-1,-1)) -@test @inferred(newindex(ci((2,2)), (true,), (-1,-1))) == ci((2,)) -@test @inferred(newindex(ci((2,2)), (true,), (-1,))) == ci((2,)) -@test @inferred(newindex(ci((2,2)), (false,), (-1,))) == ci((-1,)) +@test @inferred(newindex(ci((2,2)), (true,), (-1,-1))) == 2 +@test @inferred(newindex(ci((2,2)), (true,), (-1,))) == 2 +@test @inferred(newindex(ci((2,2)), (false,), (-1,))) == -1 @test @inferred(newindex(ci((2,2)), (), ())) == ci(()) end @@ -1175,3 +1175,16 @@ import Base.Broadcast: BroadcastStyle, DefaultArrayStyle f51129(v, x) = (1 .- (v ./ x) .^ 2) @test @inferred(f51129([13.0], 6.5)) == [-3.0] + +@testset "broadcast for `AbstractArray` without `CartesianIndex` support" begin + struct BVec52775 <: AbstractVector{Int} + a::Vector{Int} + end + Base.size(a::BVec52775) = size(a.a) + Base.getindex(a::BVec52775, i::Real) = a.a[i] + Base.getindex(a::BVec52775, i) = error("unsupported index!") + a = BVec52775([1,2,3]) + bc = Base.broadcasted(identity, a) + @test bc[1] == bc[CartesianIndex(1)] == bc[1, CartesianIndex()] + @test a .+ [1 2] == a.a .+ [1 2] +end