Skip to content

Commit

Permalink
Add Julia implementation of LOO-PIT (#277)
Browse files Browse the repository at this point in the history
* Add Julia implementation of loopit

* Make test approximate

* Add `smooth_data` to API docs

* Handle case where condition is never met

* Delete methods and update docstrings

* Rearrange code

* Update docstrings

* Reduce allocations

* Rename variables

* Run formatter

* Split methods

* Add loo_pit tests

* Support older Julia versions

* Remove docs manifest

* Add tests and raise informative warnings

* Remove comment

* Test loo_pit with discrete data

* Change to jldoctests

* Add Distributions to test environment

* Reduce docstring errors

* Increase epsilon

* Test type-inference only on recent Julia versions

* Add warning when smoothing automatic

* Replace Interpolations with DataInterpolations

This reduces load time by 0.5s due to invalidations in Interpolations.

* Add a backward-compatible _eachslice

* Use DataInterpolations in smoothing
  • Loading branch information
sethaxen authored Jul 16, 2023
1 parent 7f49e14 commit ebd2b0e
Show file tree
Hide file tree
Showing 9 changed files with 533 additions and 6 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArviZExampleData = "2f96bb34-afd9-46ae-bcd0-9b2d4372fe3c"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -33,6 +34,7 @@ ArviZExampleData = "0.1"
Conda = "1.0"
DataDeps = "0.7"
DataFrames = "0.20, 0.21, 0.22, 1.0"
DataInterpolations = "4"
DimensionalData = "0.23, 0.24"
DocStringExtensions = "0.8, 0.9"
InferenceObjects = "0.3.9"
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api/stats.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,10 @@ waic
compare
loo_pit
```


### Utilities

```@docs
ArviZStats.smooth_data
```
6 changes: 5 additions & 1 deletion src/ArviZStats/ArviZStats.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ArviZStats

using ArviZ: ArviZ, arviz, @forwardfun
using DataInterpolations: DataInterpolations
using DimensionalData: DimensionalData, Dimensions
using DocStringExtensions: FIELDS, FUNCTIONNAME, TYPEDEF, SIGNATURES
using InferenceObjects: InferenceObjects
Expand All @@ -24,17 +25,20 @@ export elpd_estimates, information_criterion, loo, waic
# Others
export compare, hdi, kde, loo_pit, r2_score, summary, summarystats

# load for docstrings
using ArviZ: InferenceData, convert_to_dataset, ess

const INFORMATION_CRITERION_SCALES = (deviance=-2, log=1, negative_log=-1)

@forwardfun compare
@forwardfun hdi
@forwardfun kde
@forwardfun loo_pit
@forwardfun r2_score

include("utils.jl")
include("elpdresult.jl")
include("loo.jl")
include("loo_pit.jl")
include("waic.jl")

function ArviZ.convert_arguments(::typeof(compare), data, args...; kwargs...)
Expand Down
283 changes: 283 additions & 0 deletions src/ArviZStats/loo_pit.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
"""
loo_pit(y, y_pred, log_weights; kwargs...) -> Union{Real,AbstractArray}
Compute leave-one-out probability integral transform (LOO-PIT) checks.
# Arguments
- `y`: array of observations with shape `(params...,)`
- `y_pred`: array of posterior predictive samples with shape `(draws, chains, params...)`.
- `log_weights`: array of normalized log LOO importance weights with shape
`(draws, chains, params...)`.
# Keywords
- `is_discrete`: If not provided, then it is set to `true` iff elements of `y` and `y_pred`
are all integer-valued. If `true`, then data are smoothed using [`smooth_data`](@ref) to
make them non-discrete before estimating LOO-PIT values.
- `kwargs`: Remaining keywords are forwarded to `smooth_data` if data is discrete.
# Returns
- `pitvals`: LOO-PIT values with same size as `y`. If `y` is a scalar, then `pitvals` is a
scalar.
LOO-PIT is a marginal posterior predictive check. If ``y_{-i}`` is the array ``y`` of
observations with the ``i``th observation left out, and ``y_i^*`` is a posterior prediction
of the ``i``th observation, then the LOO-PIT value for the ``i``th observation is defined as
```math
P(y_i^* \\le y_i \\mid y_{-i}) = \\int_{-\\infty}^{y_i} p(y_i^* \\mid y_{-i}) \\mathrm{d} y_i^*
```
The LOO posterior predictions and the corresponding observations should have similar
distributions, so if conditional predictive distributions are well-calibrated, then all
LOO-PIT values should be approximately uniformly distributed on ``[0, 1]``.[^Gabry2019]
[^Gabry2019]: Gabry, J., Simpson, D., Vehtari, A., Betancourt, M. & Gelman, A.
Visualization in Bayesian Workflow.
J. R. Stat. Soc. Ser. A Stat. Soc. 182, 389–402 (2019).
doi: [10.1111/rssa.12378](https://doi.org/10.1111/rssa.12378)
arXiv: [1709.01449](https://arxiv.org/abs/1709.01449)
# Examples
Calculate LOO-PIT values using as test quantity the observed values themselves.
```jldoctest loo_pit1
using ArviZ
idata = load_example_data("centered_eight")
log_weights = loo(idata; var_name=:obs).psis_result.log_weights
loo_pit(
idata.observed_data.obs,
permutedims(idata.posterior_predictive.obs, (:draw, :chain, :school)),
log_weights,
)
# output
8-element DimArray{Float64,1} with dimensions:
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
"Choate" 0.943511
"Deerfield" 0.63797
"Phillips Andover" 0.316697
"Phillips Exeter" 0.582252
"Hotchkiss" 0.295321
"Lawrenceville" 0.403318
"St. Paul's" 0.902508
"Mt. Hermon" 0.655275
```
Calculate LOO-PIT values using as test quantity the square of the difference between
each observation and `mu`.
```jldoctest loo_pit1
using DimensionalData, Statistics
T = idata.observed_data.obs .- only(median(idata.posterior.mu; dims=(:draw, :chain)))
T_pred = permutedims(
broadcast_dims(-, idata.posterior_predictive.obs, idata.posterior.mu),
(:draw, :chain, :school),
)
loo_pit(T .^ 2, T_pred .^ 2, log_weights)
# output
8-element DimArray{Float64,1} with dimensions:
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
"Choate" 0.873577
"Deerfield" 0.243686
"Phillips Andover" 0.357563
"Phillips Exeter" 0.149908
"Hotchkiss" 0.435094
"Lawrenceville" 0.220627
"St. Paul's" 0.775086
"Mt. Hermon" 0.296706
```
"""
function loo_pit(
y::Union{AbstractArray,Number},
y_pred::AbstractArray,
log_weights::AbstractArray;
is_discrete::Union{Bool,Nothing}=nothing,
kwargs...,
)
sample_dims = (1, 2)
size(y) == size(y_pred)[3:end] ||
throw(ArgumentError("data dimensions of `y` and `y_pred` must have the size"))
size(log_weights) == size(y_pred) ||
throw(ArgumentError("`log_weights` and `y_pred` must have same size"))
_is_discrete = if is_discrete === nothing
all(isinteger, y) && all(isinteger, y_pred)
else
is_discrete
end
if _is_discrete
is_discrete === nothing &&
@warn "All data and predictions are integer-valued. Smoothing data before running `loo_pit`."
y_smooth = smooth_data(y; kwargs...)
y_pred_smooth = smooth_data(y_pred; dims=_otherdims(y_pred, sample_dims), kwargs...)
return _loo_pit(y_smooth, y_pred_smooth, log_weights)
else
return _loo_pit(y, y_pred, log_weights)
end
end

"""
loo_pit(idata::InferenceData, log_weights; kwargs...) -> DimArray
Compute LOO-PIT values using existing normalized log LOO importance weights.
# Keywords
- `y_name`: Name of observed data variable in `idata.observed_data`. If not provided, then
the only observed data variable is used.
- `y_pred_name`: Name of posterior predictive variable in `idata.posterior_predictive`.
If not provided, then `y_name` is used.
- `kwargs`: Remaining keywords are forwarded to [`loo_pit`](@ref).
# Examples
Calculate LOO-PIT values using already computed log weights.
```jldoctest
using ArviZ
idata = load_example_data("centered_eight")
loo_result = loo(idata; var_name=:obs)
loo_pit(idata, loo_result.psis_result.log_weights; y_name=:obs)
# output
8-element DimArray{Float64,1} loo_pit_obs with dimensions:
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
"Choate" 0.943511
"Deerfield" 0.63797
"Phillips Andover" 0.316697
"Phillips Exeter" 0.582252
"Hotchkiss" 0.295321
"Lawrenceville" 0.403318
"St. Paul's" 0.902508
"Mt. Hermon" 0.655275
```
"""
function loo_pit(
idata::InferenceObjects.InferenceData,
log_weights::AbstractArray;
y_name::Union{Symbol,Nothing}=nothing,
y_pred_name::Union{Symbol,Nothing}=nothing,
kwargs...,
)
_y_name = y_name === nothing ? _only_observed_data_key(idata) : y_name
_y_pred_name = y_pred_name === nothing ? _y_name : y_pred_name
haskey(idata, :posterior_predictive) ||
throw(ArgumentError("No `posterior_predictive` group"))
y = idata.observed_data[_y_name]
y_pred = _draw_chains_params_array(idata.posterior_predictive[_y_pred_name])
pitvals = loo_pit(y, y_pred, log_weights; kwargs...)
return DimensionalData.rebuild(pitvals; name=Symbol("loo_pit_$(_y_name)"))
end

"""
loo_pit(idata::InferenceData; kwargs...) -> DimArray
Compute LOO-PIT from groups in `idata` using PSIS-LOO.
See also: [`loo`](@ref), [`psis`](@ref)
# Keywords
- `y_name`: Name of observed data variable in `idata.observed_data`. If not provided, then
the only observed data variable is used.
- `y_pred_name`: Name of posterior predictive variable in `idata.posterior_predictive`.
If not provided, then `y_name` is used.
- `log_likelihood_name`: Name of log-likelihood variable in `idata.log_likelihood`.
If not provided, then `y_name` is used if `idata` has a `log_likelihood` group,
otherwise the only variable is used.
- `reff::Union{Real,AbstractArray{<:Real}}`: The relative effective sample size(s) of the
_likelihood_ values. If an array, it must have the same data dimensions as the
corresponding log-likelihood variable. If not provided, then this is estimated using
[`ess`](@ref).
- `kwargs`: Remaining keywords are forwarded to [`loo_pit`](@ref).
# Examples
Calculate LOO-PIT values using as test quantity the observed values themselves.
```jldoctest
using ArviZ
idata = load_example_data("centered_eight")
loo_pit(idata; y_name=:obs)
# output
8-element DimArray{Float64,1} loo_pit_obs with dimensions:
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
"Choate" 0.943511
"Deerfield" 0.63797
"Phillips Andover" 0.316697
"Phillips Exeter" 0.582252
"Hotchkiss" 0.295321
"Lawrenceville" 0.403318
"St. Paul's" 0.902508
"Mt. Hermon" 0.655275
```
"""
function loo_pit(
idata::InferenceObjects.InferenceData;
y_name::Union{Symbol,Nothing}=nothing,
log_likelihood_name::Union{Symbol,Nothing}=nothing,
reff=nothing,
kwargs...,
)
_y_name = y_name === nothing ? _only_observed_data_key(idata) : y_name
if log_likelihood_name === nothing
if haskey(idata, :log_likelihood)
_log_like = log_likelihood(idata.log_likelihood, _y_name)
elseif haskey(idata, :sample_stats) && haskey(idata.sample_stats, :log_likelihood)
_log_like = idata.sample_stats.log_likelihood
else
throw(ArgumentError("There must be a `log_likelihood` group in `idata`"))
end
else
_log_like = log_likelihood(idata.log_likelihood, log_likelihood_name)
end
log_like = _draw_chains_params_array(_log_like)
psis_result = _psis_loo_setup(log_like, reff)
return loo_pit(idata, psis_result.log_weights; y_name=_y_name, kwargs...)
end

function _loo_pit(y::Number, y_pred, log_weights)
return @views exp.(LogExpFunctions.logsumexp(log_weights[y_pred .≤ y]))
end
function _loo_pit(y::AbstractArray, y_pred, log_weights)
sample_dims = (1, 2)
T = typeof(exp(zero(float(eltype(log_weights)))))
pitvals = similar(y, T)
param_dims = _otherdims(log_weights, sample_dims)
# work around for `eachslices` not supporting multiple dims in older Julia versions
map!(
pitvals,
y,
CartesianIndices(map(Base.Fix1(axes, y_pred), param_dims)),
CartesianIndices(map(Base.Fix1(axes, log_weights), param_dims)),
) do yi, i1, i2
yi_pred = @views y_pred[:, :, i1]
lwi = @views log_weights[:, :, i2]
init = T(-Inf)
sel_iter = Iterators.flatten((
init, (lwi_j for (lwi_j, yi_pred_j) in zip(lwi, yi_pred) if yi_pred_j yi)
))
return exp(LogExpFunctions.logsumexp(sel_iter))
end
return pitvals
end

function _only_observed_data_key(idata::InferenceObjects.InferenceData)
haskey(idata, :observed_data) ||
throw(ArgumentError("No `observed_data` group in `idata`"))
ks = keys(idata.observed_data)
length(ks) == 1 || throw(
ArgumentError(
"More than one observed data variable: $(ks). `y_name` must be provided"
),
)
return first(ks)
end
Loading

0 comments on commit ebd2b0e

Please sign in to comment.