From f5f4b6857065cc5a7048124e7670c451588dc33a Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Fri, 18 Aug 2023 12:08:33 -0300 Subject: [PATCH 1/4] add support for authentication using a token --- R/hub_download.R | 21 ++++++++++++++++++++- R/hub_info.R | 6 +++--- tests/testthat/helper-skips.R | 9 +++++++++ tests/testthat/test-hub_download.R | 22 ++++++++++++++++++++++ tests/testthat/test-hub_info.R | 7 +++++++ tests/testthat/test-hub_snapshot.R | 10 ++++++++++ 6 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 tests/testthat/helper-skips.R diff --git a/R/hub_download.R b/R/hub_download.R index d5d974c..b67420f 100644 --- a/R/hub_download.R +++ b/R/hub_download.R @@ -149,6 +149,7 @@ hub_download <- function(repo_id, filename, ..., revision = "main", repo_type = TRUE } handle <- curl::new_handle(noprogress = FALSE, progressfunction = progress) + curl::handle_setheaders(handle, .list = hub_headers()) curl::curl_download(url, tmp, handle = handle, quiet = FALSE) cli::cli_progress_done(id = bar_id) }, error = function(err) { @@ -184,12 +185,30 @@ repo_folder_name <- function(repo_id, repo_type = "model") { glue::glue("{repo_type}s{REPO_ID_SEPARATOR()}{repo_id}") } +hub_headers <- function() { + headers <- c("user-agent" = "hfhub/0.0.1") + + token <- Sys.getenv("HUGGING_FACE_HUB_TOKEN", unset = "") + if (!nzchar(token)) + token <- Sys.getenv("HUGGINGFACE_HUB_TOKEN", unset = "") + + if (nzchar(token)) { + headers["authorization"] <- paste0("Bearer ", token) + } + + headers +} + #' @importFrom rlang %||% get_file_metadata <- function(url) { + + headers <- hub_headers() + headers["Accept-Encoding"] <- "identity" + req <- reqst(httr::HEAD, url = url, httr::config(followlocation = FALSE), - httr::add_headers("Accept-Encoding" = "identity", "user-agent" = "hfhub/0.0.1"), + httr::add_headers(.headers = headers), follow_relative_redirects = TRUE ) list( diff --git a/R/hub_info.R b/R/hub_info.R index 4fcab6a..439878c 100644 --- a/R/hub_info.R +++ b/R/hub_info.R @@ -19,12 +19,12 @@ hub_repo_info <- function(repo_id, ..., repo_type = NULL, revision = NULL, files params$blobs <- TRUE } + headers <- hub_headers() + results <- httr::GET( path, query = params, - httr::add_headers( - "user-agent" = "hfhub/0.0.1" - ) + httr::add_headers(.headers = headers) ) httr::content(results) diff --git a/tests/testthat/helper-skips.R b/tests/testthat/helper-skips.R new file mode 100644 index 0000000..1cef51f --- /dev/null +++ b/tests/testthat/helper-skips.R @@ -0,0 +1,9 @@ +skip_if_no_token <- function() { + token <- Sys.getenv("HUGGINGFACE_HUB_TOKEN", "") + if (token == "") { + token <- Sys.getenv("HUGGING_FACE_HUB_TOKEN", "") + } + + if (token == "") + skip("No auth token set.") +} diff --git a/tests/testthat/test-hub_download.R b/tests/testthat/test-hub_download.R index 2f22f27..0486bb6 100644 --- a/tests/testthat/test-hub_download.R +++ b/tests/testthat/test-hub_download.R @@ -27,3 +27,25 @@ test_that("hub_download", { }) expect_equal(list.files(tmp), "models--gpt2") }) + +test_that("can download from private repo", { + + skip_if_no_token() + + expect_error(regexp = NA, { + hub_download( + repo_id = "dfalbel/test-hfhub", + filename = ".gitattributes", + force_download = TRUE + ) + }) + + expect_error(regexp = NA, { + hub_download( + repo_id = "dfalbel/test-hfhub", + filename = "hello.safetensors", + force_download = TRUE + ) + }) + +}) diff --git a/tests/testthat/test-hub_info.R b/tests/testthat/test-hub_info.R index 8d393c1..c727093 100644 --- a/tests/testthat/test-hub_info.R +++ b/tests/testthat/test-hub_info.R @@ -5,3 +5,10 @@ test_that("dataset info", { expect_equal(info$author, "dfalbel") expect_true(length(info$siblings) >= 13) }) + +test_that("can get ifo for private repositories", { + skip_if_no_token() + + info <- hub_dataset_info("dfalbel/test-hfhub-dataset") + expect_equal(info$author, "dfalbel") +}) diff --git a/tests/testthat/test-hub_snapshot.R b/tests/testthat/test-hub_snapshot.R index 2a817ba..b08b1be 100644 --- a/tests/testthat/test-hub_snapshot.R +++ b/tests/testthat/test-hub_snapshot.R @@ -10,3 +10,13 @@ test_that("snapshot", { expect_true(length(fs::dir_ls(p)) >= 4) }) + +test_that("can snapshot private repositories", { + + skip_if_no_token() + + expect_error(regexp=NA, { + hub_snapshot("dfalbel/test-hfhub", repo_type = "model", force_download = TRUE) + }) + +}) From d07f99a60ed44a367c087a3af2315db6bafc92a1 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Fri, 18 Aug 2023 12:11:41 -0300 Subject: [PATCH 2/4] set env token --- .github/workflows/R-CMD-check.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 2e4e293..fa3a283 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -27,6 +27,7 @@ jobs: env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} R_KEEP_PKG_SOURCE: yes + HUGGINGFACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_HUB_TOKEN }} steps: - uses: actions/checkout@v3 From fc1897d89ee0063ef97ad9436391a52f86b762cb Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Fri, 18 Aug 2023 14:51:09 -0300 Subject: [PATCH 3/4] document env var. --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 5235c0f..398bfab 100644 --- a/README.md +++ b/README.md @@ -32,3 +32,9 @@ library(hfhub) path <- hub_download("gpt2", "config.json") str(jsonlite::fromJSON(path)) ``` + +## Authentication + +You can set the `HUGGING_FACE_HUB_TOKEN` environment variable with the value +of a token obtained [here](https://huggingface.co/settings/tokens). This will +allow you to download private files from Hugging Face Hub. From 23e33ba7afffe148611704f4eff6ac2f38e5b2fb Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Fri, 18 Aug 2023 14:52:13 -0300 Subject: [PATCH 4/4] NEWS bullet. --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index cf41c25..cb8a8ab 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,3 +2,4 @@ * Added a `NEWS.md` file to track changes to the package. * Added `hub_snapshot` to alllow downloading an entire repository at once (#2). +* Added support for authentication using `HUGGING_FACE_HUB_TOKEN`. (#5)