diff --git a/Project.toml b/Project.toml index 0ef4e9bd..54a3b316 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PSIS" uuid = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04" authors = ["Seth Axen and contributors"] -version = "0.8.0" +version = "0.9.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/core.jl b/src/core.jl index 781e1476..0025f672 100644 --- a/src/core.jl +++ b/src/core.jl @@ -1,10 +1,10 @@ # range, description, condition const SHAPE_DIAGNOSTIC_CATEGORIES = ( - ("(-Inf, 0.5]", "good", x -> !ismissing(x) && x ≤ 0.5), - ("(0.5, 0.7]", "okay", x -> !ismissing(x) && 0.5 < x ≤ 0.7), - ("(0.7, 1]", "bad", x -> !ismissing(x) && 0.7 < x ≤ 1), - ("(1, Inf)", "very bad", x -> !ismissing(x) && x > 1), - ("——", "missing", ismissing), + ("(-Inf, 0.5]", "good", ≤(0.5)), + ("(0.5, 0.7]", "okay", x -> 0.5 < x ≤ 0.7), + ("(0.7, 1]", "bad", x -> 0.7 < x ≤ 1), + ("(1, Inf)", "very bad", >(1)), + ("——", "failed", isnan), ) const BAD_SHAPE_SUMMARY = "Resulting importance sampling estimates are likely to be unstable." const VERY_BAD_SHAPE_SUMMARY = "Corresponding importance sampling estimates are likely to be unstable and are unlikely to converge with additional samples." @@ -117,8 +117,8 @@ function _print_pareto_shape_summary(io::IO, r::PSISResult; kwargs...) inds = findall(cond, k) count = length(inds) perc = 100 * count / npoints - ess_min = if count == 0 || desc == "too few draws" - missing + ess_min = if count == 0 || desc == "failed" + oftype(first(ess), NaN) else minimum(view(ess, inds)) end @@ -130,7 +130,7 @@ function _print_pareto_shape_summary(io::IO, r::PSISResult; kwargs...) "okay" => (; color=:yellow), "bad" => (bold=true, color=:light_red), "very bad" => (bold=true, color=:red), - "missing" => (), + "failed" => (; color=:red), ) col_padding = " " @@ -165,7 +165,7 @@ function _print_pareto_shape_summary(io::IO, r::PSISResult; kwargs...) format = formats[r.desc] printstyled(io, _pad_left(count, col_widths[3]); format...) printstyled(io, " ", _pad_right(perc_str, col_widths[4]); format...) - print(io, col_delim_tot, r.ess_min === missing ? "——" : floor(Int, r.ess_min)) + print(io, col_delim_tot, isfinite(r.ess_min) ? floor(Int, r.ess_min) : "——") end return nothing end @@ -218,6 +218,7 @@ function psis(logr, reff=1; kwargs...) end function psis!(logw::AbstractVecOrMat, reff=1; normalize::Bool=true, warn::Bool=true) + T = typeof(float(one(eltype(logw)))) S = length(logw) reff_val = first(reff) M = tail_length(reff_val, S) @@ -225,7 +226,8 @@ function psis!(logw::AbstractVecOrMat, reff=1; normalize::Bool=true, warn::Bool= warn && @warn "$M tail draws is insufficient to fit the generalized Pareto distribution. Total number of draws should in general exceed 25." _maybe_log_normalize!(logw, normalize) - return PSISResult(logw, reff_val, M, missing, normalize) + tail_dist_failed = GeneralizedPareto(0, T(NaN), T(NaN)) + return PSISResult(logw, reff_val, M, tail_dist_failed, normalize) end perm = partialsortperm(logw, (S - M):S) cutoff_ind = perm[1] @@ -236,7 +238,8 @@ function psis!(logw::AbstractVecOrMat, reff=1; normalize::Bool=true, warn::Bool= warn && @warn "Tail contains non-finite values. Generalized Pareto distribution cannot be reliably fit." _maybe_log_normalize!(logw, normalize) - return PSISResult(logw, reff_val, M, missing, normalize) + tail_dist_failed = GeneralizedPareto(0, T(NaN), T(NaN)) + return PSISResult(logw, reff_val, M, tail_dist_failed, normalize) end _, tail_dist = psis_tail!(logw_tail, logu) warn && check_pareto_shape(tail_dist) @@ -259,7 +262,7 @@ function psis!(logw::AbstractArray, reff=1; normalize::Bool=true, warn::Bool=tru reffs = similar(logw, eltype(reff), param_axes) reffs .= reff tail_lengths = similar(logw, Int, param_axes) - tail_dists = similar(logw, Union{Missing,GeneralizedPareto{T}}, param_axes) + tail_dists = similar(logw, GeneralizedPareto{T}, param_axes) # call psis! in parallel for all parameters Threads.@threads for i in _eachparamindex(logw) @@ -277,7 +280,6 @@ function psis!(logw::AbstractArray, reff=1; normalize::Bool=true, warn::Bool=tru return result end -pareto_shape(::Missing) = missing pareto_shape(dist::GeneralizedPareto) = dist.k pareto_shape(r::PSISResult) = pareto_shape(getfield(r, :tail_dist)) pareto_shape(dists) = map(pareto_shape, dists) @@ -292,18 +294,18 @@ function check_pareto_shape(dist::GeneralizedPareto) end return nothing end -function check_pareto_shape(dists::AbstractArray{<:Union{Missing,GeneralizedPareto}}) - nmissing = count(ismissing, dists) - ngt07 = count(x -> !(ismissing(x)) && pareto_shape(x) > 0.7, dists) - ngt1 = iszero(ngt07) ? ngt07 : count(x -> !(ismissing(x)) && pareto_shape(x) > 1, dists) +function check_pareto_shape(dists::AbstractArray{<:GeneralizedPareto}) + nnan = count(isnan ∘ pareto_shape, dists) + ngt07 = count(>(0.7) ∘ pareto_shape, dists) + ngt1 = iszero(ngt07) ? ngt07 : count(>(1) ∘ pareto_shape, dists) if ngt07 > ngt1 @warn "$(ngt07 - ngt1) parameters had Pareto shape values 0.7 < k ≤ 1. $BAD_SHAPE_SUMMARY" end if ngt1 > 0 @warn "$ngt1 parameters had Pareto shape values k > 1. $VERY_BAD_SHAPE_SUMMARY" end - if nmissing > 0 - @warn "For $nmissing parameters, the generalized Pareto distribution could not be fit to the tail draws. Total number of draws should in general exceed 25, and the tail draws must be finite." + if nnan > 0 + @warn "For $nnan parameters, the generalized Pareto distribution could not be fit to the tail draws. Total number of draws should in general exceed 25, and the tail draws must be finite." end return nothing end diff --git a/src/ess.jl b/src/ess.jl index 595a66b0..9cc45416 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -11,32 +11,34 @@ Given normalized weights ``w_{1:n}``, the ESS is estimated using the L2-norm of where ``r_{\\mathrm{eff}}`` is the relative efficiency of the `log_weights`. - ess_is(result::PSISResult; bad_shape_missing=true) + ess_is(result::PSISResult; bad_shape_nan=true) Estimate ESS for Pareto-smoothed importance sampling. !!! note ESS estimates for Pareto shape values ``k > 0.7``, which are unreliable and misleadingly - high, are set to `missing`. To avoid this, set `bad_shape_missing=false`. + high, are set to `NaN`. To avoid this, set `bad_shape_nan=false`. """ ess_is -function ess_is(r::PSISResult; bad_shape_missing::Bool=true) +function ess_is(r::PSISResult; bad_shape_nan::Bool=true) neff = ess_is(r.weights; reff=r.reff) - return _apply_missing(neff, r.tail_dist; bad_shape_missing=bad_shape_missing) + return _apply_nan(neff, r.tail_dist; bad_shape_nan=bad_shape_nan) end function ess_is(weights; reff=1) dims = _sample_dims(weights) return reff ./ dropdims(sum(abs2, weights; dims=dims); dims=dims) end -function _apply_missing(neff, dist; bad_shape_missing) - return bad_shape_missing && pareto_shape(dist) > 0.7 ? missing : neff +function _apply_nan(neff, dist; bad_shape_nan) + bad_shape_nan || return neff + k = pareto_shape(dist) + (isnan(k) || k > 0.7) && return oftype(neff, NaN) + return neff end -_apply_missing(neff, ::Missing; kwargs...) = missing -function _apply_missing(ess::AbstractArray, tail_dist::AbstractArray; kwargs...) +function _apply_nan(ess::AbstractArray, tail_dist::AbstractArray; kwargs...) return map(ess, tail_dist) do essᵢ, tail_distᵢ - return _apply_missing(essᵢ, tail_distᵢ; kwargs...) + return _apply_nan(essᵢ, tail_distᵢ; kwargs...) end end diff --git a/test/core.jl b/test/core.jl index e1408c03..42304573 100644 --- a/test/core.jl +++ b/test/core.jl @@ -90,7 +90,7 @@ using DimensionalData: Dimensions, DimArray (0.5, 0.7] okay 6 (20.0%) 92 (0.7, 1] bad 4 (13.3%) —— (1, Inf) very bad 17 (56.7%) —— - —— missing 1 (3.3%) ——""" + —— failed 1 (3.3%) ——""" end end end @@ -116,7 +116,7 @@ end x = rand(rng, proposal, sz) logr = logpdf.(target, x) .- logpdf.(proposal, x) - r = psis(logr) + r = @inferred psis(logr) @test r isa PSISResult logw = r.log_weights @test logw isa typeof(logr) @@ -160,8 +160,8 @@ end psis(logr; normalize=false) end @test result.log_weights == logr - @test ismissing(result.tail_dist) - @test ismissing(result.pareto_shape) + @test isnan(result.tail_dist.σ) + @test isnan(result.pareto_shape) msg = String(take!(io)) @test occursin( "Warning: 1 tail draws is insufficient to fit the generalized Pareto distribution.", @@ -180,8 +180,8 @@ end psis(logr; normalize=false) end @test skipnan(result.log_weights) == skipnan(logr) - @test ismissing(result.tail_dist) - @test ismissing(result.pareto_shape) + @test isnan(result.tail_dist.σ) + @test isnan(result.pareto_shape) msg = String(take!(io)) @test occursin("Warning: Tail contains non-finite values.", msg) end @@ -226,7 +226,7 @@ end @test isempty(msg) tail_dist = [ - missing, + PSIS.GeneralizedPareto(0, NaN, NaN), PSIS.GeneralizedPareto(0, 1, 0.69), PSIS.GeneralizedPareto(0, 1, 0.71), PSIS.GeneralizedPareto(0, 1, 1.1), @@ -260,7 +260,7 @@ end logr = permutedims(logr, (2, 3, 1)) @testset for r_eff in (0.7, 1.2) r_effs = fill(r_eff, sz[1]) - result = psis(logr, r_effs; normalize=false) + result = @inferred psis(logr, r_effs; normalize=false) logw = result.log_weights @test !isapprox(logw, logr) basename = "normal_to_cauchy_reff_$(r_eff)" @@ -295,7 +295,7 @@ end Dimensions.Dim{:param}(param_names), ), ) - result = psis(logr) + result = @inferred psis(logr) @test result.log_weights isa DimArray @test Dimensions.dims(result.log_weights) == Dimensions.dims(logr) for k in (:pareto_shape, :tail_length, :tail_dist, :reff) diff --git a/test/ess.jl b/test/ess.jl index 2b374cf5..05c7fcd3 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -20,14 +20,14 @@ using Test @test ess_is(result) ≈ ess_is(result.weights; reff=1.5) result = PSISResult(logw, 1.5, 20, PSIS.GeneralizedPareto(0.0, 1.0, 0.71), false) - @test ismissing(ess_is(result)) - @test ess_is(result; bad_shape_missing=false) ≈ ess_is(result.weights; reff=1.5) + @test isnan(ess_is(result)) + @test ess_is(result; bad_shape_nan=false) ≈ ess_is(result.weights; reff=1.5) logw = randn(100, 4, 3) tail_dists = [ PSIS.GeneralizedPareto(0.0, 1.0, 0.69), PSIS.GeneralizedPareto(0.0, 1.0, 0.71), - missing, + PSIS.GeneralizedPareto(0.0, NaN, NaN), ] reff = [1.5, 0.8, 1.0] result = PSISResult(logw, reff, [20, 20, 20], tail_dists, false) @@ -35,9 +35,8 @@ using Test @test ess isa Vector @test length(ess) == 3 @test ess[1] ≈ ess_is(result.weights; reff=reff)[1] - @test ismissing(ess[2]) - @test ismissing(ess[3]) - ess = ess_is(result; bad_shape_missing=false) - @test ess[1:2] ≈ ess_is(result.weights; reff=reff)[1:2] - @test ismissing(ess[3]) + @test isnan(ess[2]) + @test isnan(ess[3]) + ess = ess_is(result; bad_shape_nan=false) + @test ess ≈ ess_is(result.weights; reff=reff)[1:3] end