Skip to content

Commit

Permalink
Add TypedDict
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 authored Jul 8, 2024
1 parent ee2176f commit 4e96329
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions src/model_fetcher.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from __future__ import annotations

from pathlib import Path

from typing import TypedDict, Dict
import requests

class ModelInfo(TypedDict):
model_name: str
url: str

def download_model(model_name, url):
def download_model(model_info: ModelInfo) -> None:
"""
Download a model file if it doesn't already exist.
Args:
model_name (str): The name of the model file.
url (str): The URL of the model file.
model_info (ModelInfo): The model information including name and URL.
"""
model_name = model_info['model_name']
url = model_info['url']

# Define the local directory to store the model files
LOCAL_MODEL_DIRECTORY = Path('models/pt/')

Expand All @@ -23,9 +28,7 @@ def download_model(model_name, url):

# Check if the model already exists
if local_file_path.exists():
print(
f"'{model_name}' exists. Skipping download.",
)
print(f"'{model_name}' exists. Skipping download.")
return

# Send an HTTP GET request to fetch the model file
Expand All @@ -35,28 +38,24 @@ def download_model(model_name, url):
with open(local_file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(
f"'{model_name}' saved to '{local_file_path}'.",
)
print(f"'{model_name}' saved to '{local_file_path}'.")
else:
print(
f"Error downloading '{model_name}': {response.status_code}",
)
print(f"Error downloading '{model_name}': {response.status_code}")


def main():
def main() -> None:
# Define the URLs for the model files
MODEL_URLS = {
'best_yolov8l.pt':
'http://changdar-server.mooo.com:28000/models/best_yolov8l.pt',
'best_yolov8x.pt':
'http://changdar-server.mooo.com:28000/models/best_yolov8x.pt',
MODEL_URLS: Dict[str, str] = {
'best_yolov8l.pt': 'http://changdar-server.mooo.com:28000/models/best_yolov8l.pt',
'best_yolov8x.pt': 'http://changdar-server.mooo.com:28000/models/best_yolov8x.pt',
}

# Iterate over all models and download them if they don't already exist
for model_name, url in MODEL_URLS.items():
download_model(model_name, url)

model_info: ModelInfo = {
'model_name': model_name,
'url': url,
}
download_model(model_info)

if __name__ == '__main__':
main()

0 comments on commit 4e96329

Please sign in to comment.