Skip to content

Commit

Permalink
Add Julia implementation of r2_score (#285)
Browse files Browse the repository at this point in the history
* Add Julia implementation of r2_score

* Fix example

* Add method for InferenceData argument

* Move function to utils

* Update error messages

* Add smarter heuristic for getting observations and matching predictions

* Use smarter heuristic in loo_pit

* Use smarter heuristic in r2_score

* Rename variable

* Remove unused variables

* Don't refer to unexported function docstring

* Add citation

* Use different regression model in example

* Update docstrings

* Remove duplicated method

* Add citation

* Add GLM as test dependency

* Test r2_score

* Fixes for older Julia versions

* Fix filename

* Fix doctests
  • Loading branch information
sethaxen authored Jul 26, 2023
1 parent 95210fb commit 63d3cac
Show file tree
Hide file tree
Showing 10 changed files with 423 additions and 40 deletions.
4 changes: 1 addition & 3 deletions src/ArviZStats/ArviZStats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ using ArviZ: InferenceData, convert_to_dataset, ess
const INFORMATION_CRITERION_SCALES = (deviance=-2, log=1, negative_log=-1)

@forwardfun kde
@forwardfun r2_score

include("utils.jl")
include("hdi.jl")
Expand All @@ -53,8 +52,7 @@ include("waic.jl")
include("model_weights.jl")
include("compare.jl")
include("loo_pit.jl")

ArviZ.convert_result(::typeof(r2_score), result) = ArviZ.todataframes(result)
include("r2_score.jl")

@doc doc"""
summarystats(
Expand Down
9 changes: 5 additions & 4 deletions src/ArviZStats/hdi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ This implementation uses the algorithm of [^ChenShao1999].
Here we calculate the 83% HDI for a normal random variable:
```jldoctest; setup = :(using Random; Random.seed!(78))
```jldoctest hdi; setup = :(using Random; Random.seed!(78))
using ArviZ
x = randn(2_000)
hdi(x; prob=0.83)
Expand All @@ -48,7 +49,7 @@ hdi(x; prob=0.83)
We can also calculate the HDI for a 3-dimensional array of samples:
```jldoctest; setup = :(using Random; Random.seed!(67))
```jldoctest hdi; setup = :(using Random; Random.seed!(67))
x = randn(1_000, 1, 1) .+ reshape(0:5:10, 1, 1, :)
pairs(hdi(x))
Expand Down Expand Up @@ -121,8 +122,8 @@ hdi(idata)
# output
Dataset with dimensions:
Dim{:hdi_bound} Categorical{Symbol} Symbol[:lower, :upper] ForwardOrdered,
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered,
Dim{:hdi_bound} Categorical{Symbol} Symbol[:lower, :upper] ForwardOrdered
and 3 layers:
:mu Float64 dims: Dim{:hdi_bound} (2)
:theta Float64 dims: Dim{:school}, Dim{:hdi_bound} (8×2)
Expand Down
27 changes: 7 additions & 20 deletions src/ArviZStats/loo_pit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,8 @@ function loo_pit(
y_pred_name::Union{Symbol,Nothing}=nothing,
kwargs...,
)
_y_name = y_name === nothing ? _only_observed_data_key(idata) : y_name
_y_pred_name = y_pred_name === nothing ? _y_name : y_pred_name
haskey(idata, :posterior_predictive) ||
throw(ArgumentError("No `posterior_predictive` group"))
y = idata.observed_data[_y_name]
y_pred = _draw_chains_params_array(idata.posterior_predictive[_y_pred_name])
(_y_name, y), (_, _y_pred) = observations_and_predictions(idata, y_name, y_pred_name)
y_pred = _draw_chains_params_array(_y_pred)
pitvals = loo_pit(y, y_pred, log_weights; kwargs...)
return DimensionalData.rebuild(pitvals; name=Symbol("loo_pit_$(_y_name)"))
end
Expand Down Expand Up @@ -223,11 +219,13 @@ loo_pit(idata; y_name=:obs)
function loo_pit(
idata::InferenceObjects.InferenceData;
y_name::Union{Symbol,Nothing}=nothing,
y_pred_name::Union{Symbol,Nothing}=nothing,
log_likelihood_name::Union{Symbol,Nothing}=nothing,
reff=nothing,
kwargs...,
)
_y_name = y_name === nothing ? _only_observed_data_key(idata) : y_name
(_y_name, y), (_, _y_pred) = observations_and_predictions(idata, y_name, y_pred_name)
y_pred = _draw_chains_params_array(_y_pred)
if log_likelihood_name === nothing
if haskey(idata, :log_likelihood)
_log_like = log_likelihood(idata.log_likelihood, _y_name)
Expand All @@ -241,7 +239,8 @@ function loo_pit(
end
log_like = _draw_chains_params_array(_log_like)
psis_result = _psis_loo_setup(log_like, reff)
return loo_pit(idata, psis_result.log_weights; y_name=_y_name, kwargs...)
pitvals = loo_pit(y, y_pred, psis_result.log_weights; kwargs...)
return DimensionalData.rebuild(pitvals; name=Symbol("loo_pit_$(_y_name)"))
end

function _loo_pit(y::Number, y_pred, log_weights)
Expand Down Expand Up @@ -269,15 +268,3 @@ function _loo_pit(y::AbstractArray, y_pred, log_weights)
end
return pitvals
end

function _only_observed_data_key(idata::InferenceObjects.InferenceData)
haskey(idata, :observed_data) ||
throw(ArgumentError("No `observed_data` group in `idata`"))
ks = keys(idata.observed_data)
length(ks) == 1 || throw(
ArgumentError(
"More than one observed data variable: $(ks). `y_name` must be provided"
),
)
return first(ks)
end
94 changes: 94 additions & 0 deletions src/ArviZStats/r2_score.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
r2_score(y_true::AbstractVector, y_pred::AbstractVecOrMat) -> (; r2, r2_std)
``R²`` for linear Bayesian regression models.[^GelmanGoodrich2019]
# Arguments
- `y_true`: Observed data of length `noutputs`
- `y_pred`: Predicted data with size `(ndraws[, nchains], noutputs)`
[^GelmanGoodrich2019]: Andrew Gelman, Ben Goodrich, Jonah Gabry & Aki Vehtari (2019)
R-squared for Bayesian Regression Models, The American Statistician,
73:3, 307-9,
DOI: [10.1080/00031305.2018.1549100](https://doi.org/10.1080/00031305.2018.1549100).
# Examples
```jldoctest
using ArviZ, ArviZExampleData
idata = load_example_data("regression1d")
y_true = idata.observed_data.y
y_pred = PermutedDimsArray(idata.posterior_predictive.y, (:draw, :chain, :y_dim_0))
r2_score(y_true, y_pred)
# output
(r2 = 0.683196996216511, r2_std = 0.036883777654323734)
```
"""
function r2_score(y_true, y_pred)
r_squared = r2_samples(y_true, y_pred)
return NamedTuple{(:r2, :r2_std)}(StatsBase.mean_and_std(r_squared; corrected=false))
end

"""
r2_score(idata::InferenceData; y_name, y_pred_name) -> (; r2, r2_std)
Compute ``R²`` from `idata`, automatically formatting the predictions to the correct shape.
# Keywords
- `y_name`: Name of observed data variable in `idata.observed_data`. If not provided, then
the only observed data variable is used.
- `y_pred_name`: Name of posterior predictive variable in `idata.posterior_predictive`.
If not provided, then `y_name` is used.
# Examples
```jldoctest
using ArviZ, ArviZExampleData
idata = load_example_data("regression10d")
r2_score(idata)
# output
(r2 = 0.998384805658226, r2_std = 0.00010062063385452256)
```
"""
function r2_score(
idata::InferenceObjects.InferenceData;
y_name::Union{Symbol,Nothing}=nothing,
y_pred_name::Union{Symbol,Nothing}=nothing,
)
(_, y), (_, y_pred) = observations_and_predictions(idata, y_name, y_pred_name)
return r2_score(y, _draw_chains_params_array(y_pred))
end

"""
r2_samples(y_true::AbstractVector, y_pred::AbstractMatrix) -> AbstractVector
``R²`` samples for Bayesian regression models. Only valid for linear models.
See also [`r2_score`](@ref).
# Arguments
- `y_true`: Observed data of length `noutputs`
- `y_pred`: Predicted data with size `(ndraws[, nchains], noutputs)`
"""
function r2_samples(y_true::AbstractVector, y_pred::AbstractArray)
@assert ndims(y_pred) (2, 3)
corrected = false
dims = ndims(y_pred)

var_y_est = dropdims(Statistics.var(y_pred; corrected, dims); dims)
y_true_reshape = reshape(y_true, ntuple(one, ndims(y_pred) - 1)..., :)
var_residual = dropdims(Statistics.var(y_pred .- y_true_reshape; corrected, dims); dims)

# allocate storage for type-stability
T = typeof(first(var_y_est) / first(var_residual))
sample_axes = ntuple(Base.Fix1(axes, y_pred), ndims(y_pred) - 1)
r_squared = similar(y_pred, T, sample_axes)
r_squared .= var_y_est ./ (var_y_est .+ var_residual)
return r_squared
end
118 changes: 118 additions & 0 deletions src/ArviZStats/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,124 @@ function _check_log_likelihood(x)
return nothing
end

function _only_observed_data_key(idata::InferenceObjects.InferenceData; var_name=nothing)
haskey(idata, :observed_data) ||
throw(ArgumentError("Data must contain an `observed_data` group."))
ks = keys(idata.observed_data)
isempty(ks) && throw(ArgumentError("`observed_data` group must not be empty."))
if length(ks) > 1
msg = "More than one observed data variable: $(ks)."
if var_name !== nothing
msg = "$msg `$var_name` must be specified."
end
throw(ArgumentError(msg))
end
return first(ks)
end

# get name of group and group itself most likely to contain posterior predictive draws
function _post_pred_or_post_name_group(idata)
haskey(idata, :posterior_predictive) &&
return :posterior_predictive => idata.posterior_predictive
haskey(idata, :posterior) && return :posterior => idata.posterior
throw(ArgumentError("No `posterior_predictive` or `posterior` group"))
end

"""
observations_and_predictions(data::InferenceData[, y_name[, y_pred_name]])
Get arrays of observations and predictions for the specified variable in `data`.
If `y_name` and/or `y_pred_name` is not provided, then they are inferred from the data.
Generally this requires that either there is a single variable in `observed_data` or that
there is only one variable in `posterior` or `posterior_predictive` that has a matching name
in `observed_data`, optionally with the suffix `_pred`.
The return value has the structure `(y_name => y, y_pred_name => y_pred)`, where `y_name`
and `y_pred_name` are the actual names found.
"""
function observations_and_predictions end
function observations_and_predictions(
idata::InferenceObjects.InferenceData, y_name::Union{Symbol,Nothing}=nothing
)
return observations_and_predictions(idata, y_name, nothing)
end
function observations_and_predictions(
idata::InferenceObjects.InferenceData, y_name::Symbol, y_pred_name::Symbol
)
haskey(idata, :observed_data) ||
throw(ArgumentError("Data must contain `observed_data` group"))
y = idata.observed_data[y_name]
_, post_pred = _post_pred_or_post_name_group(idata)
y_pred = post_pred[y_pred_name]
return (y_name => y, y_pred_name => y_pred)
end
function observations_and_predictions(
idata::InferenceObjects.InferenceData, ::Nothing, y_pred_name::Symbol
)
y_name = _only_observed_data_key(idata; var_name=:y_name)
y = idata.observed_data[y_name]
_, post_pred = _post_pred_or_post_name_group(idata)
y_pred = post_pred[y_pred_name]
return (y_name => y, y_pred_name => y_pred)
end
function observations_and_predictions(
idata::InferenceObjects.InferenceData, y_name::Symbol, ::Nothing
)
haskey(idata, :observed_data) ||
throw(ArgumentError("Data must contain `observed_data` group"))
observed_data = idata.observed_data
y = observed_data[y_name]
post_pred_name, post_pred = _post_pred_or_post_name_group(idata)
y_pred_names = (y_name, Symbol("$(y_name)_pred"))
for y_pred_name in y_pred_names
if haskey(post_pred, y_pred_name)
y_pred = post_pred[y_pred_name]
return (y_name => y, y_pred_name => y_pred)
end
end
throw(
ArgumentError(
"Could not find names $y_pred_names in group `$post_pred_name`. `y_pred_name` must be specified.",
),
)
end
function observations_and_predictions(
idata::InferenceObjects.InferenceData, ::Nothing, ::Nothing
)
haskey(idata, :observed_data) ||
throw(ArgumentError("Data must contain `observed_data` group"))
observed_data = idata.observed_data
obs_keys = keys(observed_data)
if length(obs_keys) == 1
y_name = first(obs_keys)
return observations_and_predictions(idata, y_name, nothing)
else
_, post_pred = _post_pred_or_post_name_group(idata)
var_name_pairs = filter(
!isnothing,
map(obs_keys) do k
for k_pred in (k, Symbol("$(k)_pred"))
haskey(post_pred, k_pred) && return (k, k_pred)
end
return nothing
end,
)
if length(var_name_pairs) == 1
y_name, y_pred_name = first(var_name_pairs)
y = observed_data[y_name]
y_pred = post_pred[y_pred_name]
return (y_name => y, y_pred_name => y_pred)
else
throw(
ArgumentError(
"No unique pair of variable names. `y_name` and/or `y_pred_name` must be specified.",
),
)
end
end
end

"""
smooth_data(y; dims=:, interp_method=CubicSpline, offset_frac=0.01)
Expand Down
4 changes: 2 additions & 2 deletions test/ArviZStats/loo_pit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ using StatsBase
posterior_predictive=Dataset((; y=y_pred)),
log_likelihood=Dataset((; y=log_like)),
)
@test_throws ArgumentError loo_pit(idata1; y_name=:z)
@test_throws Exception loo_pit(idata1; y_name=:z)
@test_throws Exception loo_pit(idata1; y_pred_name=:z)
@test_throws Exception loo_pit(idata1; log_likelihood_name=:z)
@test loo_pit(idata1) == pit_vals
Expand All @@ -114,8 +114,8 @@ using StatsBase
@test_throws ArgumentError loo_pit(
idata2; y_name=:z, y_pred_name=:y_pred, log_likelihood_name=:log_like
)
@test_throws Exception loo_pit(idata2; y_name=:y, log_likelihood_name=:log_like)
@test_throws ArgumentError loo_pit(idata2; y_name=:y, y_pred_name=:y_pred)
@test loo_pit(idata2; y_name=:y, log_likelihood_name=:log_like) == pit_vals
@test loo_pit(
idata2; y_name=:y, y_pred_name=:y_pred, log_likelihood_name=:log_like
) == pit_vals
Expand Down
49 changes: 49 additions & 0 deletions test/ArviZStats/r2_score.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using ArviZ
using ArviZExampleData
using GLM
using Statistics
using Test

@testset "r2_score/r2_sample" begin
@testset "basic" begin
n = 100
@testset for T in (Float32, Float64),
sz in (300, (100, 3)),
σ in T.((2, 1, 0.5, 0.1))

x = range(T(0), T(1); length=n)
slope = T(2)
intercept = T(3)
y = @. slope * x + intercept + randn(T) * σ
x_reshape = length(sz) == 1 ? x' : reshape(x, 1, 1, :)
y_pred = slope .* x_reshape .+ intercept .+ randn(T, sz..., n) .* σ

r2_val = @inferred r2_score(y, y_pred)
@test r2_val isa NamedTuple{(:r2, :r2_std),NTuple{2,T}}
r2_draws = @inferred ArviZStats.r2_samples(y, y_pred)
@test r2_val.r2 mean(r2_draws)
@test r2_val.r2_std std(r2_draws; corrected=false)

# check rough consistency with GLM
res = lm(@formula(y ~ 1 + x), (; x=Float64.(x), y=Float64.(y)))
@test r2_val.r2 r2(res) rtol = 1
end
end

@testset "InferenceData inputs" begin
@testset for name in ("regression1d", "regression10d")
idata = load_example_data(name)
VERSION v"1.9" && @inferred r2_score(idata)
r2_val = r2_score(idata)
@test r2_val == r2_score(
idata.observed_data.y,
PermutedDimsArray(idata.posterior_predictive.y, (:draw, :chain, :y_dim_0)),
)
@test r2_val == r2_score(idata; y_name=:y)
@test r2_val == r2_score(idata; y_pred_name=:y)
@test r2_val == r2_score(idata; y_name=:y, y_pred_name=:y)
@test_throws Exception r2_score(idata; y_name=:z)
@test_throws Exception r2_score(idata; y_pred_name=:z)
end
end
end
Loading

0 comments on commit 63d3cac

Please sign in to comment.