Skip to content

Commit

Permalink
feat: Orqsdk 1042 support Pydantic v1 in addition to V2 (#373)
Browse files Browse the repository at this point in the history
# The problem

in #372 we bumped pydantic to V2. As raised by @jamesclark-Zapata, we
have conflicts with Amazon bracket and it would be nice to let the
user's pip do the final resolving.

# This PR's solution

Maintains pydantic V2-style syntax throughout our code. We encompass as
much of the compatibility layer in
[src/orquestra/sdk/_base/_storage/orqdantic.py](https://github.com/zapatacomputing/orquestra-workflow-sdk/blob/275ab19a8bebb5d70cc16698206f403dd012cbaa/src/orquestra/sdk/_base/_storage/orqdantic.py)
as possible.
This file provides the following:
- Class `BaseModel` - used throughout our code in place of
`pydantic.BaseModel`.
- When V2 is installed this is a subclass of `pydantic.BaseModel`
providing the `__setstate__` method [^1].
- When V1 is installed this is a subclass of `pydantic.GenericModel`
that provides V2-style methods that serve as wrappers for the equivalent
V1 methods
- Class `TypeAdapter` - used throughout our code in place of
`pydantic.TypeAdapter`
- When V2 is installed this operates identically to
`pydantic.TypeAdapter`
- When V1 is installed this class acts as a translator between the
V1-specific `parse_X_as` methods and the V2 TypeAdapter style syntax we
use in our code.
- Function `field_validator` - used throughout our code in place of
`pydantic.field_validator`
- When V2 is installed this operates identically to
`pydantic.TypeAdapter`
- When V1 is installed this operates as a wrapper for
`pydantic.validator` and _tries_ to translate V2-style kwargs. There are
not perfect analogues, so this is likely to cause problems if we add
more validators.
- `GpuResouceType` - The type used to define the `gpu` field of the
`Resource` model. For V1 this is a simple `Optional[str]`. For V2, due
to changes in how validators function, we need to annotate the type with
a validator.

[^1]: This was introduced in #372 to allow us to handle depickling
objects pickled under pydantic V1.

# Checklist

_Check that this PR satisfies the following items:_

- [ ] Tests have been added for new features/changed behavior (if no new
features have been added, check the box).
- [x] The [changelog file](CHANGELOG.md) has been updated with a
user-readable description of the changes (if the change isn't visible to
the user in any way, check the box).
- [x] The PR's title is prefixed with
`<feat/fix/chore/imp[rovement]/int[ernal]/docs>[!]:`
- [x] The PR is linked to a JIRA ticket (if there's no suitable ticket,
check the box).

[ORQSDK-1042]

[ORQSDK-1042]:
https://zapatacomputing.atlassian.net/browse/ORQSDK-1042?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ

---------

Co-authored-by: Alexander Juda <[email protected]>
Co-authored-by: James Clark <[email protected]>
  • Loading branch information
3 people authored Mar 26, 2024
1 parent 5525102 commit c04317f
Show file tree
Hide file tree
Showing 14 changed files with 240 additions and 135 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
💅 *Improvements*

* Tracebacks in `orq` are made more compact to help with copy and pasting when an issue happens.
* Bumped Pydantic version to `>=2.5.0`
* Added support for Pydantic V2 in addition to the previously supported `>=1.10.8`.
* Removed bunch of upper-bound constrains from SDK requirements to prevent dependency-hell

🥷 *Internal*
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ requires-python = ">= 3.8"
# the same API.
dependencies = [
# Schema definition
"pydantic>=2.5",
"pydantic>1.10.7",
# Pickling library
"cloudpickle==2.2.1",
# We should keep `dill` to make sure old workflows can be unpickled.
Expand Down
29 changes: 14 additions & 15 deletions src/orquestra/sdk/_base/_driver/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
WorkspaceId,
)

from ..._base._storage import TypeAdapter
from .._regex import VERSION_REGEX
from . import _exceptions, _models

Expand Down Expand Up @@ -905,7 +906,7 @@ def get_workflow_run_artifact(

return cast(
WorkflowResult,
pydantic.TypeAdapter(WorkflowResult).validate_python(resp.json()),
TypeAdapter(WorkflowResult).validate_python(resp.json()),
)

# --- Workflow Run Results ---
Expand Down Expand Up @@ -1014,8 +1015,9 @@ def get_workflow_run_result(
# Try an older response
return cast(
WorkflowResult,
pydantic.TypeAdapter(WorkflowResult).validate_python(json_response),
TypeAdapter(WorkflowResult).validate_python(json_response),
)

except pydantic.ValidationError:
# If we fail, try parsing each part of a list separately
return ComputeEngineWorkflowResult.model_validate(json_response)
Expand Down Expand Up @@ -1081,9 +1083,7 @@ def get_workflow_run_logs(
if len(section_str) < 1:
continue

events = pydantic.TypeAdapter(_models.WorkflowLogSection).validate_json(
section_str
)
events = TypeAdapter(_models.WorkflowLogSection).validate_json(section_str)

for event in events:
messages.append(event.message)
Expand Down Expand Up @@ -1153,9 +1153,7 @@ def get_task_run_logs(
if len(section_str) < 1:
continue

events = pydantic.TypeAdapter(_models.TaskLogSection).validate_json(
section_str
)
events = TypeAdapter(_models.TaskLogSection).validate_json(section_str)

for event in events:
messages.append(event.message)
Expand Down Expand Up @@ -1222,7 +1220,8 @@ def get_system_logs(self, wf_run_id: _models.WorkflowRunID) -> List[_models.SysL
for section_str in decoded.split("\n"):
if len(section_str) < 1:
continue
events = pydantic.TypeAdapter(_models.SysSection).validate_json(section_str)

events = TypeAdapter(_models.SysSection).validate_json(section_str)

for event in events:
messages.append(event.message)
Expand Down Expand Up @@ -1254,9 +1253,9 @@ def list_workspaces(self):
):
raise

parsed_response = pydantic.TypeAdapter(
_models.ListWorkspacesResponse
).validate_python(resp.json())
parsed_response = TypeAdapter(_models.ListWorkspacesResponse).validate_python(
resp.json()
)

return parsed_response

Expand Down Expand Up @@ -1293,9 +1292,9 @@ def list_projects(self, workspace_id: WorkspaceId):
):
raise

parsed_response = pydantic.TypeAdapter(
_models.ListProjectResponse
).validate_python(resp.json())
parsed_response = TypeAdapter(_models.ListProjectResponse).validate_python(
resp.json()
)

return parsed_response

Expand Down
64 changes: 32 additions & 32 deletions src/orquestra/sdk/_base/_driver/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
WorkspaceId,
)

from ..._base._storage import OrquestraBaseModel
from ..._base._storage import BaseModel

WorkflowDefID = str
WorkflowRunID = str
Expand All @@ -47,7 +47,7 @@
MetaT = TypeVar("MetaT")


class Pagination(OrquestraBaseModel):
class Pagination(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/259481b9240547bccf4fa40df4e92bf6c617a25f/openapi/src/schemas/MetaSuccessPaginated.yaml.
Expand All @@ -56,14 +56,14 @@ class Pagination(OrquestraBaseModel):
nextPageToken: str


class Response(OrquestraBaseModel, Generic[DataT, MetaT]):
class Response(BaseModel, Generic[DataT, MetaT]):
"""A generic to help with the structure of driver responses."""

data: DataT
meta: Optional[MetaT] = None


class MetaEmpty(OrquestraBaseModel):
class MetaEmpty(BaseModel):
pass


Expand All @@ -81,7 +81,7 @@ class ErrorCode(IntEnum):
WORKFLOW_DEF_NOT_FOUND = 6


class Error(OrquestraBaseModel):
class Error(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/2b3534/openapi/src/schemas/Error.yaml.
Expand All @@ -95,7 +95,7 @@ class Error(OrquestraBaseModel):
# --- Workflow Definitions ---


class CreateWorkflowDefResponse(OrquestraBaseModel):
class CreateWorkflowDefResponse(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/2b3534/openapi/src/responses/CreateWorkflowDefinitionResponse.yaml.
Expand All @@ -104,7 +104,7 @@ class CreateWorkflowDefResponse(OrquestraBaseModel):
id: WorkflowDefID


class GetWorkflowDefResponse(OrquestraBaseModel):
class GetWorkflowDefResponse(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/cb61512e9f3da24addd933c7259aa4584ab04e4f/openapi/src/schemas/WorkflowDefinition.yaml.
Expand All @@ -119,7 +119,7 @@ class GetWorkflowDefResponse(OrquestraBaseModel):
sdkVersion: str


class ListWorkflowDefsRequest(OrquestraBaseModel):
class ListWorkflowDefsRequest(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/cdb667ef6d1053876250daff27e19fb50374c0d4/openapi/src/resources/workflow-definitions.yaml#L8.
Expand All @@ -129,7 +129,7 @@ class ListWorkflowDefsRequest(OrquestraBaseModel):
pageToken: Optional[str] = None


class CreateWorkflowDefsRequest(OrquestraBaseModel):
class CreateWorkflowDefsRequest(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/dc8a2a37d92324f099afefc048f6486a5061850f/openapi/src/resources/workflow-definitions.yaml#L39.
Expand Down Expand Up @@ -163,7 +163,7 @@ def _missing_(cls, value):
return cls.UNKNOWN


class RunStatusResponse(OrquestraBaseModel):
class RunStatusResponse(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/schemas/RunStatus.yaml#L1.
Expand All @@ -181,7 +181,7 @@ def to_ir(self) -> RunStatus:
)


class TaskRunResponse(OrquestraBaseModel):
class TaskRunResponse(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/schemas/WorkflowRun.yaml#L17.
Expand All @@ -207,7 +207,7 @@ def to_ir(self) -> TaskRun:
)


class MinimalWorkflowRunResponse(OrquestraBaseModel):
class MinimalWorkflowRunResponse(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/schemas/WorkflowRun.yaml#L1.
Expand All @@ -223,7 +223,7 @@ def to_ir(self, workflow_def: WorkflowDef) -> WorkflowRunMinimal:
)


class WorkflowRunSummaryResponse(OrquestraBaseModel):
class WorkflowRunSummaryResponse(BaseModel):
"""Contains all of the information needed to give a basic overview of the workflow.
Implements:
Expand Down Expand Up @@ -269,7 +269,7 @@ def to_ir(self, workflow_def: WorkflowDef) -> WorkflowRun:
)


class Resources(OrquestraBaseModel):
class Resources(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/580c8d8835b1cccd085ea716c514038e85eb28d7/openapi/src/schemas/Resources.yaml.
Expand All @@ -288,7 +288,7 @@ class Resources(OrquestraBaseModel):
gpu: Optional[str] = pydantic.Field(pattern="^[01]+$")


class HeadNodeResources(OrquestraBaseModel):
class HeadNodeResources(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/ac1e97ea00fc3526c93187a1da02170bff45b74f/openapi/src/schemas/HeadNodeResources.yaml.
Expand All @@ -302,7 +302,7 @@ class HeadNodeResources(OrquestraBaseModel):
)


class CreateWorkflowRunRequest(OrquestraBaseModel):
class CreateWorkflowRunRequest(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/ac1e97ea00fc3526c93187a1da02170bff45b74f/openapi/src/schemas/CreateWorkflowRunRequest.yaml.
Expand All @@ -314,15 +314,15 @@ class CreateWorkflowRunRequest(OrquestraBaseModel):
headNodeResources: Optional[HeadNodeResources] = None


class CreateWorkflowRunResponse(OrquestraBaseModel):
class CreateWorkflowRunResponse(BaseModel):
"""Implements:
https://github.com/zapatacomputing/workflow-driver/blob/2e999a76019e8f8de8082409daddf7789dc2f430/pkg/server/server.go#L376.
""" # noqa: D205, D212

id: WorkflowRunID


class ListWorkflowRunsRequest(OrquestraBaseModel):
class ListWorkflowRunsRequest(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/c52013c0f4df066159fc32ad38d489b3eaff5850/openapi/src/resources/workflow-runs.yaml#L14.
Expand All @@ -342,7 +342,7 @@ class ListWorkflowRunsRequest(OrquestraBaseModel):
ListWorkflowRunSummariesResponse = List[WorkflowRunSummaryResponse]


class GetWorkflowRunResponse(OrquestraBaseModel):
class GetWorkflowRunResponse(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/resources/workflow-run.yaml#L17.
Expand All @@ -351,7 +351,7 @@ class GetWorkflowRunResponse(OrquestraBaseModel):
data: WorkflowRunResponse


class TerminateWorkflowRunRequest(OrquestraBaseModel):
class TerminateWorkflowRunRequest(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/873437f8157226c451220306a6ce90c80e8c8f9e/openapi/src/resources/workflow-run-terminate.yaml#L12.
Expand All @@ -363,7 +363,7 @@ class TerminateWorkflowRunRequest(OrquestraBaseModel):
# --- Workflow Artifacts ---


class GetWorkflowRunArtifactsRequest(OrquestraBaseModel):
class GetWorkflowRunArtifactsRequest(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/resources/artifacts.yaml#L10.
Expand All @@ -377,7 +377,7 @@ class GetWorkflowRunArtifactsRequest(OrquestraBaseModel):
# --- Workflow Results ---


class GetWorkflowRunResultsRequest(OrquestraBaseModel):
class GetWorkflowRunResultsRequest(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/resources/run-results.yaml#L10.
Expand All @@ -392,7 +392,7 @@ class GetWorkflowRunResultsRequest(OrquestraBaseModel):
# --- Logs ---


class GetWorkflowRunLogsRequest(OrquestraBaseModel):
class GetWorkflowRunLogsRequest(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/resources/workflow-run-logs.yaml.
Expand All @@ -401,7 +401,7 @@ class GetWorkflowRunLogsRequest(OrquestraBaseModel):
workflowRunId: WorkflowRunID


class GetTaskRunLogsRequest(OrquestraBaseModel):
class GetTaskRunLogsRequest(BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/c7685a579eca1f9cb3eb27e2a8c2a9757a3cd021/openapi/src/resources/task-run-logs.yaml.
Expand All @@ -411,7 +411,7 @@ class GetTaskRunLogsRequest(OrquestraBaseModel):
taskInvocationId: TaskInvocationID


class CommonResourceMeta(OrquestraBaseModel):
class CommonResourceMeta(BaseModel):
type: str
displayName: str
description: str
Expand All @@ -424,7 +424,7 @@ class CommonResourceMeta(OrquestraBaseModel):
status: str


class ResourceIdentifier(OrquestraBaseModel):
class ResourceIdentifier(BaseModel):
tenantId: str
resourceGroupId: str
id: str
Expand Down Expand Up @@ -468,7 +468,7 @@ class ProjectDetail(CommonResourceMeta, ResourceIdentifier):
RayFilename = NewType("RayFilename", str)


class WorkflowLogMessage(OrquestraBaseModel):
class WorkflowLogMessage(BaseModel):
"""Represents a single line indexed by the server side log service.
Based on:
Expand Down Expand Up @@ -516,7 +516,7 @@ class WorkflowLogEvent(NamedTuple):
WorkflowLogSection = List[WorkflowLogEvent]


class TaskLogMessage(OrquestraBaseModel):
class TaskLogMessage(BaseModel):
"""Represents a single line indexed by the server side log service.
Based on:
Expand Down Expand Up @@ -579,7 +579,7 @@ def _missing_(cls, *args, **kwargs):
return cls.UNKNOWN


class K8sEventLog(OrquestraBaseModel):
class K8sEventLog(BaseModel):
"""A system-level log line produced by a K8S event."""

tag: str
Expand All @@ -592,7 +592,7 @@ class K8sEventLog(OrquestraBaseModel):
source_type: Literal[SystemLogSourceType.K8S_EVENT] = SystemLogSourceType.K8S_EVENT


class RayHeadNodeEventLog(OrquestraBaseModel):
class RayHeadNodeEventLog(BaseModel):
"""A system-level log line produced by a Ray head node event."""

tag: str
Expand All @@ -604,7 +604,7 @@ class RayHeadNodeEventLog(OrquestraBaseModel):
] = SystemLogSourceType.RAY_HEAD_NODE


class RayWorkerNodeEventLog(OrquestraBaseModel):
class RayWorkerNodeEventLog(BaseModel):
"""A system-level log line produced by a Ray head node event."""

tag: str
Expand All @@ -616,7 +616,7 @@ class RayWorkerNodeEventLog(OrquestraBaseModel):
] = SystemLogSourceType.RAY_WORKER_NODE


class UnknownEventLog(OrquestraBaseModel):
class UnknownEventLog(BaseModel):
"""Fallback option - the event type is unknown, so display the message as a str."""

tag: str
Expand Down
Loading

0 comments on commit c04317f

Please sign in to comment.