Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Julia implementation of r2_score #285

Merged
merged 22 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/ArviZStats/ArviZStats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ using ArviZ: InferenceData, convert_to_dataset, ess
const INFORMATION_CRITERION_SCALES = (deviance=-2, log=1, negative_log=-1)

@forwardfun kde
@forwardfun r2_score

include("utils.jl")
include("hdi.jl")
Expand All @@ -53,8 +52,7 @@ include("waic.jl")
include("model_weights.jl")
include("compare.jl")
include("loo_pit.jl")

ArviZ.convert_result(::typeof(r2_score), result) = ArviZ.todataframes(result)
include("r2_score.jl")

@doc doc"""
summarystats(
Expand Down
9 changes: 5 additions & 4 deletions src/ArviZStats/hdi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

Here we calculate the 83% HDI for a normal random variable:

```jldoctest; setup = :(using Random; Random.seed!(78))
```jldoctest hdi; setup = :(using Random; Random.seed!(78))

Check failure on line 40 in src/ArviZStats/hdi.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/ArviZ.jl/ArviZ.jl/src/ArviZStats/hdi.jl:40-48 ```jldoctest hdi; setup = :(using Random; Random.seed!(78)) using ArviZ x = randn(2_000) hdi(x; prob=0.83) # output (lower = -1.3826605224220527, upper = 1.259817149822839) ``` Subexpression: using ArviZ x = randn(2_000) hdi(x; prob=0.83) Evaluated output: (lower = -1.3826605224220527, upper = 1.259817149822839,) Expected output: (lower = -1.3826605224220527, upper = 1.259817149822839) diff = Warning: Diff output requires color. (lower = -1.3826605224220527, -1.3826605224220527, upper = 1.259817149822839)1.259817149822839,)
using ArviZ
x = randn(2_000)
hdi(x; prob=0.83)

Expand All @@ -48,7 +49,7 @@

We can also calculate the HDI for a 3-dimensional array of samples:

```jldoctest; setup = :(using Random; Random.seed!(67))
```jldoctest hdi; setup = :(using Random; Random.seed!(67))
x = randn(1_000, 1, 1) .+ reshape(0:5:10, 1, 1, :)
pairs(hdi(x))

Expand Down Expand Up @@ -121,8 +122,8 @@
# output

Dataset with dimensions:
Dim{:hdi_bound} Categorical{Symbol} Symbol[:lower, :upper] ForwardOrdered,
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered,
Dim{:hdi_bound} Categorical{Symbol} Symbol[:lower, :upper] ForwardOrdered
and 3 layers:
:mu Float64 dims: Dim{:hdi_bound} (2)
:theta Float64 dims: Dim{:school}, Dim{:hdi_bound} (8×2)
Expand Down
27 changes: 7 additions & 20 deletions src/ArviZStats/loo_pit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,8 @@
y_pred_name::Union{Symbol,Nothing}=nothing,
kwargs...,
)
_y_name = y_name === nothing ? _only_observed_data_key(idata) : y_name
_y_pred_name = y_pred_name === nothing ? _y_name : y_pred_name
haskey(idata, :posterior_predictive) ||
throw(ArgumentError("No `posterior_predictive` group"))
y = idata.observed_data[_y_name]
y_pred = _draw_chains_params_array(idata.posterior_predictive[_y_pred_name])
(_y_name, y), (_, _y_pred) = observations_and_predictions(idata, y_name, y_pred_name)
y_pred = _draw_chains_params_array(_y_pred)

Check warning on line 169 in src/ArviZStats/loo_pit.jl

View check run for this annotation

Codecov / codecov/patch

src/ArviZStats/loo_pit.jl#L168-L169

Added lines #L168 - L169 were not covered by tests
pitvals = loo_pit(y, y_pred, log_weights; kwargs...)
return DimensionalData.rebuild(pitvals; name=Symbol("loo_pit_$(_y_name)"))
end
Expand Down Expand Up @@ -223,11 +219,13 @@
function loo_pit(
idata::InferenceObjects.InferenceData;
y_name::Union{Symbol,Nothing}=nothing,
y_pred_name::Union{Symbol,Nothing}=nothing,
log_likelihood_name::Union{Symbol,Nothing}=nothing,
reff=nothing,
kwargs...,
)
_y_name = y_name === nothing ? _only_observed_data_key(idata) : y_name
(_y_name, y), (_, _y_pred) = observations_and_predictions(idata, y_name, y_pred_name)
y_pred = _draw_chains_params_array(_y_pred)
if log_likelihood_name === nothing
if haskey(idata, :log_likelihood)
_log_like = log_likelihood(idata.log_likelihood, _y_name)
Expand All @@ -241,7 +239,8 @@
end
log_like = _draw_chains_params_array(_log_like)
psis_result = _psis_loo_setup(log_like, reff)
return loo_pit(idata, psis_result.log_weights; y_name=_y_name, kwargs...)
pitvals = loo_pit(y, y_pred, psis_result.log_weights; kwargs...)
return DimensionalData.rebuild(pitvals; name=Symbol("loo_pit_$(_y_name)"))
end

function _loo_pit(y::Number, y_pred, log_weights)
Expand Down Expand Up @@ -269,15 +268,3 @@
end
return pitvals
end

function _only_observed_data_key(idata::InferenceObjects.InferenceData)
haskey(idata, :observed_data) ||
throw(ArgumentError("No `observed_data` group in `idata`"))
ks = keys(idata.observed_data)
length(ks) == 1 || throw(
ArgumentError(
"More than one observed data variable: $(ks). `y_name` must be provided"
),
)
return first(ks)
end
94 changes: 94 additions & 0 deletions src/ArviZStats/r2_score.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
r2_score(y_true::AbstractVector, y_pred::AbstractVecOrMat) -> (; r2, r2_std)

``R²`` for linear Bayesian regression models.[^GelmanGoodrich2019]

# Arguments

- `y_true`: Observed data of length `noutputs`
- `y_pred`: Predicted data with size `(ndraws[, nchains], noutputs)`

[^GelmanGoodrich2019]: Andrew Gelman, Ben Goodrich, Jonah Gabry & Aki Vehtari (2019)
R-squared for Bayesian Regression Models, The American Statistician,
73:3, 307-9,
DOI: [10.1080/00031305.2018.1549100](https://doi.org/10.1080/00031305.2018.1549100).
# Examples

```jldoctest

Check failure on line 17 in src/ArviZStats/r2_score.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/ArviZ.jl/ArviZ.jl/src/ArviZStats/r2_score.jl:17-27 ```jldoctest using ArviZ, ArviZExampleData idata = load_example_data("regression1d") y_true = idata.observed_data.y y_pred = PermutedDimsArray(idata.posterior_predictive.y, (:draw, :chain, :y_dim_0)) r2_score(y_true, y_pred) # output (r2 = 0.683196996216511, r2_std = 0.036883777654323734) ``` Subexpression: using ArviZ, ArviZExampleData idata = load_example_data("regression1d") y_true = idata.observed_data.y y_pred = PermutedDimsArray(idata.posterior_predictive.y, (:draw, :chain, :y_dim_0)) r2_score(y_true, y_pred) Evaluated output: (r2 = 0.683196996216511, r2_std = 0.036883777654323734,) Expected output: (r2 = 0.683196996216511, r2_std = 0.036883777654323734) diff = Warning: Diff output requires color. (r2 = 0.683196996216511, 0.683196996216511, r2_std = 0.036883777654323734)0.036883777654323734,)
using ArviZ, ArviZExampleData
idata = load_example_data("regression1d")
y_true = idata.observed_data.y
y_pred = PermutedDimsArray(idata.posterior_predictive.y, (:draw, :chain, :y_dim_0))
r2_score(y_true, y_pred)

# output

(r2 = 0.683196996216511, r2_std = 0.036883777654323734)
```
"""
function r2_score(y_true, y_pred)
r_squared = r2_samples(y_true, y_pred)
return NamedTuple{(:r2, :r2_std)}(StatsBase.mean_and_std(r_squared; corrected=false))
end

"""
r2_score(idata::InferenceData; y_name, y_pred_name) -> (; r2, r2_std)

Compute ``R²`` from `idata`, automatically formatting the predictions to the correct shape.

# Keywords

- `y_name`: Name of observed data variable in `idata.observed_data`. If not provided, then
the only observed data variable is used.
- `y_pred_name`: Name of posterior predictive variable in `idata.posterior_predictive`.
If not provided, then `y_name` is used.

# Examples

```jldoctest

Check failure on line 48 in src/ArviZStats/r2_score.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/ArviZ.jl/ArviZ.jl/src/ArviZStats/r2_score.jl:48-56 ```jldoctest using ArviZ, ArviZExampleData idata = load_example_data("regression10d") r2_score(idata) # output (r2 = 0.998384805658226, r2_std = 0.00010062063385452256) ``` Subexpression: using ArviZ, ArviZExampleData idata = load_example_data("regression10d") r2_score(idata) Evaluated output: (r2 = 0.998384805658226, r2_std = 0.00010062063385452256,) Expected output: (r2 = 0.998384805658226, r2_std = 0.00010062063385452256) diff = Warning: Diff output requires color. (r2 = 0.998384805658226, 0.998384805658226, r2_std = 0.00010062063385452256)0.00010062063385452256,)
using ArviZ, ArviZExampleData
idata = load_example_data("regression10d")
r2_score(idata)

# output

(r2 = 0.998384805658226, r2_std = 0.00010062063385452256)
```
"""
function r2_score(
idata::InferenceObjects.InferenceData;
y_name::Union{Symbol,Nothing}=nothing,
y_pred_name::Union{Symbol,Nothing}=nothing,
)
(_, y), (_, y_pred) = observations_and_predictions(idata, y_name, y_pred_name)
return r2_score(y, _draw_chains_params_array(y_pred))
end

"""
r2_samples(y_true::AbstractVector, y_pred::AbstractMatrix) -> AbstractVector

``R²`` samples for Bayesian regression models. Only valid for linear models.

See also [`r2_score`](@ref).

# Arguments

- `y_true`: Observed data of length `noutputs`
- `y_pred`: Predicted data with size `(ndraws[, nchains], noutputs)`
"""
function r2_samples(y_true::AbstractVector, y_pred::AbstractArray)
@assert ndims(y_pred) ∈ (2, 3)
corrected = false
dims = ndims(y_pred)

var_y_est = dropdims(Statistics.var(y_pred; corrected, dims); dims)
y_true_reshape = reshape(y_true, ntuple(one, ndims(y_pred) - 1)..., :)
var_residual = dropdims(Statistics.var(y_pred .- y_true_reshape; corrected, dims); dims)

# allocate storage for type-stability
T = typeof(first(var_y_est) / first(var_residual))
sample_axes = ntuple(Base.Fix1(axes, y_pred), ndims(y_pred) - 1)
r_squared = similar(y_pred, T, sample_axes)
r_squared .= var_y_est ./ (var_y_est .+ var_residual)
return r_squared
end
118 changes: 118 additions & 0 deletions src/ArviZStats/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,124 @@
return nothing
end

function _only_observed_data_key(idata::InferenceObjects.InferenceData; var_name=nothing)
haskey(idata, :observed_data) ||
throw(ArgumentError("Data must contain an `observed_data` group."))
ks = keys(idata.observed_data)
isempty(ks) && throw(ArgumentError("`observed_data` group must not be empty."))
if length(ks) > 1
msg = "More than one observed data variable: $(ks)."
if var_name !== nothing
msg = "$msg `$var_name` must be specified."
end
throw(ArgumentError(msg))
end
return first(ks)
end

# get name of group and group itself most likely to contain posterior predictive draws
function _post_pred_or_post_name_group(idata)
haskey(idata, :posterior_predictive) &&
return :posterior_predictive => idata.posterior_predictive
haskey(idata, :posterior) && return :posterior => idata.posterior
throw(ArgumentError("No `posterior_predictive` or `posterior` group"))
end

"""
observations_and_predictions(data::InferenceData[, y_name[, y_pred_name]])

Get arrays of observations and predictions for the specified variable in `data`.

If `y_name` and/or `y_pred_name` is not provided, then they are inferred from the data.
Generally this requires that either there is a single variable in `observed_data` or that
there is only one variable in `posterior` or `posterior_predictive` that has a matching name
in `observed_data`, optionally with the suffix `_pred`.

The return value has the structure `(y_name => y, y_pred_name => y_pred)`, where `y_name`
and `y_pred_name` are the actual names found.
"""
function observations_and_predictions end
function observations_and_predictions(
idata::InferenceObjects.InferenceData, y_name::Union{Symbol,Nothing}=nothing
)
return observations_and_predictions(idata, y_name, nothing)
end
function observations_and_predictions(
idata::InferenceObjects.InferenceData, y_name::Symbol, y_pred_name::Symbol
)
haskey(idata, :observed_data) ||
throw(ArgumentError("Data must contain `observed_data` group"))
y = idata.observed_data[y_name]
_, post_pred = _post_pred_or_post_name_group(idata)
y_pred = post_pred[y_pred_name]
return (y_name => y, y_pred_name => y_pred)
end
function observations_and_predictions(
idata::InferenceObjects.InferenceData, ::Nothing, y_pred_name::Symbol
)
y_name = _only_observed_data_key(idata; var_name=:y_name)
y = idata.observed_data[y_name]
_, post_pred = _post_pred_or_post_name_group(idata)
y_pred = post_pred[y_pred_name]
return (y_name => y, y_pred_name => y_pred)
end
function observations_and_predictions(
idata::InferenceObjects.InferenceData, y_name::Symbol, ::Nothing
)
haskey(idata, :observed_data) ||
throw(ArgumentError("Data must contain `observed_data` group"))
observed_data = idata.observed_data
y = observed_data[y_name]
post_pred_name, post_pred = _post_pred_or_post_name_group(idata)
y_pred_names = (y_name, Symbol("$(y_name)_pred"))
for y_pred_name in y_pred_names
if haskey(post_pred, y_pred_name)
y_pred = post_pred[y_pred_name]
return (y_name => y, y_pred_name => y_pred)
end
end
throw(

Check warning on line 127 in src/ArviZStats/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/ArviZStats/utils.jl#L127

Added line #L127 was not covered by tests
ArgumentError(
"Could not find names $y_pred_names in group `$post_pred_name`. `y_pred_name` must be specified.",
),
)
end
function observations_and_predictions(
idata::InferenceObjects.InferenceData, ::Nothing, ::Nothing
)
haskey(idata, :observed_data) ||
throw(ArgumentError("Data must contain `observed_data` group"))
observed_data = idata.observed_data
obs_keys = keys(observed_data)
if length(obs_keys) == 1
y_name = first(obs_keys)
return observations_and_predictions(idata, y_name, nothing)
else
_, post_pred = _post_pred_or_post_name_group(idata)
var_name_pairs = filter(
!isnothing,
map(obs_keys) do k
for k_pred in (k, Symbol("$(k)_pred"))
haskey(post_pred, k_pred) && return (k, k_pred)
end
return nothing
end,
)
if length(var_name_pairs) == 1
y_name, y_pred_name = first(var_name_pairs)
y = observed_data[y_name]
y_pred = post_pred[y_pred_name]
return (y_name => y, y_pred_name => y_pred)
else
throw(
ArgumentError(
"No unique pair of variable names. `y_name` and/or `y_pred_name` must be specified.",
),
)
end
end
end

"""
smooth_data(y; dims=:, interp_method=CubicSpline, offset_frac=0.01)

Expand Down
4 changes: 2 additions & 2 deletions test/ArviZStats/loo_pit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ using StatsBase
posterior_predictive=Dataset((; y=y_pred)),
log_likelihood=Dataset((; y=log_like)),
)
@test_throws ArgumentError loo_pit(idata1; y_name=:z)
@test_throws Exception loo_pit(idata1; y_name=:z)
@test_throws Exception loo_pit(idata1; y_pred_name=:z)
@test_throws Exception loo_pit(idata1; log_likelihood_name=:z)
@test loo_pit(idata1) == pit_vals
Expand All @@ -114,8 +114,8 @@ using StatsBase
@test_throws ArgumentError loo_pit(
idata2; y_name=:z, y_pred_name=:y_pred, log_likelihood_name=:log_like
)
@test_throws Exception loo_pit(idata2; y_name=:y, log_likelihood_name=:log_like)
@test_throws ArgumentError loo_pit(idata2; y_name=:y, y_pred_name=:y_pred)
@test loo_pit(idata2; y_name=:y, log_likelihood_name=:log_like) == pit_vals
@test loo_pit(
idata2; y_name=:y, y_pred_name=:y_pred, log_likelihood_name=:log_like
) == pit_vals
Expand Down
49 changes: 49 additions & 0 deletions test/ArviZStats/r2_score.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using ArviZ
using ArviZExampleData
using GLM
using Statistics
using Test

@testset "r2_score/r2_sample" begin
@testset "basic" begin
n = 100
@testset for T in (Float32, Float64),
sz in (300, (100, 3)),
σ in T.((2, 1, 0.5, 0.1))

x = range(T(0), T(1); length=n)
slope = T(2)
intercept = T(3)
y = @. slope * x + intercept + randn(T) * σ
x_reshape = length(sz) == 1 ? x' : reshape(x, 1, 1, :)
y_pred = slope .* x_reshape .+ intercept .+ randn(T, sz..., n) .* σ

r2_val = @inferred r2_score(y, y_pred)
@test r2_val isa NamedTuple{(:r2, :r2_std),NTuple{2,T}}
r2_draws = @inferred ArviZStats.r2_samples(y, y_pred)
@test r2_val.r2 ≈ mean(r2_draws)
@test r2_val.r2_std ≈ std(r2_draws; corrected=false)

# check rough consistency with GLM
res = lm(@formula(y ~ 1 + x), (; x=Float64.(x), y=Float64.(y)))
@test r2_val.r2 ≈ r2(res) rtol = 1
end
end

@testset "InferenceData inputs" begin
@testset for name in ("regression1d", "regression10d")
idata = load_example_data(name)
VERSION ≥ v"1.9" && @inferred r2_score(idata)
r2_val = r2_score(idata)
@test r2_val == r2_score(
idata.observed_data.y,
PermutedDimsArray(idata.posterior_predictive.y, (:draw, :chain, :y_dim_0)),
)
@test r2_val == r2_score(idata; y_name=:y)
@test r2_val == r2_score(idata; y_pred_name=:y)
@test r2_val == r2_score(idata; y_name=:y, y_pred_name=:y)
@test_throws Exception r2_score(idata; y_name=:z)
@test_throws Exception r2_score(idata; y_pred_name=:z)
end
end
end
Loading
Loading