Skip to content

Commit

Permalink
centralized model download in new function in whisper.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jblom committed Oct 1, 2024
1 parent 0281720 commit 8144798
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 29 deletions.
1 change: 1 addition & 0 deletions model_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


# makes sure the model is available locally, if not download it from S3, if that fails download from Huggingface
# FIXME should also check if the correct w_model type is available locally!
def check_model_availability() -> bool:
logger = logging.getLogger(__name__)
if os.path.exists(model_base_dir + "/model.bin"):
Expand Down
36 changes: 25 additions & 11 deletions whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional

import faster_whisper
from faster_whisper import WhisperModel
from config import (
model_base_dir,
w_beam_size,
Expand All @@ -24,21 +25,34 @@
logger = logging.getLogger(__name__)


# loads the whisper model
# FIXME does not check if the specific model_type is available locally!
def load_model(model_base_dir: str, model_type: str, device: str) -> WhisperModel:
logger.info(f"Loading Whisper model {model_type} for device: {device}")

# change HuggingFace dir to where model is downloaded
os.environ["HF_HOME"] = model_base_dir

# determine loading locally or have Whisper download from HuggingFace
model_location = model_base_dir if check_model_availability() else model_type
model = WhisperModel(
model_location, # either local path or e.g. large-v2 (means HuggingFace download)
device=device,
compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU
"float16" if device == "cuda" else "float32"
),
)
logger.info(f"Model loaded from location: {model_location}")
return model


def run_asr(input_path, output_dir, model=None) -> Optional[dict]:
logger.info(f"Starting ASR on {input_path}")
start_time = time.time()
if not model:
logger.info(f"Device used: {w_device}")
# checking if model needs to be downloaded from HF or not
model_location = model_base_dir if check_model_availability() else w_model
model = faster_whisper.WhisperModel(
model_location,
device=w_device,
compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU
"float16" if w_device == "cuda" else "float32"
),
)
logger.info("Model loaded, now getting segments")
logger.info("Model not passed as param, need to obtain it first")
model = load_model(model_base_dir, w_model, w_device)
logger.info("Processing segments")
segments, _ = model.transcribe(
input_path,
vad_filter=w_vad,
Expand Down
21 changes: 3 additions & 18 deletions whisper_api.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,24 @@
import logging
import os
from typing import Optional
from uuid import uuid4
from fastapi import BackgroundTasks, FastAPI, HTTPException, status, Response
from asr import run
from whisper import load_model
from enum import Enum
from pydantic import BaseModel
from config import (
model_base_dir,
w_device,
w_model,
)
import faster_whisper
from model_download import check_model_availability

logger = logging.getLogger(__name__)
api = FastAPI()

logger.info(f"Loading model on device {w_device}")


# change hugging face home dir where model is downloaded
os.environ["HF_HOME"] = model_base_dir

# checking if model needs to be downloaded from HF or not
model_location = model_base_dir if check_model_availability() else w_model

model = faster_whisper.WhisperModel(
model_location,
device=w_device,
compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU
"float16" if w_device == "cuda" else "float32"
),
)
logger.info("Model loaded!")
# load the model in memory on API startup
model = load_model(model_base_dir, w_model, w_device)


class Status(Enum):
Expand Down

0 comments on commit 8144798

Please sign in to comment.