Skip to content

Commit

Permalink
try disabling GPU's everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel committed Feb 29, 2024
1 parent 3000b97 commit 1eaf796
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tests/testthat/test-luz.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ set.seed(1)

gc() # to see memory available on CI

acc <- luz::accelerator(cpu = TRUE)

luz_fit <- torch::nn_linear %>%
luz::setup(loss = torch::nnf_mse_loss, optimizer = torch::optim_sgd) %>%
luz::set_hparams(in_features = ncol(x_train), out_features = 1) %>%
luz::set_opt_hparams(lr = 0.01) %>%
luz::fit(list(x_train, y_train), verbose = FALSE, dataloader_options = list(batch_size = 5))
luz::fit(list(x_train, y_train), verbose = FALSE, dataloader_options = list(batch_size = 5),
accelerator = acc)

v <- vetiver_model(
luz_fit,
Expand All @@ -29,8 +32,8 @@ test_that("can print a `vetiver`ed luz model", {
})

test_that("can predict a `vetiver`ed luz model", {
v_preds <- predict(v, x_test)$cpu()
l_preds <- predict(luz_fit, x_test)$cpu()
v_preds <- predict(v, x_test, accelerator = acc)$cpu()
l_preds <- predict(luz_fit, x_test, accelerator = acc)$cpu()

expect_equal(as.array(v_preds), as.array(l_preds))
})
Expand Down

0 comments on commit 1eaf796

Please sign in to comment.