Skip to content

Commit

Permalink
Stop returning missing for failured GPD fits (#51)
Browse files Browse the repository at this point in the history
* Don't set tail_dist to missing with failures

* Update ess accordingly

* Update tests

* Update minor version number

* Add checks for type-inference
  • Loading branch information
sethaxen authored Jul 1, 2023
1 parent e07c715 commit e0f8756
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 46 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PSIS"
uuid = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.8.0"
version = "0.9.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
40 changes: 21 additions & 19 deletions src/core.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# range, description, condition
const SHAPE_DIAGNOSTIC_CATEGORIES = (
("(-Inf, 0.5]", "good", x -> !ismissing(x) && x 0.5),
("(0.5, 0.7]", "okay", x -> !ismissing(x) && 0.5 < x 0.7),
("(0.7, 1]", "bad", x -> !ismissing(x) && 0.7 < x 1),
("(1, Inf)", "very bad", x -> !ismissing(x) && x > 1),
("——", "missing", ismissing),
("(-Inf, 0.5]", "good", (0.5)),
("(0.5, 0.7]", "okay", x -> 0.5 < x 0.7),
("(0.7, 1]", "bad", x -> 0.7 < x 1),
("(1, Inf)", "very bad", >(1)),
("——", "failed", isnan),
)
const BAD_SHAPE_SUMMARY = "Resulting importance sampling estimates are likely to be unstable."
const VERY_BAD_SHAPE_SUMMARY = "Corresponding importance sampling estimates are likely to be unstable and are unlikely to converge with additional samples."
Expand Down Expand Up @@ -117,8 +117,8 @@ function _print_pareto_shape_summary(io::IO, r::PSISResult; kwargs...)
inds = findall(cond, k)
count = length(inds)
perc = 100 * count / npoints
ess_min = if count == 0 || desc == "too few draws"
missing
ess_min = if count == 0 || desc == "failed"
oftype(first(ess), NaN)
else
minimum(view(ess, inds))
end
Expand All @@ -130,7 +130,7 @@ function _print_pareto_shape_summary(io::IO, r::PSISResult; kwargs...)
"okay" => (; color=:yellow),
"bad" => (bold=true, color=:light_red),
"very bad" => (bold=true, color=:red),
"missing" => (),
"failed" => (; color=:red),
)

col_padding = " "
Expand Down Expand Up @@ -165,7 +165,7 @@ function _print_pareto_shape_summary(io::IO, r::PSISResult; kwargs...)
format = formats[r.desc]
printstyled(io, _pad_left(count, col_widths[3]); format...)
printstyled(io, " ", _pad_right(perc_str, col_widths[4]); format...)
print(io, col_delim_tot, r.ess_min === missing ? "——" : floor(Int, r.ess_min))
print(io, col_delim_tot, isfinite(r.ess_min) ? floor(Int, r.ess_min) : "——")
end
return nothing
end
Expand Down Expand Up @@ -218,14 +218,16 @@ function psis(logr, reff=1; kwargs...)
end

function psis!(logw::AbstractVecOrMat, reff=1; normalize::Bool=true, warn::Bool=true)
T = typeof(float(one(eltype(logw))))
S = length(logw)
reff_val = first(reff)
M = tail_length(reff_val, S)
if M < 5
warn &&
@warn "$M tail draws is insufficient to fit the generalized Pareto distribution. Total number of draws should in general exceed 25."
_maybe_log_normalize!(logw, normalize)
return PSISResult(logw, reff_val, M, missing, normalize)
tail_dist_failed = GeneralizedPareto(0, T(NaN), T(NaN))
return PSISResult(logw, reff_val, M, tail_dist_failed, normalize)
end
perm = partialsortperm(logw, (S - M):S)
cutoff_ind = perm[1]
Expand All @@ -236,7 +238,8 @@ function psis!(logw::AbstractVecOrMat, reff=1; normalize::Bool=true, warn::Bool=
warn &&
@warn "Tail contains non-finite values. Generalized Pareto distribution cannot be reliably fit."
_maybe_log_normalize!(logw, normalize)
return PSISResult(logw, reff_val, M, missing, normalize)
tail_dist_failed = GeneralizedPareto(0, T(NaN), T(NaN))
return PSISResult(logw, reff_val, M, tail_dist_failed, normalize)
end
_, tail_dist = psis_tail!(logw_tail, logu)
warn && check_pareto_shape(tail_dist)
Expand All @@ -259,7 +262,7 @@ function psis!(logw::AbstractArray, reff=1; normalize::Bool=true, warn::Bool=tru
reffs = similar(logw, eltype(reff), param_axes)
reffs .= reff
tail_lengths = similar(logw, Int, param_axes)
tail_dists = similar(logw, Union{Missing,GeneralizedPareto{T}}, param_axes)
tail_dists = similar(logw, GeneralizedPareto{T}, param_axes)

# call psis! in parallel for all parameters
Threads.@threads for i in _eachparamindex(logw)
Expand All @@ -277,7 +280,6 @@ function psis!(logw::AbstractArray, reff=1; normalize::Bool=true, warn::Bool=tru
return result
end

pareto_shape(::Missing) = missing
pareto_shape(dist::GeneralizedPareto) = dist.k
pareto_shape(r::PSISResult) = pareto_shape(getfield(r, :tail_dist))
pareto_shape(dists) = map(pareto_shape, dists)
Expand All @@ -292,18 +294,18 @@ function check_pareto_shape(dist::GeneralizedPareto)
end
return nothing
end
function check_pareto_shape(dists::AbstractArray{<:Union{Missing,GeneralizedPareto}})
nmissing = count(ismissing, dists)
ngt07 = count(x -> !(ismissing(x)) && pareto_shape(x) > 0.7, dists)
ngt1 = iszero(ngt07) ? ngt07 : count(x -> !(ismissing(x)) && pareto_shape(x) > 1, dists)
function check_pareto_shape(dists::AbstractArray{<:GeneralizedPareto})
nnan = count(isnan pareto_shape, dists)
ngt07 = count(>(0.7) pareto_shape, dists)
ngt1 = iszero(ngt07) ? ngt07 : count(>(1) pareto_shape, dists)
if ngt07 > ngt1
@warn "$(ngt07 - ngt1) parameters had Pareto shape values 0.7 < k ≤ 1. $BAD_SHAPE_SUMMARY"
end
if ngt1 > 0
@warn "$ngt1 parameters had Pareto shape values k > 1. $VERY_BAD_SHAPE_SUMMARY"
end
if nmissing > 0
@warn "For $nmissing parameters, the generalized Pareto distribution could not be fit to the tail draws. Total number of draws should in general exceed 25, and the tail draws must be finite."
if nnan > 0
@warn "For $nnan parameters, the generalized Pareto distribution could not be fit to the tail draws. Total number of draws should in general exceed 25, and the tail draws must be finite."
end
return nothing
end
Expand Down
20 changes: 11 additions & 9 deletions src/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,34 @@ Given normalized weights ``w_{1:n}``, the ESS is estimated using the L2-norm of
where ``r_{\\mathrm{eff}}`` is the relative efficiency of the `log_weights`.
ess_is(result::PSISResult; bad_shape_missing=true)
ess_is(result::PSISResult; bad_shape_nan=true)
Estimate ESS for Pareto-smoothed importance sampling.
!!! note
ESS estimates for Pareto shape values ``k > 0.7``, which are unreliable and misleadingly
high, are set to `missing`. To avoid this, set `bad_shape_missing=false`.
high, are set to `NaN`. To avoid this, set `bad_shape_nan=false`.
"""
ess_is

function ess_is(r::PSISResult; bad_shape_missing::Bool=true)
function ess_is(r::PSISResult; bad_shape_nan::Bool=true)
neff = ess_is(r.weights; reff=r.reff)
return _apply_missing(neff, r.tail_dist; bad_shape_missing=bad_shape_missing)
return _apply_nan(neff, r.tail_dist; bad_shape_nan=bad_shape_nan)
end
function ess_is(weights; reff=1)
dims = _sample_dims(weights)
return reff ./ dropdims(sum(abs2, weights; dims=dims); dims=dims)
end

function _apply_missing(neff, dist; bad_shape_missing)
return bad_shape_missing && pareto_shape(dist) > 0.7 ? missing : neff
function _apply_nan(neff, dist; bad_shape_nan)
bad_shape_nan || return neff
k = pareto_shape(dist)
(isnan(k) || k > 0.7) && return oftype(neff, NaN)
return neff
end
_apply_missing(neff, ::Missing; kwargs...) = missing
function _apply_missing(ess::AbstractArray, tail_dist::AbstractArray; kwargs...)
function _apply_nan(ess::AbstractArray, tail_dist::AbstractArray; kwargs...)
return map(ess, tail_dist) do essᵢ, tail_distᵢ
return _apply_missing(essᵢ, tail_distᵢ; kwargs...)
return _apply_nan(essᵢ, tail_distᵢ; kwargs...)
end
end
18 changes: 9 additions & 9 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ using DimensionalData: Dimensions, DimArray
(0.5, 0.7] okay 6 (20.0%) 92
(0.7, 1] bad 4 (13.3%) ——
(1, Inf) very bad 17 (56.7%) ——
—— missing 1 (3.3%) ——"""
—— failed 1 (3.3%) ——"""
end
end
end
Expand All @@ -116,7 +116,7 @@ end
x = rand(rng, proposal, sz)
logr = logpdf.(target, x) .- logpdf.(proposal, x)

r = psis(logr)
r = @inferred psis(logr)
@test r isa PSISResult
logw = r.log_weights
@test logw isa typeof(logr)
Expand Down Expand Up @@ -160,8 +160,8 @@ end
psis(logr; normalize=false)
end
@test result.log_weights == logr
@test ismissing(result.tail_dist)
@test ismissing(result.pareto_shape)
@test isnan(result.tail_dist.σ)
@test isnan(result.pareto_shape)
msg = String(take!(io))
@test occursin(
"Warning: 1 tail draws is insufficient to fit the generalized Pareto distribution.",
Expand All @@ -180,8 +180,8 @@ end
psis(logr; normalize=false)
end
@test skipnan(result.log_weights) == skipnan(logr)
@test ismissing(result.tail_dist)
@test ismissing(result.pareto_shape)
@test isnan(result.tail_dist.σ)
@test isnan(result.pareto_shape)
msg = String(take!(io))
@test occursin("Warning: Tail contains non-finite values.", msg)
end
Expand Down Expand Up @@ -226,7 +226,7 @@ end
@test isempty(msg)

tail_dist = [
missing,
PSIS.GeneralizedPareto(0, NaN, NaN),
PSIS.GeneralizedPareto(0, 1, 0.69),
PSIS.GeneralizedPareto(0, 1, 0.71),
PSIS.GeneralizedPareto(0, 1, 1.1),
Expand Down Expand Up @@ -260,7 +260,7 @@ end
logr = permutedims(logr, (2, 3, 1))
@testset for r_eff in (0.7, 1.2)
r_effs = fill(r_eff, sz[1])
result = psis(logr, r_effs; normalize=false)
result = @inferred psis(logr, r_effs; normalize=false)
logw = result.log_weights
@test !isapprox(logw, logr)
basename = "normal_to_cauchy_reff_$(r_eff)"
Expand Down Expand Up @@ -295,7 +295,7 @@ end
Dimensions.Dim{:param}(param_names),
),
)
result = psis(logr)
result = @inferred psis(logr)
@test result.log_weights isa DimArray
@test Dimensions.dims(result.log_weights) == Dimensions.dims(logr)
for k in (:pareto_shape, :tail_length, :tail_dist, :reff)
Expand Down
15 changes: 7 additions & 8 deletions test/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,23 @@ using Test
@test ess_is(result) ess_is(result.weights; reff=1.5)

result = PSISResult(logw, 1.5, 20, PSIS.GeneralizedPareto(0.0, 1.0, 0.71), false)
@test ismissing(ess_is(result))
@test ess_is(result; bad_shape_missing=false) ess_is(result.weights; reff=1.5)
@test isnan(ess_is(result))
@test ess_is(result; bad_shape_nan=false) ess_is(result.weights; reff=1.5)

logw = randn(100, 4, 3)
tail_dists = [
PSIS.GeneralizedPareto(0.0, 1.0, 0.69),
PSIS.GeneralizedPareto(0.0, 1.0, 0.71),
missing,
PSIS.GeneralizedPareto(0.0, NaN, NaN),
]
reff = [1.5, 0.8, 1.0]
result = PSISResult(logw, reff, [20, 20, 20], tail_dists, false)
ess = ess_is(result)
@test ess isa Vector
@test length(ess) == 3
@test ess[1] ess_is(result.weights; reff=reff)[1]
@test ismissing(ess[2])
@test ismissing(ess[3])
ess = ess_is(result; bad_shape_missing=false)
@test ess[1:2] ess_is(result.weights; reff=reff)[1:2]
@test ismissing(ess[3])
@test isnan(ess[2])
@test isnan(ess[3])
ess = ess_is(result; bad_shape_nan=false)
@test ess ess_is(result.weights; reff=reff)[1:3]
end

2 comments on commit e0f8756

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/86683

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.0 -m "<description of version>" e0f8756c9766e1a78d1fab36835d9bb647f89db6
git push origin v0.9.0

Please sign in to comment.