Skip to content

Commit

Permalink
minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
dehann committed Apr 7, 2021
1 parent 550f3bf commit 1c031df
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/KDE01.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ end
Extract the marginal distribution from the given higher dimensional kernel density estimate object.
"""
function marginal(bd::BallTreeDensity, ind::Array{Int,1})
function marginal(bd::BallTreeDensity, ind::AbstractVector{<:Integer})
pts = getPoints(bd)
if size(bd.bandwidth,2) > 2*bd.bt.num_points
sig = getBW(bd)
Expand All @@ -152,12 +152,12 @@ function marginal(bd::BallTreeDensity, ind::Array{Int,1})
p = kde!(pts[ind,:],sig[ind], wts)
end

function randKernel(N::Int, M::Int, ::Type{KernelDensityEstimate.GaussianKer}) #t::Int)
function randKernel(N::Int, M::Int, ::Type{<:KernelDensityEstimate.GaussianKer}) #t::Int)
return randn(N,M)
end

"""
$(SIGNATURES)
$(SIGNATURES)
Randomly sample points from the KernelDensityEstimate object.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/StringSerialization.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
function string(d::KernelDensityEstimate.BallTreeDensity)
function Base.string(d::KernelDensityEstimate.BallTreeDensity)
# TODO only supports single bandwidth per dimension at this point
pts = getPoints(d)
return "KDE:$(size(pts,2)):$(getBW(d)[:,1]):$(pts)"
end

function parsestringvector(str::AS; dlim=',') where {AS <: AbstractString}
function parsestringvector(str::AbstractString; dlim=',')
sstr = split(split(strip(str),'[')[end],']')[1]
ssstr = strip.(split(sstr,dlim))
parse.(Float64, ssstr)
end

function convert(::Type{BallTreeDensity}, str::AS) where {AS <: AbstractString}
function convert(::Type{<:BallTreeDensity}, str::AbstractString)
@assert occursin(r"KDE:", str) # ismatch
sstr = strip.(split(str, ':'))
N = parse(Int, sstr[2])
Expand Down

0 comments on commit 1c031df

Please sign in to comment.