Skip to content

Commit

Permalink
Fix file downloading
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jul 22, 2024
1 parent ba79b3f commit c110ac9
Showing 1 changed file with 54 additions and 21 deletions.
75 changes: 54 additions & 21 deletions optimum/neuron/utils/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,33 +67,66 @@
}


def list_filenames_in_github_repo_directory(
github_repo_directory_url: str, only_files: bool = False, only_directories: bool = False
) -> List[str]:
"""
Lists the content of a repository on GitHub.
"""
if only_files and only_directories:
raise ValueError("Either `only_files` or `only_directories` can be set to True.")

response = requests.get(github_repo_directory_url)

if response.status_code != 200:
raise ValueError(f"Could not fetch the content of the page: {github_repo_directory_url}.")

# Here we use regex instead of beautiful soup to not rely on yet another library.
table_regex = r"\<table aria-labelledby=\"folders-and-files\".*\<\/table\>"
filename_column_regex = r"\<div class=\"react-directory-filename-cell\".*?\<\/div>"
if only_files:
filename_regex = r"\<a .* aria-label=\"([\w\.]+), \(File\)\""
elif only_directories:
filename_regex = r"\<a .* aria-label=\"([\w\.]+), \(Directory\)\""
else:
filename_regex = r"\<a .* aria-label=\"([\w\.]+)"

filenames = []

table_match = re.search(table_regex, response.text)
if table_match is not None:
table_content = response.text[table_match.start(0) : table_match.end(0)]
for column in re.finditer(filename_column_regex, table_content):
match = re.search(filename_regex, column.group(0))
if match:
filenames.append(match.group(1))

return list(set(filenames))


def download_example_script_from_github(task_name: str, target_directory: Path, revision: str = "main") -> Path:
# TODO: test that every existing task can be downloaded.
script_name = f"{_TASK_TO_EXAMPLE_SCRIPT[task_name]}.py"
example_script_path = target_directory / script_name
requirements_path = target_directory / "requirements.txt"
was_saved = False
script_name = f"{_TASK_TO_EXAMPLE_SCRIPT[task_name]}.py"
example_script_path = target_directory
for folder in _GH_REPO_EXAMPLE_FOLDERS:
# Download the python script.
url = f"{_BASE_RAW_FILES_PATH_IN_GH_REPO}/{revision}/examples/{folder}/{script_name}"
r = requests.get(url)
if r.status_code != 200:
url_folder = f"{_BASE_RAW_FILES_PATH_IN_GH_REPO}/{revision}/examples/{folder}"
filenames_for_example = list_filenames_in_github_repo_directory(url_folder, only_files=True)
if script_name not in filenames_for_example:
continue
with open(example_script_path, "w") as fp:
fp.write(r.text)
was_saved = True

# Try to download the associated requirements if it exists.
url_requirements = f"{_BASE_RAW_FILES_PATH_IN_GH_REPO}/{revision}/examples/{folder}/requirements.txt"
r = requests.get(url_requirements)
if r.status_code != 200:
continue
with open(requirements_path, "w") as fp:
fp.write(r.text)

for filename in filenames_for_example:
r = requests.get(f"{url_folder}/{filename}")
if r.status_code != 200:
continue
local_path = target_directory / filename
with open(local_path, "w") as fp:
fp.write(r.text)
if filename == script_name:
was_saved = True
example_script_path = local_path
if was_saved:
break
if not was_saved:
raise FileNotFoundError(f"Could not find an example script for the task {task_name} on the GitHub repo")

return example_script_path


Expand Down

0 comments on commit c110ac9

Please sign in to comment.