Skip to content

Commit

Permalink
Merge pull request #171 from wilhelm-lab/patch/retry_mechanism
Browse files Browse the repository at this point in the history
Patch/retry mechanism
  • Loading branch information
picciama authored Dec 21, 2023
2 parents 60ab7ad + 3300f37 commit b6c81cd
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .cookietemple.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ full_name: Victor Giurcoiu
email: [email protected]
project_name: oktoberfest
project_short_description: Public repo oktoberfest
version: 0.5.2
version: 0.5.3
license: MIT
4 changes: 2 additions & 2 deletions .github/release-drafter.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name-template: "0.5.2 🌈" # <<COOKIETEMPLE_FORCE_BUMP>>
tag-template: 0.5.2 # <<COOKIETEMPLE_FORCE_BUMP>>
name-template: "0.5.3 🌈" # <<COOKIETEMPLE_FORCE_BUMP>>
tag-template: 0.5.3 # <<COOKIETEMPLE_FORCE_BUMP>>
exclude-labels:
- "skip-changelog"

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/sync_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ jobs:

- uses: oleksiyrudenko/[email protected]
with:
name: "victorgiurcoiu"
email: "victor.giurcoiu@tum.de"
actor: "victorgiurcoiu"
name: "Mario Picciani"
email: "mario.picciani@tum.de"
actor: "picciama"
token: "${{ secrets.CT_SYNC_TOKEN}}"

- name: Sync project
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2022, Victor Giurcoiu
Copyright (c) 2023, Wilhelmlab at Technical University of Munich

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 1 addition & 1 deletion cookietemple.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.5.2
current_version = 0.5.3

[bumpversion_files_whitelisted]
init_file = oktoberfest/__init__.py
Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@
# the built documents.
#
# The short X.Y version.
version = "0.5.2"
version = "0.5.3"
# The full version, including alpha/beta/rc tags.
release = "0.5.2"
release = "0.5.3"

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
2 changes: 1 addition & 1 deletion oktoberfest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__author__ = """The Oktoberfest development team (Wilhelmlab at Technical University of Munich)"""
__copyright__ = f"Copyright {datetime.now():%Y}, Wilhelmlab at Technical University of Munich"
__license__ = "MIT"
__version__ = "0.5.2"
__version__ = "0.5.3"

import logging.handlers
import sys
Expand Down
25 changes: 13 additions & 12 deletions oktoberfest/predict/koina.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def __async_callback(
self,
infer_results: Dict[int, Union[InferResult, InferenceServerException]],
request_id: int,
result: InferResult,
result: Optional[InferResult],
error: Optional[InferenceServerException],
):
"""
Expand Down Expand Up @@ -408,7 +408,13 @@ def __async_predict_batch(
batch_outputs = self.__get_batch_outputs(self.model_outputs.keys())
batch_inputs = self.__get_batch_inputs(data)

for _ in range(retries):
for i in range(retries):
if i > 0: # need to yield first, before doing sth, but only after first time
yield
if isinstance(infer_results.get(request_id), InferResult):
break
del infer_results[request_id] # avoid race condition in case inference is slower than tqdm loop

self.client.async_infer(
model_name=self.model_name,
request_id=str(request_id),
Expand All @@ -417,9 +423,6 @@ def __async_predict_batch(
outputs=batch_outputs,
client_timeout=timeout,
)
yield
if isinstance(infer_results.get(request_id), InferResult):
break

def predict(
self,
Expand Down Expand Up @@ -492,23 +495,21 @@ def __predict_async(
n_tasks = i + 1
with tqdm(total=n_tasks, desc="Getting predictions", disable=disable_progress_bar) as pbar:
unfinished_tasks = [i for i in range(n_tasks)]
while pbar.n != n_tasks:
while pbar.n < n_tasks:
time.sleep(0.2)
new_unfinished_tasks = []
for j in unfinished_tasks:
result = infer_results.get(j)
if result is None:
new_unfinished_tasks.append(j)
continue
if isinstance(result, InferenceServerException):
elif isinstance(result, InferResult):
pbar.n += 1
else: # unexpected result / exception -> try again
try:
new_unfinished_tasks.append(j)
next(tasks[j])
new_unfinished_tasks.append(j)
except StopIteration:
pbar.n += 1
continue
if isinstance(result, InferResult):
pbar.n += 1

unfinished_tasks = new_unfinished_tasks
pbar.refresh()
Expand Down
7 changes: 7 additions & 0 deletions oktoberfest/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def list_spectra(input_dir: Union[str, Path], file_format: str) -> List[Path]:
:param file_format: Format of spectra files that match the file extension (case-insensitive), can be "mzML", "RAW" or "pkl".
:raises NotADirectoryError: if the specified input directory does not exist
:raises ValueError: if the specified file format is not supported
:raises AssertionError: if no files in the provided input directory match the provided file format
:return: A list of paths to all spectra files found in the given directory
"""
if isinstance(input_dir, str):
Expand All @@ -264,6 +265,12 @@ def list_spectra(input_dir: Union[str, Path], file_format: str) -> List[Path]:
else:
raise NotADirectoryError(f"{input_dir} does not exist.")

if not raw_files:
raise AssertionError(
f"There are no spectra files with the extension {file_format.lower()} in the provided input_dir {input_dir}. "
"Please check."
)

return raw_files


Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[tool.poetry]
name = "oktoberfest"
version = "0.5.2" # <<COOKIETEMPLE_FORCE_BUMP>>
version = "0.5.3" # <<COOKIETEMPLE_FORCE_BUMP>>
description = "Public repo oktoberfest"
authors = ["Victor Giurcoiu <[email protected]>"]
authors = ["Wilhelmlab at Technical University of Munich"]
license = "MIT"
readme = "README.rst"
homepage = "https://github.com/wilhelm-lab/oktoberfest"
Expand Down
28 changes: 28 additions & 0 deletions tests/unit_tests/test_pp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest
from pathlib import Path

from oktoberfest import pp


class TestProcessing(unittest.TestCase):
"""Test class for preprocessing functions."""

def test_list_spectra(self):
"""Test listing of spectra with expected user input."""
spectra_path = Path(__file__).parent
spectra_file = spectra_path / "test.mzml"
spectra_file.open("w").close()
self.assertEqual([spectra_path / "test.mzml"], pp.list_spectra(spectra_path, file_format="mzml"))
spectra_file.unlink()

def test_list_spectra_with_empty_string_folder(self):
"""Test listing spectra in a string folder without matching files."""
self.assertRaises(AssertionError, pp.list_spectra, str(Path(__file__).parent), "raw")

def test_list_spectra_with_wrong_folder(self):
"""Test listing spectra in a folder that does not exist."""
self.assertRaises(NotADirectoryError, pp.list_spectra, Path(__file__).parent / "noexist", "raw")

def test_list_spectra_with_wrong_format(self):
"""Test listing spectra with a format that isn't allowed."""
self.assertRaises(ValueError, pp.list_spectra, Path(__file__).parent, "mzm")
15 changes: 15 additions & 0 deletions tests/unit_tests/test_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,18 @@ def test_prosit_tmt(self):
expected_df["PREDICTED_IRT"] = expected_df["PREDICTED_IRT"].astype(library.spectra_data["PREDICTED_IRT"].dtype)

pd.testing.assert_frame_equal(library.spectra_data, expected_df)

def test_failing_koina(self):
"""Test koina with input data that does not fit to the model to trigger exception handling."""
library = Spectra.from_csv(Path(__file__).parent / "data" / "predictions" / "library_input.csv")
input_data = library.spectra_data

self.assertRaises(
Exception,
predict,
input_data,
model_name="Prosit_2020_intensity_HCD",
server_url="koina.proteomicsdb.org:443",
ssl=True,
targets=["intensities", "annotation"],
)

0 comments on commit b6c81cd

Please sign in to comment.