Skip to content

Commit

Permalink
Various fixes to byte / bytearray search (#54579)
Browse files Browse the repository at this point in the history
This was originally intended as a targeted fix to #54578, but I ran into
a bunch of smaller issues with this code that also needed to be solved
and it turned out to be difficult to fix them with small, trivial PRs.

I would also like to refactor this whole file, but I want these
correctness fixes to be merged first, because a larger refactoring has
higher risk of getting stuck without getting reviewed and merged.

## Larger things that needs decisions
* The internal union `Base.ByteArray` has been deleted. Instead, the
unions `DenseInt8` and `DenseUInt8` have been added. These more
comprehensively cover the types that was meant, e.g. `Memory{UInt8}` was
incorrectly not covered by the former. As stated in the TODO, the
concept of a "memory backed dense byte array" is needed throughout
Julia, so this ideally needs to be implemented as a single type and used
throughout Base. The fix here is a decent temporary solution. See #53178
#54581
* The `findall` docstring between two arrays was incorrectly not
attached to the method - now it is. **Note that this change _changes_
the documentation** since it includes a docstring that was previously
missed. Hence, it's an API addition.
* Added a new minimal `testhelpers/OffsetDenseArrays.jl` which provide a
`DenseVector` with offset axes for testing purposes.

## Trivial fixes
* `findfirst(==(Int8(-1)), [0xff])` and similar findlast, findnext and
findprev is no longer buggy, see #54578
* `findfirst([0x0ff], Int8[-1])` is similarly no longer buggy, see
#54578
* `findnext(==('\xa6'), "æ", 1)` and `findprev(==('\xa6'), "æa", 2)` no
longer incorrectly throws an error
* The byte-oriented find* functions now work correctly with offset
arrays
* Fixed incorrect use of `GC.@preserve`, where the pointer was taken
before the preserve block.
* More of the optimised string methods now also apply to
`SubString{String}`


Closes #54578
Co-authored-by: Martin Holters <[email protected]>
  • Loading branch information
jakobnissen authored and kshyatt committed Sep 12, 2024
1 parent 5a816bd commit f82fd8e
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 47 deletions.
1 change: 1 addition & 0 deletions base/char.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ hash(x::Char, h::UInt) =
hash_uint64(((bitcast(UInt32, x) + UInt64(0xd4d64234)) << 32) UInt64(h))

first_utf8_byte(c::Char) = (bitcast(UInt32, c) >> 24) % UInt8
first_utf8_byte(c::AbstractChar) = first_utf8_byte(Char(c)::Char)

# fallbacks:
isless(x::AbstractChar, y::AbstractChar) = isless(Char(x), Char(y))
Expand Down
136 changes: 89 additions & 47 deletions base/strings/search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,29 @@ match strings with [`match`](@ref).
"""
abstract type AbstractPattern end

nothing_sentinel(i) = i == 0 ? nothing : i
# TODO: These unions represent bytes in memory that can be accessed via a pointer.
# this property is used throughout Julia, e.g. also in IO code.
# This deserves a better solution - see #53178.
# If such a better solution comes in place, these unions should be replaced.
const DenseInt8 = Union{
DenseArray{Int8},
FastContiguousSubArray{Int8,N,<:DenseArray} where N
}

# Note: This union is different from that above in that it includes CodeUnits.
# Currently, this is redundant as CodeUnits <: DenseVector, but this subtyping
# is buggy and may be removed in the future, see #54002
const DenseUInt8 = Union{
DenseArray{UInt8},
FastContiguousSubArray{UInt8,N,<:DenseArray} where N,
CodeUnits{UInt8, <:Union{String, SubString{String}}},
FastContiguousSubArray{UInt8,N,<:CodeUnits{UInt8, <:Union{String, SubString{String}}}} where N,
}

const DenseUInt8OrInt8 = Union{DenseUInt8, DenseInt8}

last_byteindex(x::Union{String, SubString{String}}) = ncodeunits(x)
last_byteindex(x::DenseUInt8OrInt8) = lastindex(x)

function last_utf8_byte(c::Char)
u = reinterpret(UInt32, c)
Expand All @@ -30,11 +52,11 @@ function findnext(pred::Fix2{<:Union{typeof(isequal),typeof(==)},<:AbstractChar}
end
@inbounds isvalid(s, i) || string_index_err(s, i)
c = pred.x
c '\x7f' && return nothing_sentinel(_search(s, c % UInt8, i))
c '\x7f' && return _search(s, first_utf8_byte(c), i)
while true
i = _search(s, first_utf8_byte(c), i)
i == 0 && return nothing
pred(s[i]) && return i
i === nothing && return nothing
isvalid(s, i) && pred(s[i]) && return i
i = nextind(s, i)
end
end
Expand All @@ -47,31 +69,41 @@ const DenseBytes = Union{
CodeUnits{UInt8, <:Union{String, SubString{String}}},
}

const ByteArray = Union{DenseBytes, DenseArrayType{Int8}}
function findfirst(pred::Fix2{<:Union{typeof(isequal),typeof(==)},<:Union{UInt8, Int8}}, a::Union{DenseInt8, DenseUInt8})
findnext(pred, a, firstindex(a))
end

findfirst(pred::Fix2{<:Union{typeof(isequal),typeof(==)},<:Union{Int8,UInt8}}, a::ByteArray) =
nothing_sentinel(_search(a, pred.x))
function findnext(pred::Fix2{<:Union{typeof(isequal),typeof(==)},UInt8}, a::DenseUInt8, i::Integer)
_search(a, pred.x, i)
end

findnext(pred::Fix2{<:Union{typeof(isequal),typeof(==)},<:Union{Int8,UInt8}}, a::ByteArray, i::Integer) =
nothing_sentinel(_search(a, pred.x, i))
function findnext(pred::Fix2{<:Union{typeof(isequal),typeof(==)},Int8}, a::DenseInt8, i::Integer)
_search(a, pred.x, i)
end

findfirst(::typeof(iszero), a::ByteArray) = nothing_sentinel(_search(a, zero(UInt8)))
findnext(::typeof(iszero), a::ByteArray, i::Integer) = nothing_sentinel(_search(a, zero(UInt8), i))
# iszero is special, in that the bitpattern for zero for Int8 and UInt8 is the same,
# so we can use memchr even if we search for an Int8 in an UInt8 array or vice versa
findfirst(::typeof(iszero), a::DenseUInt8OrInt8) = _search(a, zero(UInt8))
findnext(::typeof(iszero), a::DenseUInt8OrInt8, i::Integer) = _search(a, zero(UInt8), i)

function _search(a::Union{String,SubString{String},<:ByteArray}, b::Union{Int8,UInt8}, i::Integer = 1)
if i < 1
function _search(a::Union{String,SubString{String},DenseUInt8OrInt8}, b::Union{Int8,UInt8}, i::Integer = firstindex(a))
fst = firstindex(a)
lst = last_byteindex(a)
if i < fst
throw(BoundsError(a, i))
end
n = sizeof(a)
if i > n
return i == n+1 ? 0 : throw(BoundsError(a, i))
n_bytes = lst - i + 1
if i > lst
return i == lst+1 ? nothing : throw(BoundsError(a, i))
end
p = pointer(a)
q = GC.@preserve a ccall(:memchr, Ptr{UInt8}, (Ptr{UInt8}, Int32, Csize_t), p+i-1, b, n-i+1)
return q == C_NULL ? 0 : Int(q-p+1)
GC.@preserve a begin
p = pointer(a)
q = ccall(:memchr, Ptr{UInt8}, (Ptr{UInt8}, Int32, Csize_t), p+i-fst, b, n_bytes)
end
return q == C_NULL ? nothing : (q-p+fst) % Int
end

function _search(a::ByteArray, b::AbstractChar, i::Integer = 1)
function _search(a::DenseUInt8, b::AbstractChar, i::Integer = firstindex(a))
if isascii(b)
_search(a,UInt8(b),i)
else
Expand All @@ -80,41 +112,51 @@ function _search(a::ByteArray, b::AbstractChar, i::Integer = 1)
end

function findprev(pred::Fix2{<:Union{typeof(isequal),typeof(==)},<:AbstractChar},
s::String, i::Integer)
s::Union{String, SubString{String}}, i::Integer)
c = pred.x
c '\x7f' && return nothing_sentinel(_rsearch(s, c % UInt8, i))
c '\x7f' && return _rsearch(s, first_utf8_byte(c), i)
b = first_utf8_byte(c)
while true
i = _rsearch(s, b, i)
i == 0 && return nothing
pred(s[i]) && return i
i == nothing && return nothing
isvalid(s, i) && pred(s[i]) && return i
i = prevind(s, i)
end
end

findlast(pred::Fix2{<:Union{typeof(isequal),typeof(==)},<:Union{Int8,UInt8}}, a::ByteArray) =
nothing_sentinel(_rsearch(a, pred.x))
function findlast(pred::Fix2{<:Union{typeof(isequal),typeof(==)},<:Union{Int8,UInt8}}, a::DenseUInt8OrInt8)
findprev(pred, a, lastindex(a))
end

findprev(pred::Fix2{<:Union{typeof(isequal),typeof(==)},<:Union{Int8,UInt8}}, a::ByteArray, i::Integer) =
nothing_sentinel(_rsearch(a, pred.x, i))
function findprev(pred::Fix2{<:Union{typeof(isequal),typeof(==)},Int8}, a::DenseInt8, i::Integer)
_rsearch(a, pred.x, i)
end

findlast(::typeof(iszero), a::ByteArray) = nothing_sentinel(_rsearch(a, zero(UInt8)))
findprev(::typeof(iszero), a::ByteArray, i::Integer) = nothing_sentinel(_rsearch(a, zero(UInt8), i))
function findprev(pred::Fix2{<:Union{typeof(isequal),typeof(==)},UInt8}, a::DenseUInt8, i::Integer)
_rsearch(a, pred.x, i)
end

function _rsearch(a::Union{String,ByteArray}, b::Union{Int8,UInt8}, i::Integer = sizeof(a))
if i < 1
return i == 0 ? 0 : throw(BoundsError(a, i))
# See comments above for findfirst(::typeof(iszero)) methods
findlast(::typeof(iszero), a::DenseUInt8OrInt8) = _rsearch(a, zero(UInt8))
findprev(::typeof(iszero), a::DenseUInt8OrInt8, i::Integer) = _rsearch(a, zero(UInt8), i)

function _rsearch(a::Union{String,SubString{String},DenseUInt8OrInt8}, b::Union{Int8,UInt8}, i::Integer = last_byteindex(a))
fst = firstindex(a)
lst = last_byteindex(a)
if i < fst
return i == fst - 1 ? nothing : throw(BoundsError(a, i))
end
if i > lst
return i == lst+1 ? nothing : throw(BoundsError(a, i))
end
n = sizeof(a)
if i > n
return i == n+1 ? 0 : throw(BoundsError(a, i))
GC.@preserve a begin
p = pointer(a)
q = ccall(:memrchr, Ptr{UInt8}, (Ptr{UInt8}, Int32, Csize_t), p, b, i-fst+1)
end
p = pointer(a)
q = GC.@preserve a ccall(:memrchr, Ptr{UInt8}, (Ptr{UInt8}, Int32, Csize_t), p, b, i)
return q == C_NULL ? 0 : Int(q-p+1)
return q == C_NULL ? nothing : (q-p+fst) % Int
end

function _rsearch(a::ByteArray, b::AbstractChar, i::Integer = length(a))
function _rsearch(a::DenseUInt8, b::AbstractChar, i::Integer = length(a))
if isascii(b)
_rsearch(a,UInt8(b),i)
else
Expand Down Expand Up @@ -224,18 +266,19 @@ end

in(c::AbstractChar, s::AbstractString) = (findfirst(isequal(c),s)!==nothing)

function _searchindex(s::Union{AbstractString,ByteArray},
function _searchindex(s::Union{AbstractString,DenseUInt8OrInt8},
t::Union{AbstractString,AbstractChar,Int8,UInt8},
i::Integer)
sentinel = firstindex(s) - 1
x = Iterators.peel(t)
if isnothing(x)
return 1 <= i <= nextind(s,lastindex(s))::Int ? i :
return firstindex(s) <= i <= nextind(s,lastindex(s))::Int ? i :
throw(BoundsError(s, i))
end
t1, trest = x
while true
i = findnext(isequal(t1),s,i)
if i === nothing return 0 end
if i === nothing return sentinel end
ii = nextind(s, i)::Int
a = Iterators.Stateful(trest)
matched = all(splat(==), zip(SubString(s, ii), a))
Expand Down Expand Up @@ -509,9 +552,8 @@ julia> findall(UInt8[1,2], UInt8[1,2,3,1,2])
!!! compat "Julia 1.3"
This method requires at least Julia 1.3.
"""

function findall(t::Union{AbstractString, AbstractPattern, AbstractVector{<:Union{Int8,UInt8}}},
s::Union{AbstractString, AbstractPattern, AbstractVector{<:Union{Int8,UInt8}}},
function findall(t::Union{AbstractString, AbstractPattern, AbstractVector{UInt8}},
s::Union{AbstractString, AbstractPattern, AbstractVector{UInt8}},
; overlap::Bool=false)
found = UnitRange{Int}[]
i, e = firstindex(s), lastindex(s)
Expand Down Expand Up @@ -564,7 +606,7 @@ function _rsearchindex(s::AbstractString,
end
end

function _rsearchindex(s::String, t::String, i::Integer)
function _rsearchindex(s::Union{String, SubString{String}}, t::Union{String, SubString{String}}, i::Integer)
# Check for fast case of a single byte
if lastindex(t) == 1
return something(findprev(isequal(t[1]), s, i), 0)
Expand Down
49 changes: 49 additions & 0 deletions test/strings/search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ for str in [u8str]
@test findprev(isequal('ε'), str, 4) === nothing
end

# See the comments in #54579
@testset "Search for invalid chars" begin
@test findfirst(==('\xff'), "abc\xffde") == 4
@test findprev(isequal('\xa6'), "abc\xa69", 5) == 4
@test isnothing(findfirst(==('\xff'), "abcdeæd"))

@test isnothing(findnext(==('\xa6'), "æ", 1))
@test isnothing(findprev(==('\xa6'), "æa", 2))
end

# string forward search with a single-char string
@test findfirst("x", astr) === nothing
@test findfirst("H", astr) == 1:1
Expand Down Expand Up @@ -445,6 +455,45 @@ end
@test_throws BoundsError findprev(pattern, A, -3)
end
end

@test findall([0x01, 0x02], [0x03, 0x01, 0x02, 0x01, 0x02, 0x06]) == [2:3, 4:5]
@test isempty(findall([0x04, 0x05], [0x03, 0x04, 0x06]))
end

# Issue 54578
@testset "No conflation of Int8 and UInt8" begin
# Work for mixed types if the values are the same
@test findfirst(==(Int8(1)), [0x01]) == 1
@test findnext(iszero, Int8[0, -2, 0, -3], 2) == 3
@test findfirst(Int8[1,4], UInt8[0, 2, 4, 1, 8, 1, 4, 2]) == 6:7
@test findprev(UInt8[5, 6], Int8[1, 9, 2, 5, 6, 3], 6) == 4:5

# Returns nothing for the same methods if the values are different,
# even if the bitpatterns are the same
@test isnothing(findfirst(==(Int8(-1)), [0xff]))
@test isnothing(findnext(isequal(0xff), Int8[-1, -2, -1], 2))
@test isnothing(findfirst(UInt8[0xff, 0xfe], Int8[0, -1, -2, 1, 8, 1, 4, 2]))
@test isnothing(findprev(UInt8[0xff, 0xfe], Int8[1, 9, 2, -1, -2, 3], 6))
end

@testset "DenseArray with offsets" begin
isdefined(Main, :OffsetDenseArrays) || @eval Main include("../testhelpers/OffsetDenseArrays.jl")
OffsetDenseArrays = Main.OffsetDenseArrays

A = OffsetDenseArrays.OffsetDenseArray(collect(0x61:0x69), 100)
@test findfirst(==(0x61), A) == 101
@test findlast(==(0x61), A) == 101
@test findfirst(==(0x00), A) === nothing

@test findfirst([0x62, 0x63, 0x64], A) == 102:104
@test findlast([0x63, 0x64], A) == 103:104
@test findall([0x62, 0x63], A) == [102:103]

@test findfirst(iszero, A) === nothing
A = OffsetDenseArrays.OffsetDenseArray([0x01, 0x02, 0x00, 0x03], -100)
@test findfirst(iszero, A) == -97
@test findnext(==(0x02), A, -99) == -98
@test findnext(==(0x02), A, -97) === nothing
end

# issue 32568
Expand Down
31 changes: 31 additions & 0 deletions test/testhelpers/OffsetDenseArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
module OffsetDenseArrays
A minimal implementation of an offset array which is also <: DenseArray.
"""
module OffsetDenseArrays

struct OffsetDenseArray{A <: DenseVector, T} <: DenseVector{T}
x::A
offset::Int
end
OffsetDenseArray(x::AbstractVector{T}, i::Integer) where {T} = OffsetDenseArray{typeof(x), T}(x, Int(i))

Base.size(x::OffsetDenseArray) = size(x.x)
Base.pointer(x::OffsetDenseArray) = pointer(x.x)

function Base.getindex(x::OffsetDenseArray, i::Integer)
@boundscheck checkbounds(x.x, i - x.offset)
x.x[i - x.offset]
end

function Base.setindex(x::OffsetDenseArray, v, i::Integer)
@boundscheck checkbounds(x.x, i - x.offset)
x.x[i - x.offset] = v
end

IndexStyle(::Type{<:OffsetDenseArray}) = Base.IndexLinear()
Base.axes(x::OffsetDenseArray) = (x.offset + 1 : x.offset + length(x.x),)
Base.keys(x::OffsetDenseArray) = only(axes(x))

end # module

0 comments on commit f82fd8e

Please sign in to comment.