Skip to content

Commit

Permalink
Merge pull request #101 from beeldengeluid/81-prov
Browse files Browse the repository at this point in the history
Provenance
  • Loading branch information
jblom authored Oct 1, 2024
2 parents cefbf9d + 15a0d12 commit a4e44a2
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 35 deletions.
101 changes: 91 additions & 10 deletions asr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
import logging
import os
import time
import tomli

from base_util import get_asset_info, asr_output_dir, save_provenance
from config import (
s3_endpoint_url,
s3_bucket,
s3_folder_in_bucket,
model_base_dir,
w_word_timestamps,
w_device,
w_model,
w_beam_size,
w_best_of,
w_vad,
)

from base_util import get_asset_info, asr_output_dir
from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket, model_base_dir
from download import download_uri
from whisper import run_asr, WHISPER_JSON_FILE
from s3_util import S3Store
Expand All @@ -13,48 +27,115 @@
os.environ["HF_HOME"] = model_base_dir # change dir where model is downloaded


def _get_project_meta():
with open("pyproject.toml", mode="rb") as pyproject:
return tomli.load(pyproject)["tool"]["poetry"]


pkg_meta = _get_project_meta()
version = str(pkg_meta["version"])


def run(input_uri: str, output_uri: str, model=None) -> bool:
logger.info(f"Processing {input_uri} (save to --> {output_uri})")
start_time = time.time()
prov_steps = [] # track provenance
# 1. download input
result = download_uri(input_uri)
logger.info(result)
if not result:
logger.error("Could not obtain input, quitting...")
return False

prov_steps.append(result.provenance)

input_path = result.file_path
asset_id, extension = get_asset_info(input_path)
output_path = asr_output_dir(input_path)

# 2. check if the input file is suitable for processing any further
transcoded_file_path = try_transcode(input_path, asset_id, extension)
if not transcoded_file_path:
transcode_output = try_transcode(input_path, asset_id, extension)
if not transcode_output:
logger.error("The transcode failed to yield a valid file to continue with")
return False
else:
input_path = transcoded_file_path
input_path = transcode_output.transcoded_file_path
prov_steps.append(transcode_output.provenance)

# 3. run ASR
if not asr_already_done(output_path):
logger.info("No Whisper transcript found")
run_asr(input_path, output_path, model)
whisper_prov = run_asr(input_path, output_path, model)
if whisper_prov:
prov_steps.append(whisper_prov)
else:
logger.info(f"Whisper transcript already present in {output_path}")
provenance = {
"activity_name": "Whisper transcript already exists",
"activity_description": "",
"processing_time_ms": "",
"start_time_unix": "",
"parameters": [],
"software_version": "",
"input_data": "",
"output_data": "",
"steps": [],
}
prov_steps.append(provenance)

# 4. generate JSON transcript
if not daan_transcript_already_done(output_path):
logger.info("No DAAN transcript found")
success = generate_daan_transcript(output_path)
if not success:
daan_prov = generate_daan_transcript(output_path)
if daan_prov:
prov_steps.append(daan_prov)
else:
logger.warning("Could not generate DAAN transcript")
else:
logger.info(f"DAAN transcript already present in {output_path}")
provenance = {
"activity_name": "DAAN transcript already exists",
"activity_description": "",
"processing_time_ms": "",
"start_time_unix": "",
"parameters": [],
"software_version": "",
"input_data": "",
"output_data": "",
"steps": [],
}
prov_steps.append(provenance)

end_time = (time.time() - start_time) * 1000
final_prov = {
"activity_name": "Whisper ASR Worker",
"activity_description": "Worker that gets a video/audio file as input and outputs JSON transcripts in various formats",
"processing_time_ms": end_time,
"start_time_unix": start_time,
"parameters": {
"word_timestamps": w_word_timestamps,
"device": w_device,
"vad": w_vad,
"model": w_model,
"beam_size": w_beam_size,
"best_of": w_best_of,
},
"software_version": version,
"input_data": input_uri,
"output_data": output_uri if output_uri else output_path,
"steps": prov_steps,
}

prov_success = save_provenance(final_prov, output_path)
if not prov_success:
logger.warning("Could not save the provenance")

# 5. transfer output
if output_uri:
transfer_asr_output(output_path, asset_id)
else:
logger.info("No output_uri specified, so all is done")

return True


Expand Down Expand Up @@ -90,14 +171,14 @@ def transfer_asr_output(output_path: str, asset_id: str) -> bool:


# check if there is a whisper-transcript.json
def asr_already_done(output_dir):
def asr_already_done(output_dir) -> bool:
whisper_transcript = os.path.join(output_dir, WHISPER_JSON_FILE)
logger.info(f"Checking existence of {whisper_transcript}")
return os.path.exists(os.path.join(output_dir, WHISPER_JSON_FILE))


# check if there is a daan-es-transcript.json
def daan_transcript_already_done(output_dir):
def daan_transcript_already_done(output_dir) -> bool:
daan_transcript = os.path.join(output_dir, DAAN_JSON_FILE)
logger.info(f"Checking existence of {daan_transcript}")
return os.path.exists(os.path.join(output_dir, DAAN_JSON_FILE))
18 changes: 18 additions & 0 deletions base_util.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import os
import subprocess
import json
from typing import Tuple
from config import data_base_dir


LOG_FORMAT = "%(asctime)s|%(levelname)s|%(process)d|%(module)s|%(funcName)s|%(lineno)d|%(message)s"
PROVENANCE_JSON_FILE = "provenance.json"
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -55,3 +57,19 @@ def run_shell_command(cmd: str) -> bool:
except Exception:
logger.exception("Exception")
return False


def save_provenance(provenance: dict, asr_output_dir: str) -> bool:
logger.info(f"Saving provenance to: {asr_output_dir}")
try:
# write provenance.json
with open(
os.path.join(asr_output_dir, PROVENANCE_JSON_FILE), "w+", encoding="utf-8"
) as f:
logger.info(provenance)
json.dump(provenance, f, ensure_ascii=False, indent=4)
except EnvironmentError as e: # OSError or IOError...
logger.exception(os.strerror(e.errno))
return False

return True
22 changes: 18 additions & 4 deletions daan_transcript.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import time
from typing import TypedDict, List, Optional
from whisper import WHISPER_JSON_FILE

Expand All @@ -19,12 +20,13 @@ class ParsedResult(TypedDict):


# asr_output_dir e.g /data/output/whisper-test/
def generate_daan_transcript(asr_output_dir: str) -> bool:
def generate_daan_transcript(asr_output_dir: str) -> Optional[dict]:
logger.info(f"Generating transcript from: {asr_output_dir}")
start_time = time.time()
whisper_transcript = load_whisper_transcript(asr_output_dir)
if not whisper_transcript:
logger.error("No whisper_transcript.json found")
return False
return None

transcript = parse_whisper_transcript(whisper_transcript)

Expand All @@ -37,9 +39,21 @@ def generate_daan_transcript(asr_output_dir: str) -> bool:
json.dump(transcript, f, ensure_ascii=False, indent=4)
except EnvironmentError as e: # OSError or IOError...
logger.exception(os.strerror(e.errno))
return False
return None

return True
end_time = (time.time() - start_time) * 1000
provenance = {
"activity_name": "Whisper transcript -> DAAN transcript",
"activity_description": "Converts the output of Whisper to the DAAN index format",
"processing_time_ms": end_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": whisper_transcript,
"output_data": transcript,
"steps": [],
}
return provenance


def load_whisper_transcript(asr_output_dir: str) -> Optional[dict]:
Expand Down
32 changes: 29 additions & 3 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
@dataclass
class DownloadResult:
file_path: str # target_file_path, # TODO harmonize with dane-download-worker
mime_type: str # download_data.get("mime_type", "unknown"),
mime_type: str
provenance: dict
download_time: float = -1 # time (ms) taken to receive data after request
content_length: int = -1 # download_data.get("content_length", -1),

Expand Down Expand Up @@ -53,8 +54,19 @@ def http_download(url: str) -> Optional[DownloadResult]:
file.write(response.content)
file.close()
download_time = (time.time() - start_time) * 1000 # time in ms
provenance = {
"activity_name": "Input download",
"activity_description": "Downloads the input file from INPUT_URI",
"processing_time_ms": download_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": url,
"output_data": input_file,
"steps": [],
}
return DownloadResult(
input_file, mime_type, download_time # TODO add content_length
input_file, mime_type, provenance, download_time # TODO add content_length
)


Expand Down Expand Up @@ -89,9 +101,23 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]:
if not success:
logger.error("Failed to download input data from S3")
return None

download_time = int((time.time() - start_time) * 1000) # time in ms
else:
download_time = -1 # Report back?

provenance = {
"activity_name": "Input download",
"activity_description": "Downloads the input file from INPUT_URI",
"processing_time_ms": download_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": s3_uri,
"output_data": input_file,
"steps": [],
}

return DownloadResult(
input_file, mime_type, download_time # TODO add content_length
input_file, mime_type, provenance, download_time # TODO add content_length
)
33 changes: 22 additions & 11 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ validators = "^0.33.0"
boto3 = "^1.35.10"
fastapi = "^0.115.0"
uvicorn = "^0.30.6"
tomli = "^2.0.1"

[tool.poetry.group.dev.dependencies]
moto = "^5.0.13"
Expand Down
Loading

0 comments on commit a4e44a2

Please sign in to comment.