Skip to content

Commit

Permalink
resolved github issue #74
Browse files Browse the repository at this point in the history
drop models via postprocessing that have zero prior probability
  • Loading branch information
merliseclyde committed Dec 5, 2023
1 parent cba7eee commit 580bf3e
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 1 deletion.
23 changes: 23 additions & 0 deletions R/bas_glm.R
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,29 @@ bas.glm <- function(formula, family = binomial(link = "logit"),

# drop null model if it is present
if (betaprior$family == "Jeffreys" & (min(result$size) == 1)) result <- .drop.null.bas(result)

# github issue #74. drop models with zero prior probability

if (any(result$priorprobs == 0)) {
drop.models = result$priorprobs != 0

result$mle = result$mle[drop.models]
result$mle.se = result$mle.se[drop.models]
result$mse = result$mse[drop.models]
result$which = result$which[drop.models]
result$freq = result$freq[drop.models]
result$shrinkage = result$shrinkage[drop.models]
result$R2 = result$R2[drop.models]
result$logmarg = result$logmarg[drop.models]
result$df = result$df[drop.models]
result$size = result$size[drop.models]
result$Q = result$Q[drop.models]
result$rank = result$rank[drop.models]
result$sampleprobs = result$sampleprobs[drop.models]
result$postprobs = result$postprobs[subset = drop.models]
result$priorprobs = result$priorprobs[subset = drop.models]
result$n.models = length(result$postprobs)
}

if (method == "MCMC") {
result$postprobs.MCMC <- result$freq / sum(result$freq)
Expand Down
21 changes: 21 additions & 0 deletions R/bas_lm.R
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,27 @@ bas.lm <- function(formula,
result$probne0.RN <- result$probne0
result$postprobs.RN <- result$postprobs
result$include.always <- keep

# github issue #74. drop models with zero prior probability

if (any(result$priorprobs == 0)) {
drop.models = result$priorprobs != 0

result$mle = result$mle[drop.models]
result$mle.se = result$mle.se[drop.models]
result$mse = result$mse[drop.models]
result$which = result$which[drop.models]
result$freq = result$freq[drop.models]
result$shrinkage = result$shrinkage[drop.models]
result$R2 = result$R2[drop.models]
result$logmarg = result$logmarg[drop.models]
result$size = result$size[drop.models]
result$rank = result$rank[drop.models]
result$sampleprobs = result$sampleprobs[drop.models]
result$postprobs = result$postprobs[subset = drop.models]
result$priorprobs = result$priorprobs[subset = drop.models]
result$n.models = length(result$postprobs)
}

if (method == "MCMC" || method == "MCMC_new") {
result$n.models <- result$n.Unique
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-model-priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ hald_tr_power <- bas.lm(Y ~ .,
data = Hald, prior = "g-prior",
modelprior = tr.power.prior(kappa=2, 2))
expect_equal(1, sum(hald_tr_power$postprobs))
expect_error(expect_equal(0, sum(hald_tr_power$postprobs <= 0.0)))
expect_no_error(expect_equal(0, sum(hald_tr_power$postprobs <= 0.0)))

})

Expand Down
55 changes: 55 additions & 0 deletions tests/testthat/test-predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,59 @@ test_that("se.fit with 1 variable", {
expect_no_error(predict(hald.gprior,
newdata=Hald, estimator="MPM", se.fit=TRUE))

})

# GitHub issue #70 and #74
test_that("bas.lm using truncated priors includes models with prior prob 0", {
data("bodyfat")
bas_mod <- bas.lm(Bodyfat ~.,data = bodyfat[1:14,], method = 'BAS', modelprior = tr.poisson(2, 3))

expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'HPM'))
expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'BPM'))
expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'MPM'))
expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'BMA'))

bas_mod <- bas.lm(Bodyfat ~.,data = bodyfat[1:14,], method = 'deterministic', modelprior = tr.poisson(2, 3))

expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'HPM'))
expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'BPM'))
expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'MPM'))
expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'BMA'))

bas_mod <- bas.lm(Bodyfat ~.,data = bodyfat[1:14,], method = 'MCMC', modelprior = tr.poisson(2, 3))

expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'HPM'))
expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'BPM'))
expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'MPM'))
expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'BMA'))

bas_mod <- bas.lm(Bodyfat ~.,data = bodyfat[1:14,], method = 'MCMC+BAS', modelprior = tr.poisson(2, 3))

expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'HPM'))
expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'BPM'))
expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'MPM'))
expect_no_error(predict(bas_mod,newdata = bodyfat[15:20,], se.fit = T, estimator = 'BMA'))



})

# github issue #74
test_that("bas.glm using truncated priors includes models with prior prob 0", {
data("Pima.tr", package="MASS")
data("Pima.te", package="MASS")
bas_mod <- bas.glm(type ~ ., data = Pima.tr, subset = 1:5, method = 'BAS',
modelprior = tr.poisson(2,2),
betaprior = g.prior(g=as.numeric(nrow(Pima.tr))),
family=binomial())

expect_no_error(sum(bas_mod$postprobs == 1.0))

# github issue ??
# expect_no_error(predict(bas_mod,newdata = Pima.tr[15:20,], se.fit = T, estimator = 'HPM'))
# expect_no_error(predict(bas_mod,newdata = Pima.tr[15:20,], se.fit = T, estimator = 'BPM'))
# expect_no_error(predict(bas_mod,newdata = Pima.tr[15:20,], se.fit = T, estimator = 'MPM'))
# expect_no_error(predict(bas_mod,newdata = Pima.tr[15:20,], se.fit = T, estimator = 'BMA'))

#expect_null(plot(confint(pima_pred)))
})

0 comments on commit 580bf3e

Please sign in to comment.