diff --git a/src/model_fetcher.py b/src/model_fetcher.py index d3cba28..86493e1 100644 --- a/src/model_fetcher.py +++ b/src/model_fetcher.py @@ -1,18 +1,23 @@ from __future__ import annotations from pathlib import Path - +from typing import TypedDict, Dict import requests +class ModelInfo(TypedDict): + model_name: str + url: str -def download_model(model_name, url): +def download_model(model_info: ModelInfo) -> None: """ Download a model file if it doesn't already exist. Args: - model_name (str): The name of the model file. - url (str): The URL of the model file. + model_info (ModelInfo): The model information including name and URL. """ + model_name = model_info['model_name'] + url = model_info['url'] + # Define the local directory to store the model files LOCAL_MODEL_DIRECTORY = Path('models/pt/') @@ -23,9 +28,7 @@ def download_model(model_name, url): # Check if the model already exists if local_file_path.exists(): - print( - f"'{model_name}' exists. Skipping download.", - ) + print(f"'{model_name}' exists. Skipping download.") return # Send an HTTP GET request to fetch the model file @@ -35,28 +38,24 @@ def download_model(model_name, url): with open(local_file_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) - print( - f"'{model_name}' saved to '{local_file_path}'.", - ) + print(f"'{model_name}' saved to '{local_file_path}'.") else: - print( - f"Error downloading '{model_name}': {response.status_code}", - ) + print(f"Error downloading '{model_name}': {response.status_code}") - -def main(): +def main() -> None: # Define the URLs for the model files - MODEL_URLS = { - 'best_yolov8l.pt': - 'http://changdar-server.mooo.com:28000/models/best_yolov8l.pt', - 'best_yolov8x.pt': - 'http://changdar-server.mooo.com:28000/models/best_yolov8x.pt', + MODEL_URLS: Dict[str, str] = { + 'best_yolov8l.pt': 'http://changdar-server.mooo.com:28000/models/best_yolov8l.pt', + 'best_yolov8x.pt': 'http://changdar-server.mooo.com:28000/models/best_yolov8x.pt', } # Iterate over all models and download them if they don't already exist for model_name, url in MODEL_URLS.items(): - download_model(model_name, url) - + model_info: ModelInfo = { + 'model_name': model_name, + 'url': url, + } + download_model(model_info) if __name__ == '__main__': main()