From ee05745f7339f097622c5d7331c5879128d15deb Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Wed, 30 Aug 2023 11:54:15 -0300 Subject: [PATCH] compare to official llama2 --- tests/testthat/test-llama.R | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/testthat/test-llama.R b/tests/testthat/test-llama.R index a33cf1c..c1982ad 100644 --- a/tests/testthat/test-llama.R +++ b/tests/testthat/test-llama.R @@ -1,6 +1,6 @@ test_that("Can create a llama model", { skip_on_ci() # this is too big for the github runners. - model <- llama_from_pretrained("huggyllama/llama-7b") + model <- llama_from_pretrained("meta-llama/Llama-2-7b-chat-hf") model$to(dtype=torch_float32()) model$eval() with_no_grad({ @@ -8,6 +8,7 @@ test_that("Can create a llama model", { }) out <- pred[1, 1, 1:5] - reference <- c(-12.7782, -28.6373, 0.9082, -6.1501, -4.3769) - expect_equal(as.numeric(out), reference, tolerance = 1e-5) + + reference <- c(0.2226, 0.0299, 0.2729, -0.7919, 1.6164) + expect_equal(as.numeric(out), reference, tolerance = 1e-4) })