From a4f04f4feec21749dbae44dce024828d60577f88 Mon Sep 17 00:00:00 2001 From: Alex Kotlar Date: Thu, 8 Feb 2024 02:54:42 +0000 Subject: [PATCH] Issue #393: Consistent snake_case in python library --- go/beanstalkd/beanstalkd.go | 2 +- perl/bin/bystro-server.pl | 24 +- .../python/bystro/ancestry/ancestry_types.py | 4 +- python/python/bystro/ancestry/inference.py | 2 +- python/python/bystro/ancestry/listener.py | 12 +- .../bystro/ancestry/tests/test_listener.py | 47 +++- python/python/bystro/api/cli.py | 10 +- python/python/bystro/api/tests/test_cli.py | 2 +- python/python/bystro/beanstalkd/messages.py | 18 +- .../bystro/beanstalkd/tests/test_messages.py | 154 ++++++++--- python/python/bystro/beanstalkd/worker.py | 6 +- .../bystro/proteomics/proteomics_listener.py | 4 +- .../tests/test_proteomics_listener.py | 20 +- python/python/bystro/prs/listener.py | 4 +- python/python/bystro/search/index/listener.py | 9 +- python/python/bystro/search/save/listener.py | 4 +- .../python/bystro/search/utils/annotation.py | 66 ++--- python/python/bystro/search/utils/messages.py | 12 +- .../search/utils/tests/test_messages.py | 246 ++++++++++++++++++ 19 files changed, 514 insertions(+), 132 deletions(-) create mode 100644 python/python/bystro/search/utils/tests/test_messages.py diff --git a/go/beanstalkd/beanstalkd.go b/go/beanstalkd/beanstalkd.go index d7dc40f62..12abdf281 100644 --- a/go/beanstalkd/beanstalkd.go +++ b/go/beanstalkd/beanstalkd.go @@ -18,7 +18,7 @@ type ProgressData struct { } type ProgressMessage struct { - SubmissionID string `json:"submissionID"` + SubmissionID string `json:"submissionId"` Data ProgressData `json:"data"` Event string `json:"event"` } diff --git a/perl/bin/bystro-server.pl b/perl/bin/bystro-server.pl index 93fdfa613..f588d5b0f 100755 --- a/perl/bin/bystro-server.pl +++ b/perl/bin/bystro-server.pl @@ -53,11 +53,11 @@ # The properties that we accept from the worker caller my %requiredForAll = ( - output_file_base => 'output_base_path', + output_file_base => 'outputBasePath', assembly => 'assembly', ); -my $requiredForType = { input_file => 'input_file_path' }; +my $requiredForType = { input_file => 'inputFilePath' }; say "Running Annotation queue server"; @@ -125,8 +125,8 @@ data => encode_json( { event => $STARTED, - submissionID => $jobDataHref->{submissionID}, - queueID => $job->id, + submissionId => $jobDataHref->{submissionId}, + queueId => $job->id, } ) } @@ -161,8 +161,8 @@ my $data = { event => $FAILED, reason => $err, - queueID => $job->id, - submissionID => $jobDataHref->{submissionID}, + queueId => $job->id, + submissionId => $jobDataHref->{submissionId}, }; $beanstalkEvents->put( { priority => 0, data => encode_json($data) } ); @@ -179,9 +179,9 @@ my $data = { event => $COMPLETED, - queueID => $job->id, - submissionID => $jobDataHref->{submissionID}, - results => { output_file_names => $outputFileNamesHashRef, } + queueId => $job->id, + submissionId => $jobDataHref->{submissionId}, + results => { outputFileNames => $outputFileNamesHashRef, } }; if ( defined $debug ) { @@ -213,7 +213,7 @@ sub coerceInputs { my %jobSpecificArgs; for my $key ( keys %requiredForAll ) { if ( !defined $jobDetailsHref->{ $requiredForAll{$key} } ) { - $err = "Missing required key: $key in job message"; + $err = "Missing required key: $requiredForAll{$key} in job message"; return ( $err, undef ); } @@ -243,8 +243,8 @@ sub coerceInputs { queue => $queueConfig->{events}, messageBase => { event => $PROGRESS, - queueID => $queueId, - submissionID => $jobDetailsHref->{submissionID}, + queueId => $queueId, + submissionId => $jobDetailsHref->{submissionId}, data => undef, } }, diff --git a/python/python/bystro/ancestry/ancestry_types.py b/python/python/bystro/ancestry/ancestry_types.py index 5bc29b4db..e6b4fca42 100644 --- a/python/python/bystro/ancestry/ancestry_types.py +++ b/python/python/bystro/ancestry/ancestry_types.py @@ -4,7 +4,7 @@ LOWER_UNIT_BOUND = 0.0 UPPER_UNIT_BOUND = 1.0 -class ProbabilityInterval(Struct): +class ProbabilityInterval(Struct, rename="camel"): """Represent an interval of probabilities.""" lower_bound: float @@ -105,7 +105,7 @@ def __post_init__(self): raise TypeError(f"probability must be between {LOWER_UNIT_BOUND} and {UPPER_UNIT_BOUND}") -class AncestryScoresOneSample(Struct, frozen=True): +class AncestryScoresOneSample(Struct, frozen=True, rename="camel"): """An ancestry result for a sample. Represents ancestry model output for an individual study diff --git a/python/python/bystro/ancestry/inference.py b/python/python/bystro/ancestry/inference.py index fb8a354a9..2d0eedfb0 100644 --- a/python/python/bystro/ancestry/inference.py +++ b/python/python/bystro/ancestry/inference.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) -class AncestryModel(Struct, frozen=True, forbid_unknown_fields=True): +class AncestryModel(Struct, frozen=True, forbid_unknown_fields=True, rename="camel"): """Bundle together PCA and RFC models for bookkeeping purposes.""" pca_loadings_df: pd.DataFrame diff --git a/python/python/bystro/ancestry/listener.py b/python/python/bystro/ancestry/listener.py index 104623164..a95af2b3b 100644 --- a/python/python/bystro/ancestry/listener.py +++ b/python/python/bystro/ancestry/listener.py @@ -45,13 +45,13 @@ def _get_model_from_s3() -> AncestryModel: return AncestryModel(pca_loadings_df, rfc) -class AncestryJobData(BaseMessage, frozen=True): +class AncestryJobData(BaseMessage, frozen=True, rename="camel"): """ The expected JSON message for the Ancestry job. Parameters ---------- - submissionID: str + submission_id: str The unique identifier for the job. dosage_matrix_path: str The path to the dosage matrix file. @@ -63,7 +63,7 @@ class AncestryJobData(BaseMessage, frozen=True): out_dir: str -class AncestryJobCompleteMessage(CompletedJobMessage, frozen=True, kw_only=True): +class AncestryJobCompleteMessage(CompletedJobMessage, frozen=True, kw_only=True, rename="camel"): """The returned JSON message expected by the API server""" result_path: str @@ -100,7 +100,7 @@ def handler_fn(publisher: ProgressPublisher, job_data: AncestryJobData) -> Ances def submit_msg_fn(ancestry_job_data: AncestryJobData) -> SubmittedJobMessage: """Acknowledge receipt of AncestryJobData.""" logger.debug("entering submit_msg_fn: %s", ancestry_job_data) - return SubmittedJobMessage(ancestry_job_data.submissionID) + return SubmittedJobMessage(ancestry_job_data.submission_id) def completed_msg_fn( @@ -116,7 +116,9 @@ def completed_msg_fn( with open(out_path, "wb") as f: f.write(json_data) - return AncestryJobCompleteMessage(submissionID=ancestry_job_data.submissionID, result_path=out_path) + return AncestryJobCompleteMessage( + submission_id=ancestry_job_data.submission_id, result_path=out_path + ) def main(ancestry_model: AncestryModel, queue_conf: QueueConf) -> None: diff --git a/python/python/bystro/ancestry/tests/test_listener.py b/python/python/bystro/ancestry/tests/test_listener.py index cbbba85f1..825d0efb5 100644 --- a/python/python/bystro/ancestry/tests/test_listener.py +++ b/python/python/bystro/ancestry/tests/test_listener.py @@ -1,3 +1,4 @@ +from msgspec import json import pyarrow.feather as feather # type: ignore from bystro.ancestry.listener import ( @@ -24,7 +25,7 @@ def test_submit_fn(): ancestry_job_data = AncestryJobData( - submissionID="my_submission_id2", + submission_id="my_submission_id2", dosage_matrix_path="some_dosage.feather", out_dir="/path/to/some/dir", ) @@ -39,12 +40,12 @@ def test_handler_fn_happy_path(tmpdir): feather.write_feather(FAKE_GENOTYPES_DOSAGE_MATRIX.to_table(), str(f1)) - progress_message = ProgressMessage(submissionID="my_submission_id") + progress_message = ProgressMessage(submission_id="my_submission_id") publisher = ProgressPublisher( host="127.0.0.1", port=1234, queue="my_queue", message=progress_message ) ancestry_job_data = AncestryJobData( - submissionID="my_submission_id2", dosage_matrix_path=f1, out_dir=str(tmpdir) + submission_id="my_submission_id2", dosage_matrix_path=f1, out_dir=str(tmpdir) ) ancestry_response = handler_fn(publisher, ancestry_job_data) @@ -61,7 +62,7 @@ def test_handler_fn_happy_path(tmpdir): def test_completion_fn(tmpdir): ancestry_job_data = AncestryJobData( - submissionID="my_submission_id2", dosage_matrix_path="some_dosage.feather", out_dir=str(tmpdir) + submission_id="my_submission_id2", dosage_matrix_path="some_dosage.feather", out_dir=str(tmpdir) ) ancestry_results, _ = _infer_ancestry() @@ -69,3 +70,41 @@ def test_completion_fn(tmpdir): completed_msg = completed_msg_fn(ancestry_job_data, ancestry_results) assert isinstance(completed_msg, AncestryJobCompleteMessage) + + +def test_completion_message(): + ancestry_job_data = AncestryJobCompleteMessage( + submission_id="my_submission_id2", result_path="some_dosage.feather" + ) + + serialized_values = json.encode(ancestry_job_data) + expected_value = { + "submissionId": "my_submission_id2", + "event": "completed", + "resultPath": "some_dosage.feather", + } + serialized_expected_value = json.encode(expected_value) + + assert serialized_values == serialized_expected_value + + deserialized_values = json.decode(serialized_expected_value, type=AncestryJobCompleteMessage) + assert deserialized_values == ancestry_job_data + + +def test_job_data_from_beanstalkd(): + ancestry_job_data = AncestryJobData( + submission_id="my_submission_id2", dosage_matrix_path="some_dosage.feather", out_dir="/foo" + ) + + serialized_values = json.encode(ancestry_job_data) + expected_value = { + "submissionId": "my_submission_id2", + "dosageMatrixPath": "some_dosage.feather", + "outDir": "/foo", + } + serialized_expected_value = json.encode(expected_value) + + assert serialized_values == serialized_expected_value + + deserialized_values = json.decode(serialized_expected_value, type=AncestryJobData) + assert deserialized_values == ancestry_job_data diff --git a/python/python/bystro/api/cli.py b/python/python/bystro/api/cli.py index 339c7f737..1f033d5c9 100644 --- a/python/python/bystro/api/cli.py +++ b/python/python/bystro/api/cli.py @@ -20,7 +20,7 @@ } -class SignupResponse(Struct): +class SignupResponse(Struct, rename="camel"): """ The response body for signing up for Bystro. @@ -33,7 +33,7 @@ class SignupResponse(Struct): access_token: str -class LoginResponse(Struct): +class LoginResponse(Struct, rename="camel"): """ The response body for logging in to Bystro. @@ -46,7 +46,7 @@ class LoginResponse(Struct): access_token: str -class CachedAuth(Struct): +class CachedAuth(Struct, rename="camel"): """ The authentication state. @@ -65,7 +65,7 @@ class CachedAuth(Struct): url: str -class JobBasicResponse(Struct): +class JobBasicResponse(Struct, rename="camel"): """ The basic job information, returned in job list commands @@ -84,7 +84,7 @@ class JobBasicResponse(Struct): createdAt: datetime.datetime -class UserProfile(Struct): +class UserProfile(Struct, rename="camel"): """ The response body for fetching the user profile. diff --git a/python/python/bystro/api/tests/test_cli.py b/python/python/bystro/api/tests/test_cli.py index 30512577b..07037f641 100644 --- a/python/python/bystro/api/tests/test_cli.py +++ b/python/python/bystro/api/tests/test_cli.py @@ -78,7 +78,7 @@ "_id": "64db4e68fb86b79cbda4f387", "type": "annotation", "submittedDate": "2023-08-15T10:07:36.027Z", - "queueID": "1538", + "queueId": "1538", "startedDate": "2023-08-15T10:07:37.045Z", }, "config": json.encode( diff --git a/python/python/bystro/beanstalkd/messages.py b/python/python/bystro/beanstalkd/messages.py index c280b5745..c88a495f2 100644 --- a/python/python/bystro/beanstalkd/messages.py +++ b/python/python/bystro/beanstalkd/messages.py @@ -6,6 +6,7 @@ SubmissionID = str | int BeanstalkJobID = int + class Event(str, Enum): """Beanstalkd Event""" @@ -14,27 +15,32 @@ class Event(str, Enum): STARTED = "started" COMPLETED = "completed" -class BaseMessage(Struct, frozen=True): - submissionID: SubmissionID + +class BaseMessage(Struct, frozen=True, rename="camel"): + submission_id: SubmissionID @classmethod def keys_with_types(cls) -> dict: return get_type_hints(cls) + class SubmittedJobMessage(BaseMessage, frozen=True): event: Event = Event.STARTED + class CompletedJobMessage(BaseMessage, frozen=True): event: Event = Event.COMPLETED + class FailedJobMessage(BaseMessage, frozen=True): reason: str event: Event = Event.FAILED -class InvalidJobMessage(Struct, frozen=True): + +class InvalidJobMessage(Struct, frozen=True, rename="camel"): # Invalid jobs that are invalid because the submission breaks serialization invariants - # will not have a submissionID as that ID is held in the serialized data - queueID: BeanstalkJobID + # will not have a submission_id as that ID is held in the serialized data + queue_id: BeanstalkJobID reason: str event: Event = Event.FAILED @@ -42,10 +48,12 @@ class InvalidJobMessage(Struct, frozen=True): def keys_with_types(cls) -> dict: return get_type_hints(cls) + class ProgressData(Struct): progress: int = 0 skipped: int = 0 + class ProgressMessage(BaseMessage, frozen=True): """Beanstalkd Message""" diff --git a/python/python/bystro/beanstalkd/tests/test_messages.py b/python/python/bystro/beanstalkd/tests/test_messages.py index 2dd737e33..a3563bfbe 100644 --- a/python/python/bystro/beanstalkd/tests/test_messages.py +++ b/python/python/bystro/beanstalkd/tests/test_messages.py @@ -1,7 +1,7 @@ import unittest from enum import Enum -from msgspec import Struct +from msgspec import Struct, json from msgspec.inspect import type_info, Field, IntType, NODEFAULT from bystro.beanstalkd.messages import ( @@ -45,13 +45,13 @@ def test_struct(self): t_immutable = ImmutableT(a=1, b="test") with self.assertRaisesRegex(AttributeError, "immutable type: 'ImmutableT'"): - t_immutable.a = 2 # type: ignore + t_immutable.a = 2 # type: ignore t_default = DefaultT(a=1) self.assertEqual(t_default.b, "test") with self.assertRaisesRegex(TypeError, "Missing required argument 'c'"): - InheritedT(a=5, b="test") # type: ignore + InheritedT(a=5, b="test") # type: ignore def test_event_enum(self): self.assertTrue(issubclass(Event, Enum)) @@ -63,91 +63,91 @@ def test_event_enum(self): def test_base_message(self): self.assertTrue(issubclass(BaseMessage, Struct)) - t = BaseMessage(submissionID="test") + t = BaseMessage(submission_id="test") types = BaseMessage.keys_with_types() - self.assertEqual(t.submissionID, "test") - self.assertEqual(set(types.keys()), set({"submissionID"})) - self.assertEqual(types["submissionID"], SubmissionID) + self.assertEqual(t.submission_id, "test") + self.assertEqual(set(types.keys()), set({"submission_id"})) + self.assertEqual(types["submission_id"], SubmissionID) with self.assertRaisesRegex(AttributeError, "immutable type: 'BaseMessage'"): - t.submissionID = "test2" # type: ignore + t.submission_id = "test2" # type: ignore def test_submitted_job_message(self): self.assertTrue(issubclass(SubmittedJobMessage, BaseMessage)) - t = SubmittedJobMessage(submissionID="test") + t = SubmittedJobMessage(submission_id="test") types = SubmittedJobMessage.keys_with_types() self.assertEqual(t.event, Event.STARTED) - self.assertEqual(set(types.keys()), set({"submissionID", "event"})) - self.assertEqual(types["submissionID"], SubmissionID) + self.assertEqual(set(types.keys()), set({"submission_id", "event"})) + self.assertEqual(types["submission_id"], SubmissionID) self.assertEqual(types["event"], Event) with self.assertRaisesRegex(AttributeError, "immutable type: 'SubmittedJobMessage'"): - t.submissionID = "test2" # type: ignore + t.submission_id = "test2" # type: ignore with self.assertRaisesRegex(AttributeError, "immutable type: 'SubmittedJobMessage'"): - t.event = Event.PROGRESS # type: ignore + t.event = Event.PROGRESS # type: ignore def test_completed_job_message(self): self.assertTrue(issubclass(CompletedJobMessage, BaseMessage)) - t = CompletedJobMessage(submissionID="test") + t = CompletedJobMessage(submission_id="test") types = CompletedJobMessage.keys_with_types() self.assertEqual(t.event, Event.COMPLETED) - self.assertEqual(set(types.keys()), set({"submissionID", "event"})) - self.assertEqual(types["submissionID"], SubmissionID) + self.assertEqual(set(types.keys()), set({"submission_id", "event"})) + self.assertEqual(types["submission_id"], SubmissionID) self.assertEqual(types["event"], Event) with self.assertRaisesRegex(AttributeError, "immutable type: 'CompletedJobMessage'"): - t.submissionID = "test2" # type: ignore + t.submission_id = "test2" # type: ignore with self.assertRaisesRegex(AttributeError, "immutable type: 'CompletedJobMessage'"): - t.event = Event.PROGRESS # type: ignore + t.event = Event.PROGRESS # type: ignore def test_failed_job_message(self): self.assertTrue(issubclass(FailedJobMessage, BaseMessage)) with self.assertRaisesRegex(TypeError, "Missing required argument 'reason'"): - t = FailedJobMessage(submissionID="test") # type: ignore + t = FailedJobMessage(submission_id="test") # type: ignore - t = FailedJobMessage(submissionID="test", reason="foo") + t = FailedJobMessage(submission_id="test", reason="foo") types = FailedJobMessage.keys_with_types() self.assertEqual(t.event, Event.FAILED) - self.assertEqual(set(types.keys()), set({"submissionID", "event", "reason"})) - self.assertEqual(types["submissionID"], SubmissionID) + self.assertEqual(set(types.keys()), set({"submission_id", "event", "reason"})) + self.assertEqual(types["submission_id"], SubmissionID) self.assertEqual(types["event"], Event) self.assertEqual(types["reason"], str) with self.assertRaisesRegex(AttributeError, "immutable type: 'FailedJobMessage'"): - t.event = "test2" # type: ignore + t.event = "test2" # type: ignore def test_invalid_job_message(self): self.assertTrue(issubclass(InvalidJobMessage, Struct)) self.assertTrue(not issubclass(InvalidJobMessage, BaseMessage)) with self.assertRaisesRegex(TypeError, "Missing required argument 'reason'"): - t = InvalidJobMessage(queueID="test") # type: ignore + t = InvalidJobMessage(queue_id="test") # type: ignore - t = InvalidJobMessage(queueID=1, reason="foo") + t = InvalidJobMessage(queue_id=1, reason="foo") types = InvalidJobMessage.keys_with_types() self.assertEqual(t.event, Event.FAILED) - self.assertEqual(set(types.keys()), set({"queueID", "event", "reason"})) - self.assertEqual(types["queueID"], BeanstalkJobID) + self.assertEqual(set(types.keys()), set({"queue_id", "event", "reason"})) + self.assertEqual(types["queue_id"], BeanstalkJobID) self.assertEqual(types["event"], Event) self.assertEqual(types["reason"], str) with self.assertRaisesRegex(AttributeError, "immutable type: 'InvalidJobMessage'"): - t.event = "test2" # type: ignore + t.event = "test2" # type: ignore def test_progress_data(self): self.assertTrue(issubclass(ProgressData, Struct)) - types = list(type_info(ProgressData).fields) # type: ignore + types = list(type_info(ProgressData).fields) # type: ignore expected_types = [ Field( @@ -179,20 +179,20 @@ def test_progress_data(self): def test_progress_message(self): self.assertTrue(issubclass(ProgressMessage, BaseMessage)) - with self.assertRaisesRegex(TypeError, "Missing required argument 'submissionID'"): - t = ProgressMessage() # type: ignore + with self.assertRaisesRegex(TypeError, "Missing required argument 'submission_id'"): + t = ProgressMessage() # type: ignore - t = ProgressMessage(submissionID="test") + t = ProgressMessage(submission_id="test") types = ProgressMessage.keys_with_types() self.assertEqual(t.event, Event.PROGRESS) - self.assertEqual(set(types.keys()), set({"submissionID", "event", "data"})) - self.assertEqual(types["submissionID"], SubmissionID) + self.assertEqual(set(types.keys()), set({"submission_id", "event", "data"})) + self.assertEqual(types["submission_id"], SubmissionID) self.assertEqual(types["event"], Event) self.assertEqual(types["data"], ProgressData) with self.assertRaisesRegex(AttributeError, "immutable type: 'ProgressMessage'"): - t.submissionID = "test2" # type: ignore + t.submission_id = "test2" # type: ignore t.data.progress = 1 t.data.skipped = 2 @@ -200,6 +200,90 @@ def test_progress_message(self): self.assertEqual(t.data.progress, 1) self.assertEqual(t.data.skipped, 2) + def test_job_started_message_camel_decamel(self): + msg = SubmittedJobMessage(submission_id="my_submission_id2") + + serialized_values = json.encode(msg) + expected_value = {"submissionId": "my_submission_id2", "event": "started"} + serialized_expected_value = json.encode(expected_value) + + self.assertEqual(serialized_values, serialized_expected_value) + + deserialized_values = json.decode(serialized_expected_value, type=SubmittedJobMessage) + + self.assertEqual(deserialized_values, msg) + + def test_job_completed_message_camel_decamel(self): + msg = CompletedJobMessage(submission_id="my_submission_id2") + + serialized_values = json.encode(msg) + expected_value = {"submissionId": "my_submission_id2", "event": "completed"} + serialized_expected_value = json.encode(expected_value) + + self.assertEqual(serialized_values, serialized_expected_value) + + deserialized_values = json.decode(serialized_expected_value, type=CompletedJobMessage) + + self.assertEqual(deserialized_values, msg) + + def test_job_failed_message_camel_decamel(self): + msg = FailedJobMessage(submission_id="my_submission_id2", reason="foo") + + serialized_values = json.encode(msg) + expected_value = {"submissionId": "my_submission_id2", "reason": "foo", "event": "failed"} + serialized_expected_value = json.encode(expected_value) + + self.assertEqual(serialized_values, serialized_expected_value) + + deserialized_values = json.decode(serialized_expected_value, type=FailedJobMessage) + + self.assertEqual(deserialized_values, msg) + + def test_invalid_job_message_camel_decamel(self): + msg = InvalidJobMessage(queue_id=1, reason="foo") + + serialized_values = json.encode(msg) + expected_value = {"queueId": 1, "reason": "foo", "event": "failed"} + serialized_expected_value = json.encode(expected_value) + + self.assertEqual(serialized_values, serialized_expected_value) + + deserialized_values = json.decode(serialized_expected_value, type=InvalidJobMessage) + + self.assertEqual(deserialized_values, msg) + + def test_progress_message_camel_decamel(self): + msg = ProgressMessage( + submission_id="my_submission_id2", data=ProgressData(progress=1, skipped=2) + ) + + serialized_values = json.encode(msg) + expected_value = { + "submissionId": "my_submission_id2", + "event": "progress", + "data": {"progress": 1, "skipped": 2}, + } + serialized_expected_value = json.encode(expected_value) + + self.assertEqual(serialized_values, serialized_expected_value) + + deserialized_values = json.decode(serialized_expected_value, type=ProgressMessage) + + self.assertEqual(deserialized_values, msg) + + def test_base_message_camel_decamel(self): + msg = BaseMessage(submission_id="my_submission_id2") + + serialized_values = json.encode(msg) + expected_value = {"submissionId": "my_submission_id2"} + serialized_expected_value = json.encode(expected_value) + + self.assertEqual(serialized_values, serialized_expected_value) + + deserialized_values = json.decode(serialized_expected_value, type=BaseMessage) + + self.assertEqual(deserialized_values, msg) + if __name__ == "__main__": unittest.main() diff --git a/python/python/bystro/beanstalkd/worker.py b/python/python/bystro/beanstalkd/worker.py index 335e97742..0f65ee544 100644 --- a/python/python/bystro/beanstalkd/worker.py +++ b/python/python/bystro/beanstalkd/worker.py @@ -60,8 +60,8 @@ def default_failed_msg_fn( ) -> FailedJobMessage | InvalidJobMessage: # noqa: E501 """Default failed message function""" if job_data is None: - return InvalidJobMessage(queueID=job_id, reason=str(err)) - return FailedJobMessage(submissionID=job_data.submissionID, reason=str(err)) + return InvalidJobMessage(queue_id=job_id, reason=str(err)) + return FailedJobMessage(submission_id=job_data.submission_id, reason=str(err)) def listen( @@ -145,7 +145,7 @@ def listen( host=client.host, port=client.port, queue=tube_conf["events"], - message=ProgressMessage(submissionID=job_data.submissionID), + message=ProgressMessage(submission_id=job_data.submission_id), ) client.put_job(json.encode(submit_msg_fn(job_data))) diff --git a/python/python/bystro/proteomics/proteomics_listener.py b/python/python/bystro/proteomics/proteomics_listener.py index 7b75ee2d0..7efe87fc0 100644 --- a/python/python/bystro/proteomics/proteomics_listener.py +++ b/python/python/bystro/proteomics/proteomics_listener.py @@ -34,7 +34,7 @@ class ProteomicsJobCompleteMessage(CompletedJobMessage, frozen=True, kw_only=Tru def submit_msg_fn(proteomics_job_data: ProteomicsJobData) -> SubmittedJobMessage: """Acknowledge receipt of ProteomicsJobData.""" logger.debug("entering submit_msg_fn: %s", proteomics_job_data) - return SubmittedJobMessage(proteomics_job_data.submissionID) + return SubmittedJobMessage(proteomics_job_data.submission_id) def handler_fn( @@ -60,7 +60,7 @@ def completed_msg_fn( ) raise ValueError(err_msg) return ProteomicsJobCompleteMessage( - submissionID=proteomics_job_data.submissionID, results=proteomics_response + submission_id=proteomics_job_data.submission_id, results=proteomics_response ) diff --git a/python/python/bystro/proteomics/tests/test_proteomics_listener.py b/python/python/bystro/proteomics/tests/test_proteomics_listener.py index 0928b16c8..a64dc0b24 100644 --- a/python/python/bystro/proteomics/tests/test_proteomics_listener.py +++ b/python/python/bystro/proteomics/tests/test_proteomics_listener.py @@ -27,20 +27,20 @@ def test_submit_msg_fn_happy_path(): proteomics_submission = ProteomicsSubmission("foo.tsv") proteomics_job_data = ProteomicsJobData( - submissionID="my_submission_id", proteomics_submission=proteomics_submission + submission_id="my_submission_id", proteomics_submission=proteomics_submission ) submitted_job_message = submit_msg_fn(proteomics_job_data) - assert proteomics_job_data.submissionID == submitted_job_message.submissionID + assert proteomics_job_data.submission_id == submitted_job_message.submission_id def test_handler_fn_happy_path(): - progress_message = ProgressMessage(submissionID="my_submission_id") + progress_message = ProgressMessage(submission_id="my_submission_id") publisher = ProgressPublisher( host="127.0.0.1", port=1234, queue="my_queue", message=progress_message ) proteomics_submission = ProteomicsSubmission("foo.tsv") proteomics_job_data = ProteomicsJobData( - submissionID="my_submission_id2", proteomics_submission=proteomics_submission + submission_id="my_submission_id2", proteomics_submission=proteomics_submission ) with patch(LOAD_FRAGPIPE_DATASET_PATCH_TARGET, return_value=FAKE_FRAGPIPE_DF) as _mock: proteomics_response = handler_fn(publisher, proteomics_job_data) @@ -48,37 +48,37 @@ def test_handler_fn_happy_path(): def test_completed_msg_fn_happy_path(): - progress_message = ProgressMessage(submissionID="my_submission_id") + progress_message = ProgressMessage(submission_id="my_submission_id") publisher = ProgressPublisher( host="127.0.0.1", port=1234, queue="my_queue", message=progress_message ) proteomics_submission = ProteomicsSubmission("foo.tsv") proteomics_job_data = ProteomicsJobData( - submissionID="my_submission_id", proteomics_submission=proteomics_submission + submission_id="my_submission_id", proteomics_submission=proteomics_submission ) with patch(LOAD_FRAGPIPE_DATASET_PATCH_TARGET, return_value=FAKE_FRAGPIPE_DF) as _mock: proteomics_response = handler_fn(publisher, proteomics_job_data) proteomics_job_complete_message = completed_msg_fn(proteomics_job_data, proteomics_response) - assert proteomics_job_complete_message.submissionID == proteomics_job_data.submissionID + assert proteomics_job_complete_message.submission_id == proteomics_job_data.submission_id assert proteomics_job_complete_message.results == proteomics_response def test_completed_msg_fn_filenames_dont_match(): - progress_message = ProgressMessage(submissionID="my_submission_id") + progress_message = ProgressMessage(submission_id="my_submission_id") publisher = ProgressPublisher( host="127.0.0.1", port=1234, queue="my_queue", message=progress_message ) proteomics_submission = ProteomicsSubmission("foo.tsv") proteomics_job_data = ProteomicsJobData( - submissionID="my_submission_id", proteomics_submission=proteomics_submission + submission_id="my_submission_id", proteomics_submission=proteomics_submission ) wrong_proteomics_submission = ProteomicsSubmission("wrong_file.tsv") wrong_proteomics_job_data = ProteomicsJobData( - submissionID="wrong_submission_id", proteomics_submission=wrong_proteomics_submission + submission_id="wrong_submission_id", proteomics_submission=wrong_proteomics_submission ) with patch(LOAD_FRAGPIPE_DATASET_PATCH_TARGET, return_value=FAKE_FRAGPIPE_DF) as _mock: diff --git a/python/python/bystro/prs/listener.py b/python/python/bystro/prs/listener.py index 163ecd379..9fb882f9b 100644 --- a/python/python/bystro/prs/listener.py +++ b/python/python/bystro/prs/listener.py @@ -16,11 +16,11 @@ def submit_msg_fn(job_data: PRSJobData): - return SubmittedJobMessage(submissionID=job_data.submissionID) + return SubmittedJobMessage(submission_id=job_data.submission_id) def completed_msg_fn(job_data: PRSJobData, results: PRSJobResult) -> PRSJobResultMessage: - return PRSJobResultMessage(submissionID=job_data.submissionID, results=results) + return PRSJobResultMessage(submission_id=job_data.submission_id, results=results) def main(): diff --git a/python/python/bystro/search/index/listener.py b/python/python/bystro/search/index/listener.py index a7d868aab..0639bbf2d 100644 --- a/python/python/bystro/search/index/listener.py +++ b/python/python/bystro/search/index/listener.py @@ -2,6 +2,7 @@ CLI tool to start search indexing server that listens to beanstalkd queue and indexes submitted data in Opensearch """ + import argparse import os import subprocess @@ -142,7 +143,7 @@ def handler_fn(_: ProgressPublisher, beanstalkd_job_data: IndexJobData) -> list[ header_fields = run_handler_with_config( index_name=beanstalkd_job_data.index_name, - submission_id=beanstalkd_job_data.submissionID, + submission_id=beanstalkd_job_data.submission_id, mapping_config=m_path, opensearch_config=search_conf, queue_config=queue_conf, @@ -152,7 +153,8 @@ def handler_fn(_: ProgressPublisher, beanstalkd_job_data: IndexJobData) -> list[ return header_fields def submit_msg_fn(job_data: IndexJobData): - return SubmittedJobMessage(job_data.submissionID) + print("jbo_data", job_data) + return SubmittedJobMessage(job_data.submission_id) def completed_msg_fn(job_data: IndexJobData, field_names: list[str]): mapping_config_path = get_config_file_path(conf_dir, job_data.assembly, ".mapping.y*ml") @@ -164,7 +166,8 @@ def completed_msg_fn(job_data: IndexJobData, field_names: list[str]): shutil.copyfile(mapping_config_path, map_config_out_path) return IndexJobCompleteMessage( - submissionID=job_data.submissionID, results=IndexJobResults(map_config_basename, field_names) + submission_id=job_data.submission_id, + results=IndexJobResults(map_config_basename, field_names), ) # noqa: E501 listen( diff --git a/python/python/bystro/search/save/listener.py b/python/python/bystro/search/save/listener.py index 13e9e5470..f95b0fbd9 100644 --- a/python/python/bystro/search/save/listener.py +++ b/python/python/bystro/search/save/listener.py @@ -56,11 +56,11 @@ def handler_fn(publisher: ProgressPublisher, job_data: SaveJobData): return go(job_data=job_data, search_conf=search_conf, publisher=publisher) def submit_msg_fn(job_data: SaveJobData): - return SubmittedJobMessage(submissionID=job_data.submissionID) + return SubmittedJobMessage(submission_id=job_data.submission_id) def completed_msg_fn(job_data: SaveJobData, results: AnnotationOutputs) -> SaveJobCompleteMessage: return SaveJobCompleteMessage( - submissionID=job_data.submissionID, results=SaveJobResults(results) + submission_id=job_data.submission_id, results=SaveJobResults(results) ) listen( diff --git a/python/python/bystro/search/utils/annotation.py b/python/python/bystro/search/utils/annotation.py index 9bae76b11..9d5a777b2 100644 --- a/python/python/bystro/search/utils/annotation.py +++ b/python/python/bystro/search/utils/annotation.py @@ -21,16 +21,16 @@ class StatisticsOutputExtensions(Struct, frozen=True, forbid_unknown_fields=True qc: str = "statistics.qc.tsv" -class StatisticsConfig(Struct, frozen=True, forbid_unknown_fields=True): - dbSNPnameField: str = "dbSNP.name" - siteTypeField: str = "refSeq.siteType" - exonicAlleleFunctionField: str = "refSeq.exonicAlleleFunction" - refField: str = "ref" - homozygotesField: str = "homozygotes" - heterozygotesField: str = "heterozygotes" - altField: str = "alt" - programPath: str = "bystro-stats" - outputExtensions: StatisticsOutputExtensions = StatisticsOutputExtensions() +class StatisticsConfig(Struct, frozen=True, forbid_unknown_fields=True, rename="camel"): + dbsnp_name_field: str = "dbSNP.name" + site_type_field: str = "refSeq.siteType" + exonic_allele_function_field: str = "refSeq.exonicAlleleFunction" + ref_field: str = "ref" + homozygotes_field: str = "homozygotes" + heterozygotes_field: str = "heterozygotes" + alt_field: str = "alt" + program_path: str = "bystro-stats" + output_extension: StatisticsOutputExtensions = StatisticsOutputExtensions() @staticmethod def from_dict(annotation_config: dict[str, Any]): @@ -69,7 +69,7 @@ class StatisticsOutputs(Struct, frozen=True, forbid_unknown_fields=True): qc: str -class AnnotationOutputs(Struct, frozen=True, forbid_unknown_fields=True): +class AnnotationOutputs(Struct, frozen=True, forbid_unknown_fields=True, rename="camel"): """ Paths to all possible Bystro annotation outputs @@ -78,7 +78,7 @@ class AnnotationOutputs(Struct, frozen=True, forbid_unknown_fields=True): Output directory annotation: str Basename of the annotation TSV file, in the output directory - sampleList: Optional[str] + sample_list: Optional[str] Basename of the sample list file, in the output directory log: str Basename of the log file, in the output directory @@ -86,7 +86,7 @@ class AnnotationOutputs(Struct, frozen=True, forbid_unknown_fields=True): Basename of the config file, in the output directory statistics: StatisticsOutputs Basenames of the statistics files, in the output directory - dosageMatrixOutPath: str + dosage_matrix_out_path: str Basename of the dosage matrix, in the output directory header: Optional[str] Basename of the header file, in the output directory @@ -95,11 +95,11 @@ class AnnotationOutputs(Struct, frozen=True, forbid_unknown_fields=True): """ annotation: str - sampleList: str + sample_list: str log: str config: str statistics: StatisticsOutputs - dosageMatrixOutPath: str + dosage_matrix_out_path: str header: str | None = None archived: str | None = None @@ -139,10 +139,10 @@ def from_path( return ( AnnotationOutputs( annotation=annotation, - sampleList=sample_list, + sample_list=sample_list, statistics=statistics_output_members, config=annotation_config_path, - dosageMatrixOutPath=dosage, + dosage_matrix_out_path=dosage, log=log, ), stats, @@ -198,18 +198,18 @@ def __init__( self._config = StatisticsConfig.from_dict(annotation_config) self._delimiters = DelimitersConfig.from_dict(annotation_config) - program_path = shutil.which(self._config.programPath) + program_path = shutil.which(self._config.program_path) if not program_path: raise ValueError( - f"Couldn't find statistics program {self._config.programPath}" + f"Couldn't find statistics program {self._config.program_path}" ) self.program_path = program_path self.json_output_path = ( - f"{output_base_path}.{self._config.outputExtensions.json}" + f"{output_base_path}.{self._config.output_extension.json}" ) - self.tsv_output_path = f"{output_base_path}.{self._config.outputExtensions.tsv}" - self.qc_output_path = f"{output_base_path}.{self._config.outputExtensions.qc}" + self.tsv_output_path = f"{output_base_path}.{self._config.output_extension.tsv}" + self.qc_output_path = f"{output_base_path}.{self._config.output_extension.qc}" @property def stdin_cli_stats_command(self) -> str: @@ -217,24 +217,24 @@ def stdin_cli_stats_command(self) -> str: field_delim = self._delimiters.field empty_field = self._delimiters.empty_field - het_field = self._config.heterozygotesField - hom_field = self._config.homozygotesField - site_type_field = self._config.siteTypeField - ea_fun_field = self._config.exonicAlleleFunctionField - ref_field = self._config.refField - alt_field = self._config.altField - dbSNP_field = self._config.dbSNPnameField + het_field = self._config.heterozygotes_field + hom_field = self._config.homozygotes_field + site_type_field = self._config.site_type_field + ea_fun_field = self._config.exonic_allele_function_field + ref_field = self._config.ref_field + alt_field = self._config.alt_field + dbsnp_field = self._config.dbsnp_name_field - statsProg = self.program_path + prog = self.program_path - dbSNPpart = f"-dbSnpNameColumn {dbSNP_field}" if dbSNP_field else "" + dbsnp_part = f"-dbSnpNameColumn {dbsnp_field}" if dbsnp_field else "" return ( - f"{statsProg} -outJsonPath {self.json_output_path} -outTabPath {self.tsv_output_path} " + f"{prog} -outJsonPath {self.json_output_path} -outTabPath {self.tsv_output_path} " f"-outQcTabPath {self.qc_output_path} -refColumn {ref_field} " f"-altColumn {alt_field} -homozygotesColumn {hom_field} " f"-heterozygotesColumn {het_field} -siteTypeColumn {site_type_field} " - f"{dbSNPpart} -emptyField {empty_field} " + f"{dbsnp_part} -emptyField {empty_field} " f"-exonicAlleleFunctionColumn {ea_fun_field} " f"-primaryDelimiter '{value_delim}' -fieldSeparator '{field_delim}'" ) diff --git a/python/python/bystro/search/utils/messages.py b/python/python/bystro/search/utils/messages.py index c293dff38..6435499d7 100644 --- a/python/python/bystro/search/utils/messages.py +++ b/python/python/bystro/search/utils/messages.py @@ -8,7 +8,7 @@ from bystro.search.save.binomial_maf import BinomialMafFilter -class IndexJobData(BaseMessage, frozen=True, forbid_unknown_fields=True): +class IndexJobData(BaseMessage, frozen=True, forbid_unknown_fields=True, kw_only=True, rename="camel"): """Data for Indexing jobs received from beanstalkd""" input_dir: str @@ -20,9 +20,9 @@ class IndexJobData(BaseMessage, frozen=True, forbid_unknown_fields=True): field_names: list[str] | None = None -class IndexJobResults(Struct, frozen=True): +class IndexJobResults(Struct, frozen=True, forbid_unknown_fields=True, rename="camel"): index_config_path: str - field_names: list + field_names: list[str] class IndexJobCompleteMessage(CompletedJobMessage, frozen=True, kw_only=True): @@ -32,7 +32,7 @@ class IndexJobCompleteMessage(CompletedJobMessage, frozen=True, kw_only=True): PipelineType = list[BinomialMafFilter | HWEFilter] | None -class SaveJobData(BaseMessage, frozen=True): +class SaveJobData(BaseMessage, frozen=True, forbid_unknown_fields=True, kw_only=True, rename="camel"): """Data for SaveFromQuery jobs received from beanstalkd""" assembly: str @@ -45,9 +45,9 @@ class SaveJobData(BaseMessage, frozen=True): pipeline: PipelineType = None -class SaveJobResults(Struct, frozen=True): +class SaveJobResults(Struct, frozen=True, rename="camel"): output_file_names: AnnotationOutputs -class SaveJobCompleteMessage(CompletedJobMessage, frozen=True, kw_only=True): +class SaveJobCompleteMessage(CompletedJobMessage, frozen=True, kw_only=True, rename="camel"): results: SaveJobResults diff --git a/python/python/bystro/search/utils/tests/test_messages.py b/python/python/bystro/search/utils/tests/test_messages.py new file mode 100644 index 000000000..4cd43fa20 --- /dev/null +++ b/python/python/bystro/search/utils/tests/test_messages.py @@ -0,0 +1,246 @@ +from msgspec import json + +from bystro.search.utils.annotation import AnnotationOutputs, StatisticsOutputs +from bystro.search.utils.messages import ( + IndexJobData, + IndexJobResults, + IndexJobCompleteMessage, + SaveJobData, + SaveJobResults, + SaveJobCompleteMessage, +) + + +def test_index_job_data_camel_decamel(): + job_data = IndexJobData( + submission_id="foo", + input_dir="input_dir", + out_dir="out_dir", + input_file_names=AnnotationOutputs( + annotation="annotation", + sample_list="sample_list", + log="log", + config="config", + statistics=StatisticsOutputs(json="json", tab="tab", qc="qc"), + dosage_matrix_out_path="dosage_matrix_out_path", + header="header", + archived=None, + ), + index_name="index_name", + assembly="assembly", + index_config_path="index_config_path", + field_names=["field1", "field2"], + ) + + serialized_values = json.encode(job_data) + expected_value = { + "submissionId": "foo", + "inputDir": "input_dir", + "outDir": "out_dir", + "inputFileNames": { + "annotation": "annotation", + "sampleList": "sample_list", + "log": "log", + "config": "config", + "statistics": { + "json": "json", + "tab": "tab", + "qc": "qc", + }, + "dosageMatrixOutPath": "dosage_matrix_out_path", + "header": "header", + "archived": None, + }, + "indexName": "index_name", + "assembly": "assembly", + "indexConfigPath": "index_config_path", + "fieldNames": ["field1", "field2"], + } + serialized_expected_value = json.encode(expected_value) + + assert serialized_values == serialized_expected_value + + deserialized_values = json.decode(serialized_expected_value, type=IndexJobData) + assert deserialized_values == job_data + +def test_index_job_results_camel_decamel(): + job_results = IndexJobResults( + index_config_path="index_config_path", + field_names=["field1", "field2"] + ) + + serialized_values = json.encode(job_results) + expected_value = { + "indexConfigPath": "index_config_path", + "fieldNames": ["field1", "field2"] + } + serialized_expected_value = json.encode(expected_value) + + assert serialized_values == serialized_expected_value + + deserialized_values = json.decode(serialized_expected_value, type=IndexJobResults) + assert deserialized_values == job_results + +def test_index_job_complete_message_camel_decamel(): + job_results = IndexJobResults( + index_config_path="index_config_path", + field_names=["field1", "field2"] + ) + completed_msg = IndexJobCompleteMessage( + submission_id="foo", + results=job_results + ) + + serialized_values = json.encode(completed_msg) + expected_value = { + "submissionId": "foo", + "event": "completed", + "results": { + "indexConfigPath": "index_config_path", + "fieldNames": ["field1", "field2"] + } + } + serialized_expected_value = json.encode(expected_value) + + assert serialized_values == serialized_expected_value + + deserialized_values = json.decode(serialized_expected_value, type=IndexJobCompleteMessage) + assert deserialized_values == completed_msg + +def test_save_job_data_camel_decamel(): + job_data = SaveJobData( + submission_id="submit1", + assembly="assembly", + query_body={"query": "body"}, + input_dir="input_dir", + input_file_names=AnnotationOutputs( + annotation="annotation", + sample_list="sample_list", + log="log", + config="config", + statistics=StatisticsOutputs(json="json", tab="tab", qc="qc"), + dosage_matrix_out_path="dosage_matrix_out_path", + header="header", + archived=None, + ), + index_name="index_name", + output_base_path="output_base_path", + field_names=["field1", "field2"], + pipeline=None, + ) + + serialized_values = json.encode(job_data) + expected_value = { + "submissionId": "submit1", + "assembly": "assembly", + "queryBody": {"query": "body"}, + "inputDir": "input_dir", + "inputFileNames": { + "annotation": "annotation", + "sampleList": "sample_list", + "log": "log", + "config": "config", + "statistics": { + "json": "json", + "tab": "tab", + "qc": "qc", + }, + "dosageMatrixOutPath": "dosage_matrix_out_path", + "header": "header", + "archived": None, + }, + "indexName": "index_name", + "outputBasePath": "output_base_path", + "fieldNames": ["field1", "field2"], + "pipeline": None, + } + serialized_expected_value = json.encode(expected_value) + + assert serialized_values == serialized_expected_value + + deserialized_values = json.decode(serialized_expected_value, type=SaveJobData) + assert deserialized_values == job_data + +def test_save_job_results_camel_decamel(): + job_results = SaveJobResults( + output_file_names=AnnotationOutputs( + annotation="annotation", + sample_list="sample_list", + log="log", + config="config", + statistics=StatisticsOutputs(json="json", tab="tab", qc="qc"), + dosage_matrix_out_path="dosage_matrix_out_path", + header="header", + archived=None, + ) + ) + + serialized_values = json.encode(job_results) + expected_value = { + "outputFileNames": { + "annotation": "annotation", + "sampleList": "sample_list", + "log": "log", + "config": "config", + "statistics": { + "json": "json", + "tab": "tab", + "qc": "qc", + }, + "dosageMatrixOutPath": "dosage_matrix_out_path", + "header": "header", + "archived": None, + } + } + serialized_expected_value = json.encode(expected_value) + + assert serialized_values == serialized_expected_value + + deserialized_values = json.decode(serialized_expected_value, type=SaveJobResults) + assert deserialized_values == job_results + +def test_save_job_complete_message_camel_decamel(): + job_results = SaveJobResults( + output_file_names=AnnotationOutputs( + annotation="annotation", + sample_list="sample_list", + log="log", + config="config", + statistics=StatisticsOutputs(json="json", tab="tab", qc="qc"), + dosage_matrix_out_path="dosage_matrix_out_path", + header="header", + archived=None, + ) + ) + completed_msg = SaveJobCompleteMessage( + submission_id="submit1", + results=job_results + ) + + serialized_values = json.encode(completed_msg) + expected_value = { + "submissionId": "submit1", + "event": "completed", + "results": { + "outputFileNames": { + "annotation": "annotation", + "sampleList": "sample_list", + "log": "log", + "config": "config", + "statistics": { + "json": "json", + "tab": "tab", + "qc": "qc", + }, + "dosageMatrixOutPath": "dosage_matrix_out_path", + "header": "header", + "archived": None, + } + } + } + serialized_expected_value = json.encode(expected_value) + + assert serialized_values == serialized_expected_value + + deserialized_values = json.decode(serialized_expected_value, type=SaveJobCompleteMessage) + assert deserialized_values == completed_msg \ No newline at end of file