Skip to content

Commit

Permalink
Merge pull request #5 from mlverse/auth
Browse files Browse the repository at this point in the history
Auth
  • Loading branch information
dfalbel authored Aug 18, 2023
2 parents c9cebdc + 23e33ba commit 0130a60
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 4 deletions.
1 change: 1 addition & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
R_KEEP_PKG_SOURCE: yes
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_HUB_TOKEN }}

steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

* Added a `NEWS.md` file to track changes to the package.
* Added `hub_snapshot` to alllow downloading an entire repository at once (#2).
* Added support for authentication using `HUGGING_FACE_HUB_TOKEN`. (#5)
21 changes: 20 additions & 1 deletion R/hub_download.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ hub_download <- function(repo_id, filename, ..., revision = "main", repo_type =
TRUE
}
handle <- curl::new_handle(noprogress = FALSE, progressfunction = progress)
curl::handle_setheaders(handle, .list = hub_headers())
curl::curl_download(url, tmp, handle = handle, quiet = FALSE)
cli::cli_progress_done(id = bar_id)
}, error = function(err) {
Expand Down Expand Up @@ -184,12 +185,30 @@ repo_folder_name <- function(repo_id, repo_type = "model") {
glue::glue("{repo_type}s{REPO_ID_SEPARATOR()}{repo_id}")
}

hub_headers <- function() {
headers <- c("user-agent" = "hfhub/0.0.1")

token <- Sys.getenv("HUGGING_FACE_HUB_TOKEN", unset = "")
if (!nzchar(token))
token <- Sys.getenv("HUGGINGFACE_HUB_TOKEN", unset = "")

if (nzchar(token)) {
headers["authorization"] <- paste0("Bearer ", token)
}

headers
}

#' @importFrom rlang %||%
get_file_metadata <- function(url) {

headers <- hub_headers()
headers["Accept-Encoding"] <- "identity"

req <- reqst(httr::HEAD,
url = url,
httr::config(followlocation = FALSE),
httr::add_headers("Accept-Encoding" = "identity", "user-agent" = "hfhub/0.0.1"),
httr::add_headers(.headers = headers),
follow_relative_redirects = TRUE
)
list(
Expand Down
6 changes: 3 additions & 3 deletions R/hub_info.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ hub_repo_info <- function(repo_id, ..., repo_type = NULL, revision = NULL, files
params$blobs <- TRUE
}

headers <- hub_headers()

results <- httr::GET(
path,
query = params,
httr::add_headers(
"user-agent" = "hfhub/0.0.1"
)
httr::add_headers(.headers = headers)
)

httr::content(results)
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ library(hfhub)
path <- hub_download("gpt2", "config.json")
str(jsonlite::fromJSON(path))
```

## Authentication

You can set the `HUGGING_FACE_HUB_TOKEN` environment variable with the value
of a token obtained [here](https://huggingface.co/settings/tokens). This will
allow you to download private files from Hugging Face Hub.
9 changes: 9 additions & 0 deletions tests/testthat/helper-skips.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
skip_if_no_token <- function() {
token <- Sys.getenv("HUGGINGFACE_HUB_TOKEN", "")
if (token == "") {
token <- Sys.getenv("HUGGING_FACE_HUB_TOKEN", "")
}

if (token == "")
skip("No auth token set.")
}
22 changes: 22 additions & 0 deletions tests/testthat/test-hub_download.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,25 @@ test_that("hub_download", {
})
expect_equal(list.files(tmp), "models--gpt2")
})

test_that("can download from private repo", {

skip_if_no_token()

expect_error(regexp = NA, {
hub_download(
repo_id = "dfalbel/test-hfhub",
filename = ".gitattributes",
force_download = TRUE
)
})

expect_error(regexp = NA, {
hub_download(
repo_id = "dfalbel/test-hfhub",
filename = "hello.safetensors",
force_download = TRUE
)
})

})
7 changes: 7 additions & 0 deletions tests/testthat/test-hub_info.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,10 @@ test_that("dataset info", {
expect_equal(info$author, "dfalbel")
expect_true(length(info$siblings) >= 13)
})

test_that("can get ifo for private repositories", {
skip_if_no_token()

info <- hub_dataset_info("dfalbel/test-hfhub-dataset")
expect_equal(info$author, "dfalbel")
})
10 changes: 10 additions & 0 deletions tests/testthat/test-hub_snapshot.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,13 @@ test_that("snapshot", {

expect_true(length(fs::dir_ls(p)) >= 4)
})

test_that("can snapshot private repositories", {

skip_if_no_token()

expect_error(regexp=NA, {
hub_snapshot("dfalbel/test-hfhub", repo_type = "model", force_download = TRUE)
})

})

0 comments on commit 0130a60

Please sign in to comment.