Skip to content

Commit

Permalink
Issue bystrogenomics#393: Consistent snake_case in python library
Browse files Browse the repository at this point in the history
  • Loading branch information
akotlar committed Feb 8, 2024
1 parent cfbb09d commit 2337c4d
Show file tree
Hide file tree
Showing 21 changed files with 581 additions and 174 deletions.
2 changes: 1 addition & 1 deletion go/beanstalkd/beanstalkd.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type ProgressData struct {
}

type ProgressMessage struct {
SubmissionID string `json:"submissionID"`
SubmissionID string `json:"submissionId"`
Data ProgressData `json:"data"`
Event string `json:"event"`
}
Expand Down
24 changes: 12 additions & 12 deletions perl/bin/bystro-server.pl
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@

# The properties that we accept from the worker caller
my %requiredForAll = (
output_file_base => 'output_base_path',
output_file_base => 'outputBasePath',
assembly => 'assembly',
);

my $requiredForType = { input_file => 'input_file_path' };
my $requiredForType = { input_file => 'inputFilePath' };

say "Running Annotation queue server";

Expand Down Expand Up @@ -125,8 +125,8 @@
data => encode_json(
{
event => $STARTED,
submissionID => $jobDataHref->{submissionID},
queueID => $job->id,
submissionId => $jobDataHref->{submissionId},
queueId => $job->id,
}
)
}
Expand Down Expand Up @@ -161,8 +161,8 @@
my $data = {
event => $FAILED,
reason => $err,
queueID => $job->id,
submissionID => $jobDataHref->{submissionID},
queueId => $job->id,
submissionId => $jobDataHref->{submissionId},
};

$beanstalkEvents->put( { priority => 0, data => encode_json($data) } );
Expand All @@ -179,9 +179,9 @@

my $data = {
event => $COMPLETED,
queueID => $job->id,
submissionID => $jobDataHref->{submissionID},
results => { output_file_names => $outputFileNamesHashRef, }
queueId => $job->id,
submissionId => $jobDataHref->{submissionId},
results => { outputFileNames => $outputFileNamesHashRef, }
};

if ( defined $debug ) {
Expand Down Expand Up @@ -213,7 +213,7 @@ sub coerceInputs {
my %jobSpecificArgs;
for my $key ( keys %requiredForAll ) {
if ( !defined $jobDetailsHref->{ $requiredForAll{$key} } ) {
$err = "Missing required key: $key in job message";
$err = "Missing required key: $requiredForAll{$key} in job message";
return ( $err, undef );
}

Expand Down Expand Up @@ -243,8 +243,8 @@ sub coerceInputs {
queue => $queueConfig->{events},
messageBase => {
event => $PROGRESS,
queueID => $queueId,
submissionID => $jobDetailsHref->{submissionID},
queueId => $queueId,
submissionId => $jobDetailsHref->{submissionId},
data => undef,
}
},
Expand Down
4 changes: 2 additions & 2 deletions python/python/bystro/ancestry/ancestry_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
LOWER_UNIT_BOUND = 0.0
UPPER_UNIT_BOUND = 1.0

class ProbabilityInterval(Struct):
class ProbabilityInterval(Struct, rename="camel"):
"""Represent an interval of probabilities."""

lower_bound: float
Expand Down Expand Up @@ -105,7 +105,7 @@ def __post_init__(self):
raise TypeError(f"probability must be between {LOWER_UNIT_BOUND} and {UPPER_UNIT_BOUND}")


class AncestryScoresOneSample(Struct, frozen=True):
class AncestryScoresOneSample(Struct, frozen=True, rename="camel"):
"""An ancestry result for a sample.
Represents ancestry model output for an individual study
Expand Down
2 changes: 1 addition & 1 deletion python/python/bystro/ancestry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
logger = logging.getLogger(__name__)


class AncestryModel(Struct, frozen=True, forbid_unknown_fields=True):
class AncestryModel(Struct, frozen=True, forbid_unknown_fields=True, rename="camel"):
"""Bundle together PCA and RFC models for bookkeeping purposes."""

pca_loadings_df: pd.DataFrame
Expand Down
12 changes: 7 additions & 5 deletions python/python/bystro/ancestry/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ def _get_model_from_s3() -> AncestryModel:
return AncestryModel(pca_loadings_df, rfc)


class AncestryJobData(BaseMessage, frozen=True):
class AncestryJobData(BaseMessage, frozen=True, rename="camel"):
"""
The expected JSON message for the Ancestry job.
Parameters
----------
submissionID: str
submission_id: str
The unique identifier for the job.
dosage_matrix_path: str
The path to the dosage matrix file.
Expand All @@ -63,7 +63,7 @@ class AncestryJobData(BaseMessage, frozen=True):
out_dir: str


class AncestryJobCompleteMessage(CompletedJobMessage, frozen=True, kw_only=True):
class AncestryJobCompleteMessage(CompletedJobMessage, frozen=True, kw_only=True, rename="camel"):
"""The returned JSON message expected by the API server"""

result_path: str
Expand Down Expand Up @@ -100,7 +100,7 @@ def handler_fn(publisher: ProgressPublisher, job_data: AncestryJobData) -> Ances
def submit_msg_fn(ancestry_job_data: AncestryJobData) -> SubmittedJobMessage:
"""Acknowledge receipt of AncestryJobData."""
logger.debug("entering submit_msg_fn: %s", ancestry_job_data)
return SubmittedJobMessage(ancestry_job_data.submissionID)
return SubmittedJobMessage(ancestry_job_data.submission_id)


def completed_msg_fn(
Expand All @@ -116,7 +116,9 @@ def completed_msg_fn(
with open(out_path, "wb") as f:
f.write(json_data)

return AncestryJobCompleteMessage(submissionID=ancestry_job_data.submissionID, result_path=out_path)
return AncestryJobCompleteMessage(
submission_id=ancestry_job_data.submission_id, result_path=out_path
)


def main(ancestry_model: AncestryModel, queue_conf: QueueConf) -> None:
Expand Down
47 changes: 43 additions & 4 deletions python/python/bystro/ancestry/tests/test_listener.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from msgspec import json
import pyarrow.feather as feather # type: ignore

from bystro.ancestry.listener import (
Expand All @@ -24,7 +25,7 @@

def test_submit_fn():
ancestry_job_data = AncestryJobData(
submissionID="my_submission_id2",
submission_id="my_submission_id2",
dosage_matrix_path="some_dosage.feather",
out_dir="/path/to/some/dir",
)
Expand All @@ -39,12 +40,12 @@ def test_handler_fn_happy_path(tmpdir):

feather.write_feather(FAKE_GENOTYPES_DOSAGE_MATRIX.to_table(), str(f1))

progress_message = ProgressMessage(submissionID="my_submission_id")
progress_message = ProgressMessage(submission_id="my_submission_id")
publisher = ProgressPublisher(
host="127.0.0.1", port=1234, queue="my_queue", message=progress_message
)
ancestry_job_data = AncestryJobData(
submissionID="my_submission_id2", dosage_matrix_path=f1, out_dir=str(tmpdir)
submission_id="my_submission_id2", dosage_matrix_path=f1, out_dir=str(tmpdir)
)
ancestry_response = handler_fn(publisher, ancestry_job_data)

Expand All @@ -61,11 +62,49 @@ def test_handler_fn_happy_path(tmpdir):

def test_completion_fn(tmpdir):
ancestry_job_data = AncestryJobData(
submissionID="my_submission_id2", dosage_matrix_path="some_dosage.feather", out_dir=str(tmpdir)
submission_id="my_submission_id2", dosage_matrix_path="some_dosage.feather", out_dir=str(tmpdir)
)

ancestry_results, _ = _infer_ancestry()

completed_msg = completed_msg_fn(ancestry_job_data, ancestry_results)

assert isinstance(completed_msg, AncestryJobCompleteMessage)


def test_completion_message():
ancestry_job_data = AncestryJobCompleteMessage(
submission_id="my_submission_id2", result_path="some_dosage.feather"
)

serialized_values = json.encode(ancestry_job_data)
expected_value = {
"submissionId": "my_submission_id2",
"event": "completed",
"resultPath": "some_dosage.feather",
}
serialized_expected_value = json.encode(expected_value)

assert serialized_values == serialized_expected_value

deserialized_values = json.decode(serialized_expected_value, type=AncestryJobCompleteMessage)
assert deserialized_values == ancestry_job_data


def test_job_data_from_beanstalkd():
ancestry_job_data = AncestryJobData(
submission_id="my_submission_id2", dosage_matrix_path="some_dosage.feather", out_dir="/foo"
)

serialized_values = json.encode(ancestry_job_data)
expected_value = {
"submissionId": "my_submission_id2",
"dosageMatrixPath": "some_dosage.feather",
"outDir": "/foo",
}
serialized_expected_value = json.encode(expected_value)

assert serialized_values == serialized_expected_value

deserialized_values = json.decode(serialized_expected_value, type=AncestryJobData)
assert deserialized_values == ancestry_job_data
10 changes: 5 additions & 5 deletions python/python/bystro/api/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
}


class SignupResponse(Struct):
class SignupResponse(Struct, rename="camel"):
"""
The response body for signing up for Bystro.
Expand All @@ -33,7 +33,7 @@ class SignupResponse(Struct):
access_token: str


class LoginResponse(Struct):
class LoginResponse(Struct, rename="camel"):
"""
The response body for logging in to Bystro.
Expand All @@ -46,7 +46,7 @@ class LoginResponse(Struct):
access_token: str


class CachedAuth(Struct):
class CachedAuth(Struct, rename="camel"):
"""
The authentication state.
Expand All @@ -65,7 +65,7 @@ class CachedAuth(Struct):
url: str


class JobBasicResponse(Struct):
class JobBasicResponse(Struct, rename="camel"):
"""
The basic job information, returned in job list commands
Expand All @@ -84,7 +84,7 @@ class JobBasicResponse(Struct):
createdAt: datetime.datetime


class UserProfile(Struct):
class UserProfile(Struct, rename="camel"):
"""
The response body for fetching the user profile.
Expand Down
2 changes: 1 addition & 1 deletion python/python/bystro/api/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"_id": "64db4e68fb86b79cbda4f387",
"type": "annotation",
"submittedDate": "2023-08-15T10:07:36.027Z",
"queueID": "1538",
"queueId": "1538",
"startedDate": "2023-08-15T10:07:37.045Z",
},
"config": json.encode(
Expand Down
18 changes: 13 additions & 5 deletions python/python/bystro/beanstalkd/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
SubmissionID = str | int
BeanstalkJobID = int


class Event(str, Enum):
"""Beanstalkd Event"""

Expand All @@ -14,38 +15,45 @@ class Event(str, Enum):
STARTED = "started"
COMPLETED = "completed"

class BaseMessage(Struct, frozen=True):
submissionID: SubmissionID

class BaseMessage(Struct, frozen=True, rename="camel"):
submission_id: SubmissionID

@classmethod
def keys_with_types(cls) -> dict:
return get_type_hints(cls)


class SubmittedJobMessage(BaseMessage, frozen=True):
event: Event = Event.STARTED


class CompletedJobMessage(BaseMessage, frozen=True):
event: Event = Event.COMPLETED


class FailedJobMessage(BaseMessage, frozen=True):
reason: str
event: Event = Event.FAILED

class InvalidJobMessage(Struct, frozen=True):

class InvalidJobMessage(Struct, frozen=True, rename="camel"):
# Invalid jobs that are invalid because the submission breaks serialization invariants
# will not have a submissionID as that ID is held in the serialized data
queueID: BeanstalkJobID
# will not have a submission_id as that ID is held in the serialized data
queue_id: BeanstalkJobID
reason: str
event: Event = Event.FAILED

@classmethod
def keys_with_types(cls) -> dict:
return get_type_hints(cls)


class ProgressData(Struct):
progress: int = 0
skipped: int = 0


class ProgressMessage(BaseMessage, frozen=True):
"""Beanstalkd Message"""

Expand Down
Loading

0 comments on commit 2337c4d

Please sign in to comment.