-
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Julia implementation of
r2_score
(#285)
* Add Julia implementation of r2_score * Fix example * Add method for InferenceData argument * Move function to utils * Update error messages * Add smarter heuristic for getting observations and matching predictions * Use smarter heuristic in loo_pit * Use smarter heuristic in r2_score * Rename variable * Remove unused variables * Don't refer to unexported function docstring * Add citation * Use different regression model in example * Update docstrings * Remove duplicated method * Add citation * Add GLM as test dependency * Test r2_score * Fixes for older Julia versions * Fix filename * Fix doctests
- Loading branch information
Showing
10 changed files
with
423 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
""" | ||
r2_score(y_true::AbstractVector, y_pred::AbstractVecOrMat) -> (; r2, r2_std) | ||
``R²`` for linear Bayesian regression models.[^GelmanGoodrich2019] | ||
# Arguments | ||
- `y_true`: Observed data of length `noutputs` | ||
- `y_pred`: Predicted data with size `(ndraws[, nchains], noutputs)` | ||
[^GelmanGoodrich2019]: Andrew Gelman, Ben Goodrich, Jonah Gabry & Aki Vehtari (2019) | ||
R-squared for Bayesian Regression Models, The American Statistician, | ||
73:3, 307-9, | ||
DOI: [10.1080/00031305.2018.1549100](https://doi.org/10.1080/00031305.2018.1549100). | ||
# Examples | ||
```jldoctest | ||
using ArviZ, ArviZExampleData | ||
idata = load_example_data("regression1d") | ||
y_true = idata.observed_data.y | ||
y_pred = PermutedDimsArray(idata.posterior_predictive.y, (:draw, :chain, :y_dim_0)) | ||
r2_score(y_true, y_pred) | ||
# output | ||
(r2 = 0.683196996216511, r2_std = 0.036883777654323734) | ||
``` | ||
""" | ||
function r2_score(y_true, y_pred) | ||
r_squared = r2_samples(y_true, y_pred) | ||
return NamedTuple{(:r2, :r2_std)}(StatsBase.mean_and_std(r_squared; corrected=false)) | ||
end | ||
|
||
""" | ||
r2_score(idata::InferenceData; y_name, y_pred_name) -> (; r2, r2_std) | ||
Compute ``R²`` from `idata`, automatically formatting the predictions to the correct shape. | ||
# Keywords | ||
- `y_name`: Name of observed data variable in `idata.observed_data`. If not provided, then | ||
the only observed data variable is used. | ||
- `y_pred_name`: Name of posterior predictive variable in `idata.posterior_predictive`. | ||
If not provided, then `y_name` is used. | ||
# Examples | ||
```jldoctest | ||
using ArviZ, ArviZExampleData | ||
idata = load_example_data("regression10d") | ||
r2_score(idata) | ||
# output | ||
(r2 = 0.998384805658226, r2_std = 0.00010062063385452256) | ||
``` | ||
""" | ||
function r2_score( | ||
idata::InferenceObjects.InferenceData; | ||
y_name::Union{Symbol,Nothing}=nothing, | ||
y_pred_name::Union{Symbol,Nothing}=nothing, | ||
) | ||
(_, y), (_, y_pred) = observations_and_predictions(idata, y_name, y_pred_name) | ||
return r2_score(y, _draw_chains_params_array(y_pred)) | ||
end | ||
|
||
""" | ||
r2_samples(y_true::AbstractVector, y_pred::AbstractMatrix) -> AbstractVector | ||
``R²`` samples for Bayesian regression models. Only valid for linear models. | ||
See also [`r2_score`](@ref). | ||
# Arguments | ||
- `y_true`: Observed data of length `noutputs` | ||
- `y_pred`: Predicted data with size `(ndraws[, nchains], noutputs)` | ||
""" | ||
function r2_samples(y_true::AbstractVector, y_pred::AbstractArray) | ||
@assert ndims(y_pred) ∈ (2, 3) | ||
corrected = false | ||
dims = ndims(y_pred) | ||
|
||
var_y_est = dropdims(Statistics.var(y_pred; corrected, dims); dims) | ||
y_true_reshape = reshape(y_true, ntuple(one, ndims(y_pred) - 1)..., :) | ||
var_residual = dropdims(Statistics.var(y_pred .- y_true_reshape; corrected, dims); dims) | ||
|
||
# allocate storage for type-stability | ||
T = typeof(first(var_y_est) / first(var_residual)) | ||
sample_axes = ntuple(Base.Fix1(axes, y_pred), ndims(y_pred) - 1) | ||
r_squared = similar(y_pred, T, sample_axes) | ||
r_squared .= var_y_est ./ (var_y_est .+ var_residual) | ||
return r_squared | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
using ArviZ | ||
using ArviZExampleData | ||
using GLM | ||
using Statistics | ||
using Test | ||
|
||
@testset "r2_score/r2_sample" begin | ||
@testset "basic" begin | ||
n = 100 | ||
@testset for T in (Float32, Float64), | ||
sz in (300, (100, 3)), | ||
σ in T.((2, 1, 0.5, 0.1)) | ||
|
||
x = range(T(0), T(1); length=n) | ||
slope = T(2) | ||
intercept = T(3) | ||
y = @. slope * x + intercept + randn(T) * σ | ||
x_reshape = length(sz) == 1 ? x' : reshape(x, 1, 1, :) | ||
y_pred = slope .* x_reshape .+ intercept .+ randn(T, sz..., n) .* σ | ||
|
||
r2_val = @inferred r2_score(y, y_pred) | ||
@test r2_val isa NamedTuple{(:r2, :r2_std),NTuple{2,T}} | ||
r2_draws = @inferred ArviZStats.r2_samples(y, y_pred) | ||
@test r2_val.r2 ≈ mean(r2_draws) | ||
@test r2_val.r2_std ≈ std(r2_draws; corrected=false) | ||
|
||
# check rough consistency with GLM | ||
res = lm(@formula(y ~ 1 + x), (; x=Float64.(x), y=Float64.(y))) | ||
@test r2_val.r2 ≈ r2(res) rtol = 1 | ||
end | ||
end | ||
|
||
@testset "InferenceData inputs" begin | ||
@testset for name in ("regression1d", "regression10d") | ||
idata = load_example_data(name) | ||
VERSION ≥ v"1.9" && @inferred r2_score(idata) | ||
r2_val = r2_score(idata) | ||
@test r2_val == r2_score( | ||
idata.observed_data.y, | ||
PermutedDimsArray(idata.posterior_predictive.y, (:draw, :chain, :y_dim_0)), | ||
) | ||
@test r2_val == r2_score(idata; y_name=:y) | ||
@test r2_val == r2_score(idata; y_pred_name=:y) | ||
@test r2_val == r2_score(idata; y_name=:y, y_pred_name=:y) | ||
@test_throws Exception r2_score(idata; y_name=:z) | ||
@test_throws Exception r2_score(idata; y_pred_name=:z) | ||
end | ||
end | ||
end |
Oops, something went wrong.