diff --git a/CHANGELOG.md b/CHANGELOG.md index 8aca10ce5..21a934f6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,7 @@ 💅 *Improvements* * Tracebacks in `orq` are made more compact to help with copy and pasting when an issue happens. -* Bumped Pydantic version to `>=2.5.0` +* Added support for Pydantic V2 in addition to the previously supported `>=1.10.8`. * Removed bunch of upper-bound constrains from SDK requirements to prevent dependency-hell 🥷 *Internal* diff --git a/pyproject.toml b/pyproject.toml index fba49a333..33eee1daf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ requires-python = ">= 3.8" # the same API. dependencies = [ # Schema definition - "pydantic>=2.5", + "pydantic>1.10.7", # Pickling library "cloudpickle==2.2.1", # We should keep `dill` to make sure old workflows can be unpickled. diff --git a/src/orquestra/sdk/_base/_driver/_client.py b/src/orquestra/sdk/_base/_driver/_client.py index b75314f56..753b513a7 100644 --- a/src/orquestra/sdk/_base/_driver/_client.py +++ b/src/orquestra/sdk/_base/_driver/_client.py @@ -31,6 +31,7 @@ WorkspaceId, ) +from ..._base._storage import TypeAdapter from .._regex import VERSION_REGEX from . import _exceptions, _models @@ -905,7 +906,7 @@ def get_workflow_run_artifact( return cast( WorkflowResult, - pydantic.TypeAdapter(WorkflowResult).validate_python(resp.json()), + TypeAdapter(WorkflowResult).validate_python(resp.json()), ) # --- Workflow Run Results --- @@ -1014,8 +1015,9 @@ def get_workflow_run_result( # Try an older response return cast( WorkflowResult, - pydantic.TypeAdapter(WorkflowResult).validate_python(json_response), + TypeAdapter(WorkflowResult).validate_python(json_response), ) + except pydantic.ValidationError: # If we fail, try parsing each part of a list separately return ComputeEngineWorkflowResult.model_validate(json_response) @@ -1081,9 +1083,7 @@ def get_workflow_run_logs( if len(section_str) < 1: continue - events = pydantic.TypeAdapter(_models.WorkflowLogSection).validate_json( - section_str - ) + events = TypeAdapter(_models.WorkflowLogSection).validate_json(section_str) for event in events: messages.append(event.message) @@ -1153,9 +1153,7 @@ def get_task_run_logs( if len(section_str) < 1: continue - events = pydantic.TypeAdapter(_models.TaskLogSection).validate_json( - section_str - ) + events = TypeAdapter(_models.TaskLogSection).validate_json(section_str) for event in events: messages.append(event.message) @@ -1222,7 +1220,8 @@ def get_system_logs(self, wf_run_id: _models.WorkflowRunID) -> List[_models.SysL for section_str in decoded.split("\n"): if len(section_str) < 1: continue - events = pydantic.TypeAdapter(_models.SysSection).validate_json(section_str) + + events = TypeAdapter(_models.SysSection).validate_json(section_str) for event in events: messages.append(event.message) @@ -1254,9 +1253,9 @@ def list_workspaces(self): ): raise - parsed_response = pydantic.TypeAdapter( - _models.ListWorkspacesResponse - ).validate_python(resp.json()) + parsed_response = TypeAdapter(_models.ListWorkspacesResponse).validate_python( + resp.json() + ) return parsed_response @@ -1293,9 +1292,9 @@ def list_projects(self, workspace_id: WorkspaceId): ): raise - parsed_response = pydantic.TypeAdapter( - _models.ListProjectResponse - ).validate_python(resp.json()) + parsed_response = TypeAdapter(_models.ListProjectResponse).validate_python( + resp.json() + ) return parsed_response diff --git a/src/orquestra/sdk/_base/_driver/_models.py b/src/orquestra/sdk/_base/_driver/_models.py index 90cc9500d..0f73a8663 100644 --- a/src/orquestra/sdk/_base/_driver/_models.py +++ b/src/orquestra/sdk/_base/_driver/_models.py @@ -31,7 +31,7 @@ WorkspaceId, ) -from ..._base._storage import OrquestraBaseModel +from ..._base._storage import BaseModel WorkflowDefID = str WorkflowRunID = str @@ -47,7 +47,7 @@ MetaT = TypeVar("MetaT") -class Pagination(OrquestraBaseModel): +class Pagination(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/259481b9240547bccf4fa40df4e92bf6c617a25f/openapi/src/schemas/MetaSuccessPaginated.yaml. @@ -56,14 +56,14 @@ class Pagination(OrquestraBaseModel): nextPageToken: str -class Response(OrquestraBaseModel, Generic[DataT, MetaT]): +class Response(BaseModel, Generic[DataT, MetaT]): """A generic to help with the structure of driver responses.""" data: DataT meta: Optional[MetaT] = None -class MetaEmpty(OrquestraBaseModel): +class MetaEmpty(BaseModel): pass @@ -81,7 +81,7 @@ class ErrorCode(IntEnum): WORKFLOW_DEF_NOT_FOUND = 6 -class Error(OrquestraBaseModel): +class Error(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/2b3534/openapi/src/schemas/Error.yaml. @@ -95,7 +95,7 @@ class Error(OrquestraBaseModel): # --- Workflow Definitions --- -class CreateWorkflowDefResponse(OrquestraBaseModel): +class CreateWorkflowDefResponse(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/2b3534/openapi/src/responses/CreateWorkflowDefinitionResponse.yaml. @@ -104,7 +104,7 @@ class CreateWorkflowDefResponse(OrquestraBaseModel): id: WorkflowDefID -class GetWorkflowDefResponse(OrquestraBaseModel): +class GetWorkflowDefResponse(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/cb61512e9f3da24addd933c7259aa4584ab04e4f/openapi/src/schemas/WorkflowDefinition.yaml. @@ -119,7 +119,7 @@ class GetWorkflowDefResponse(OrquestraBaseModel): sdkVersion: str -class ListWorkflowDefsRequest(OrquestraBaseModel): +class ListWorkflowDefsRequest(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/cdb667ef6d1053876250daff27e19fb50374c0d4/openapi/src/resources/workflow-definitions.yaml#L8. @@ -129,7 +129,7 @@ class ListWorkflowDefsRequest(OrquestraBaseModel): pageToken: Optional[str] = None -class CreateWorkflowDefsRequest(OrquestraBaseModel): +class CreateWorkflowDefsRequest(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/dc8a2a37d92324f099afefc048f6486a5061850f/openapi/src/resources/workflow-definitions.yaml#L39. @@ -163,7 +163,7 @@ def _missing_(cls, value): return cls.UNKNOWN -class RunStatusResponse(OrquestraBaseModel): +class RunStatusResponse(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/schemas/RunStatus.yaml#L1. @@ -181,7 +181,7 @@ def to_ir(self) -> RunStatus: ) -class TaskRunResponse(OrquestraBaseModel): +class TaskRunResponse(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/schemas/WorkflowRun.yaml#L17. @@ -207,7 +207,7 @@ def to_ir(self) -> TaskRun: ) -class MinimalWorkflowRunResponse(OrquestraBaseModel): +class MinimalWorkflowRunResponse(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/schemas/WorkflowRun.yaml#L1. @@ -223,7 +223,7 @@ def to_ir(self, workflow_def: WorkflowDef) -> WorkflowRunMinimal: ) -class WorkflowRunSummaryResponse(OrquestraBaseModel): +class WorkflowRunSummaryResponse(BaseModel): """Contains all of the information needed to give a basic overview of the workflow. Implements: @@ -269,7 +269,7 @@ def to_ir(self, workflow_def: WorkflowDef) -> WorkflowRun: ) -class Resources(OrquestraBaseModel): +class Resources(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/580c8d8835b1cccd085ea716c514038e85eb28d7/openapi/src/schemas/Resources.yaml. @@ -288,7 +288,7 @@ class Resources(OrquestraBaseModel): gpu: Optional[str] = pydantic.Field(pattern="^[01]+$") -class HeadNodeResources(OrquestraBaseModel): +class HeadNodeResources(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/ac1e97ea00fc3526c93187a1da02170bff45b74f/openapi/src/schemas/HeadNodeResources.yaml. @@ -302,7 +302,7 @@ class HeadNodeResources(OrquestraBaseModel): ) -class CreateWorkflowRunRequest(OrquestraBaseModel): +class CreateWorkflowRunRequest(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/ac1e97ea00fc3526c93187a1da02170bff45b74f/openapi/src/schemas/CreateWorkflowRunRequest.yaml. @@ -314,7 +314,7 @@ class CreateWorkflowRunRequest(OrquestraBaseModel): headNodeResources: Optional[HeadNodeResources] = None -class CreateWorkflowRunResponse(OrquestraBaseModel): +class CreateWorkflowRunResponse(BaseModel): """Implements: https://github.com/zapatacomputing/workflow-driver/blob/2e999a76019e8f8de8082409daddf7789dc2f430/pkg/server/server.go#L376. """ # noqa: D205, D212 @@ -322,7 +322,7 @@ class CreateWorkflowRunResponse(OrquestraBaseModel): id: WorkflowRunID -class ListWorkflowRunsRequest(OrquestraBaseModel): +class ListWorkflowRunsRequest(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/c52013c0f4df066159fc32ad38d489b3eaff5850/openapi/src/resources/workflow-runs.yaml#L14. @@ -342,7 +342,7 @@ class ListWorkflowRunsRequest(OrquestraBaseModel): ListWorkflowRunSummariesResponse = List[WorkflowRunSummaryResponse] -class GetWorkflowRunResponse(OrquestraBaseModel): +class GetWorkflowRunResponse(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/resources/workflow-run.yaml#L17. @@ -351,7 +351,7 @@ class GetWorkflowRunResponse(OrquestraBaseModel): data: WorkflowRunResponse -class TerminateWorkflowRunRequest(OrquestraBaseModel): +class TerminateWorkflowRunRequest(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/873437f8157226c451220306a6ce90c80e8c8f9e/openapi/src/resources/workflow-run-terminate.yaml#L12. @@ -363,7 +363,7 @@ class TerminateWorkflowRunRequest(OrquestraBaseModel): # --- Workflow Artifacts --- -class GetWorkflowRunArtifactsRequest(OrquestraBaseModel): +class GetWorkflowRunArtifactsRequest(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/resources/artifacts.yaml#L10. @@ -377,7 +377,7 @@ class GetWorkflowRunArtifactsRequest(OrquestraBaseModel): # --- Workflow Results --- -class GetWorkflowRunResultsRequest(OrquestraBaseModel): +class GetWorkflowRunResultsRequest(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/resources/run-results.yaml#L10. @@ -392,7 +392,7 @@ class GetWorkflowRunResultsRequest(OrquestraBaseModel): # --- Logs --- -class GetWorkflowRunLogsRequest(OrquestraBaseModel): +class GetWorkflowRunLogsRequest(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/34eba4253b56266772795a8a59d6ec7edf88c65a/openapi/src/resources/workflow-run-logs.yaml. @@ -401,7 +401,7 @@ class GetWorkflowRunLogsRequest(OrquestraBaseModel): workflowRunId: WorkflowRunID -class GetTaskRunLogsRequest(OrquestraBaseModel): +class GetTaskRunLogsRequest(BaseModel): """ Implements: https://github.com/zapatacomputing/workflow-driver/blob/c7685a579eca1f9cb3eb27e2a8c2a9757a3cd021/openapi/src/resources/task-run-logs.yaml. @@ -411,7 +411,7 @@ class GetTaskRunLogsRequest(OrquestraBaseModel): taskInvocationId: TaskInvocationID -class CommonResourceMeta(OrquestraBaseModel): +class CommonResourceMeta(BaseModel): type: str displayName: str description: str @@ -424,7 +424,7 @@ class CommonResourceMeta(OrquestraBaseModel): status: str -class ResourceIdentifier(OrquestraBaseModel): +class ResourceIdentifier(BaseModel): tenantId: str resourceGroupId: str id: str @@ -468,7 +468,7 @@ class ProjectDetail(CommonResourceMeta, ResourceIdentifier): RayFilename = NewType("RayFilename", str) -class WorkflowLogMessage(OrquestraBaseModel): +class WorkflowLogMessage(BaseModel): """Represents a single line indexed by the server side log service. Based on: @@ -516,7 +516,7 @@ class WorkflowLogEvent(NamedTuple): WorkflowLogSection = List[WorkflowLogEvent] -class TaskLogMessage(OrquestraBaseModel): +class TaskLogMessage(BaseModel): """Represents a single line indexed by the server side log service. Based on: @@ -579,7 +579,7 @@ def _missing_(cls, *args, **kwargs): return cls.UNKNOWN -class K8sEventLog(OrquestraBaseModel): +class K8sEventLog(BaseModel): """A system-level log line produced by a K8S event.""" tag: str @@ -592,7 +592,7 @@ class K8sEventLog(OrquestraBaseModel): source_type: Literal[SystemLogSourceType.K8S_EVENT] = SystemLogSourceType.K8S_EVENT -class RayHeadNodeEventLog(OrquestraBaseModel): +class RayHeadNodeEventLog(BaseModel): """A system-level log line produced by a Ray head node event.""" tag: str @@ -604,7 +604,7 @@ class RayHeadNodeEventLog(OrquestraBaseModel): ] = SystemLogSourceType.RAY_HEAD_NODE -class RayWorkerNodeEventLog(OrquestraBaseModel): +class RayWorkerNodeEventLog(BaseModel): """A system-level log line produced by a Ray head node event.""" tag: str @@ -616,7 +616,7 @@ class RayWorkerNodeEventLog(OrquestraBaseModel): ] = SystemLogSourceType.RAY_WORKER_NODE -class UnknownEventLog(OrquestraBaseModel): +class UnknownEventLog(BaseModel): """Fallback option - the event type is unknown, so display the message as a str.""" tag: str diff --git a/src/orquestra/sdk/_base/_storage/__init__.py b/src/orquestra/sdk/_base/_storage/__init__.py index 8a513ec33..7646c5e6b 100644 --- a/src/orquestra/sdk/_base/_storage/__init__.py +++ b/src/orquestra/sdk/_base/_storage/__init__.py @@ -2,8 +2,12 @@ # © Copyright 2024 Zapata Computing Inc. ################################################################################ -from ._basemodel import OrquestraBaseModel + +from .orqdantic import BaseModel, GpuResourceType, TypeAdapter, field_validator __all__ = [ - "OrquestraBaseModel", + "BaseModel", + "TypeAdapter", + "field_validator", + "GpuResourceType", ] diff --git a/src/orquestra/sdk/_base/_storage/_basemodel.py b/src/orquestra/sdk/_base/_storage/_basemodel.py deleted file mode 100644 index 503382904..000000000 --- a/src/orquestra/sdk/_base/_storage/_basemodel.py +++ /dev/null @@ -1,28 +0,0 @@ -################################################################################ -# © Copyright 2024 Zapata Computing Inc. -################################################################################ - -from typing import Any - -from pydantic.main import BaseModel - - -# TODO (ORQSDK-1025): remove the model base class -class OrquestraBaseModel(BaseModel): - """The pydantic BaseModel changed between V1 and V2. - - As a result, workflow outputs generated prior to the V2 upgrade may not be - depickled correctly. The culpret is a change in behaviour of `__setstate__`. - - This class adds a new `__setstate__` that wraps the V2 BaseModel `__setstate__` and - adds the missing behaviour back in. - """ - - def __setstate__(self, state: dict[Any, Any]) -> None: - state.setdefault("__pydantic_extra__", {}) - state.setdefault("__pydantic_private__", {}) - - if "__pydantic_fields_set__" not in state: - state["__pydantic_fields_set__"] = state.get("__fields_set__") - - super().__setstate__(state) diff --git a/src/orquestra/sdk/_base/_storage/orqdantic.py b/src/orquestra/sdk/_base/_storage/orqdantic.py new file mode 100644 index 000000000..1b1da3f9e --- /dev/null +++ b/src/orquestra/sdk/_base/_storage/orqdantic.py @@ -0,0 +1,129 @@ +################################################################################ +# © Copyright 2024 Zapata Computing Inc. +################################################################################ + +"""Compatibility layer for pydantic v1 / v2 compatibility.""" + +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Dict, Optional + +import pydantic +from typing_extensions import Annotated + +PYDANTICV1 = pydantic.version.VERSION.startswith("1.") + +if PYDANTICV1 and not TYPE_CHECKING: + from pydantic.generics import GenericModel + + class BaseModel(GenericModel): + @classmethod + def model_validate(cls, *args, **kwargs): + return super(GenericModel, cls).parse_obj(*args, **kwargs) + + @classmethod + def model_validate_json(cls, *args, **kwargs): + return super(GenericModel, cls).parse_raw(*args, **kwargs) + + def model_dump(self, *args, **kwargs): + return super().dict(*args, **kwargs) + + def model_dump_json(self, *args, **kwargs): + return super().json(*args, **kwargs) + + @classmethod + def model_json_schema(cls, *args, **kwargs): + return super(GenericModel, cls).schema_json(*args, **kwargs) + + def model_copy(self, *args, **kwargs): + return super().copy(*args, **kwargs) + +else: + # TODO (ORQSDK-1025): remove the model base class and replace it with an alias to + # BaseModel + class BaseModel(pydantic.BaseModel): # type: ignore[no-redef] + """The pydantic BaseModel changed between V1 and V2. + + As a result, workflow outputs generated prior to the V2 upgrade may not be + depickled correctly. The culpret is a change in behaviour of `__setstate__`. + + This class adds a new `__setstate__` that wraps the V2 BaseModel `__setstate__` + and adds the missing behaviour back in. + """ + + def __setstate__(self, state: Dict[Any, Any]) -> None: + state.setdefault("__pydantic_extra__", {}) + state.setdefault("__pydantic_private__", {}) + + if "__pydantic_fields_set__" not in state: + state["__pydantic_fields_set__"] = state.get("__fields_set__") + + super().__setstate__(state) + + +class TypeAdapter: + """Accessor for Pydantic parsing. + + If Pydantic V2 is installed, this class is a simple wrapper for + `pydantic.TypeAdapter`. + + If Pydantic V1 is installed, this class acts as a translator between the V1-specific + `parse_X_as` methods and the V2 TypeAdapter style syntax we use in our code. + """ + + def __init__(self, model, *args, **kwargs): + if PYDANTICV1: + self._model = model + else: + self._typeadapter = pydantic.TypeAdapter(model, *args, **kwargs) + + def validate_python(self, value, *args, **kwargs): + if PYDANTICV1: + return pydantic.parse_obj_as(self._model, value) + else: + return self._typeadapter.validate_python(value, *args, **kwargs) + + def validate_json(self, value, *args, **kwargs): + if PYDANTICV1: + return pydantic.parse_raw_as( + self._model, value + ) # type: ignore[reportCallIssue,operator] + else: + return self._typeadapter.validate_json(value, *args, **kwargs) + + +def field_validator(*fields, **kwargs): + """Wrapper for pydantic field validators. + + If Pydantic V2 is installed, this operates as a simple wrapper for + `pydantic.field_validator`. + + If Pydantic V1 is installed, this operates as a wrapper for `pydantic.validator` and + _tries_ to translate V2-style kwargs. There are not perfect analogues, so this is + likely to cause problems if we add more validators. + """ + if PYDANTICV1: + + def translate_kwargs(kwargs: dict) -> dict: + _kwargs = deepcopy(kwargs) + if "mode" in _kwargs: + if _kwargs["mode"] == "before": + _kwargs["pre"] = True + elif _kwargs["mode"] == "after": + _kwargs["always"] = True + _kwargs.pop("mode") + return _kwargs + + return pydantic.validator(*fields, **translate_kwargs(kwargs)) + else: + return pydantic.field_validator(*fields, **kwargs) + + +if TYPE_CHECKING: + GpuResourceType = Optional[str] +else: + if PYDANTICV1: + GpuResourceType = Optional[str] + else: + GpuResourceType = Optional[ + Annotated[str, pydantic.BeforeValidator(lambda x: str(x))] + ] diff --git a/src/orquestra/sdk/_base/serde.py b/src/orquestra/sdk/_base/serde.py index 16e7eb8ca..6ce8857f9 100644 --- a/src/orquestra/sdk/_base/serde.py +++ b/src/orquestra/sdk/_base/serde.py @@ -10,10 +10,11 @@ from pathlib import Path import cloudpickle # type: ignore -import pydantic from orquestra.sdk.schema import ir, responses +from .._base._storage import TypeAdapter + CHUNK_SIZE = 40_000 ENCODING = "base64" PICKLE_PROTOCOL = 4 @@ -158,17 +159,16 @@ def result_from_artifact( def value_from_result_dict(result_dict: t.Mapping) -> t.Any: result = t.cast( responses.WorkflowResult, - pydantic.TypeAdapter(responses.WorkflowResult).validate_python(result_dict), + TypeAdapter(responses.WorkflowResult).validate_python(result_dict), ) + return deserialize(result) def deserialize_constant(node: ir.ConstantNode): - return deserialize( - pydantic.TypeAdapter(responses.WorkflowResult).validate_python( - node.model_dump() - ) - ) + constant = TypeAdapter(responses.WorkflowResult).validate_python(node.model_dump()) + + return deserialize(constant) def stringify_package_spec(package: ir.PackageSpec) -> str: diff --git a/src/orquestra/sdk/_ray/_wf_metadata.py b/src/orquestra/sdk/_ray/_wf_metadata.py index 44aa8dfea..f7ed29bd8 100644 --- a/src/orquestra/sdk/_ray/_wf_metadata.py +++ b/src/orquestra/sdk/_ray/_wf_metadata.py @@ -5,11 +5,11 @@ import json import typing as t -from .._base._storage import OrquestraBaseModel +from .._base._storage import BaseModel from ..schema import ir, workflow_run -class WfUserMetadata(OrquestraBaseModel): +class WfUserMetadata(BaseModel): """Information about a workflow run we store as a Ray metadata dict. Pydantic helps us check that the thing we read from Ray is indeed a dictionary we @@ -20,7 +20,7 @@ class WfUserMetadata(OrquestraBaseModel): workflow_def: ir.WorkflowDef -class InvUserMetadata(OrquestraBaseModel): +class InvUserMetadata(BaseModel): """Information about a task invocation we store as a Ray metadata dict. Pydantic helps us check that the thing we read from Ray is indeed a dictionary we diff --git a/src/orquestra/sdk/schema/configs.py b/src/orquestra/sdk/schema/configs.py index 08cbd3230..305d9a808 100644 --- a/src/orquestra/sdk/schema/configs.py +++ b/src/orquestra/sdk/schema/configs.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Any, Dict, Literal -from .._base._storage import OrquestraBaseModel +from .._base._storage import BaseModel CONFIG_FILE_CURRENT_VERSION = "0.0.2" @@ -24,7 +24,7 @@ def __format__(self, format_spec: str) -> str: RemoteRuntime = Literal[RuntimeName.CE_REMOTE] -class RuntimeConfiguration(OrquestraBaseModel): +class RuntimeConfiguration(BaseModel): config_name: ConfigName runtime_name: RuntimeName runtime_options: Dict[str, Any] = {} @@ -40,7 +40,7 @@ def __str__(self): return outstr -class RuntimeConfigurationFile(OrquestraBaseModel): +class RuntimeConfigurationFile(BaseModel): """This schema is for the storage of "Runtime configurations". The major version number should be bumped when: diff --git a/src/orquestra/sdk/schema/ir.py b/src/orquestra/sdk/schema/ir.py index f943d8619..434ba309a 100644 --- a/src/orquestra/sdk/schema/ir.py +++ b/src/orquestra/sdk/schema/ir.py @@ -12,15 +12,15 @@ import warnings import pydantic -from pydantic import BeforeValidator +from typing_extensions import Annotated -from .._base._storage import OrquestraBaseModel +from .._base._storage import BaseModel, GpuResourceType, field_validator ImportId = str SecretNodeId = str -class SecretNode(OrquestraBaseModel): +class SecretNode(BaseModel): """A reference to a secret stored in an external secret/config service.""" # Workflow-scope unique ID used to refer from task invocations @@ -39,7 +39,7 @@ class SecretNode(OrquestraBaseModel): workspace_id: t.Optional[str] = None -class GitURL(OrquestraBaseModel): +class GitURL(BaseModel): original_url: str protocol: str user: t.Optional[str] = None @@ -50,7 +50,7 @@ class GitURL(OrquestraBaseModel): query: t.Optional[str] = None -class GitImport(OrquestraBaseModel): +class GitImport(BaseModel): id: ImportId repo_url: GitURL git_ref: str @@ -58,7 +58,7 @@ class GitImport(OrquestraBaseModel): # we need this in the JSON to know which class to use when deserializing type: t.Literal["GIT_IMPORT"] = "GIT_IMPORT" - @pydantic.field_validator("repo_url", mode="before") + @field_validator("repo_url", mode="before") def _backwards_compatible_repo_url(cls, v): """Allows older models with a string URL to be imported.""" # Prevent circular imports @@ -70,7 +70,7 @@ def _backwards_compatible_repo_url(cls, v): return parse_git_url(v) -class LocalImport(OrquestraBaseModel): +class LocalImport(BaseModel): """Used to specify that the source code is only available locally. (e.g. not committed to any git repo). @@ -82,12 +82,12 @@ class LocalImport(OrquestraBaseModel): type: t.Literal["LOCAL_IMPORT"] = "LOCAL_IMPORT" -class InlineImport(OrquestraBaseModel): +class InlineImport(BaseModel): id: ImportId type: t.Literal["INLINE_IMPORT"] = "INLINE_IMPORT" -class PackageSpec(OrquestraBaseModel): +class PackageSpec(BaseModel): # noqa E501 """Representation of single package import. @@ -103,7 +103,7 @@ class PackageSpec(OrquestraBaseModel): environment_markers: str -class PythonImports(OrquestraBaseModel): +class PythonImports(BaseModel): """List of imports for given task.""" id: ImportId @@ -123,7 +123,7 @@ class PythonImports(OrquestraBaseModel): TaskDefId = str -class ModuleFunctionRef(OrquestraBaseModel): +class ModuleFunctionRef(BaseModel): # Required to dereference function for execution. module: str function_name: str @@ -136,7 +136,7 @@ class ModuleFunctionRef(OrquestraBaseModel): type: t.Literal["MODULE_FUNCTION_REF"] = "MODULE_FUNCTION_REF" -class FileFunctionRef(OrquestraBaseModel): +class FileFunctionRef(BaseModel): # Required to dereference function for execution. file_path: str function_name: str @@ -149,7 +149,7 @@ class FileFunctionRef(OrquestraBaseModel): type: t.Literal["FILE_FUNCTION_REF"] = "FILE_FUNCTION_REF" -class InlineFunctionRef(OrquestraBaseModel): +class InlineFunctionRef(BaseModel): function_name: str # Required to dereference function for execution. The function object is serialized # using `dill`, base64-encoded, and chunked to workaround JSON string length limits. @@ -162,18 +162,19 @@ class InlineFunctionRef(OrquestraBaseModel): FunctionRef = t.Union[ModuleFunctionRef, FileFunctionRef, InlineFunctionRef] -class Resources(OrquestraBaseModel): +class Resources(BaseModel): cpu: t.Optional[str] = None memory: t.Optional[str] = None disk: t.Optional[str] = None - gpu: t.Optional[t.Annotated[str, BeforeValidator(lambda x: str(x))]] = None + gpu: GpuResourceType = None + # nodes should be a positive integer representing the number of nodes assigned # to a workflow. If None, the runtime will choose. # This only applies to workflows and not tasks. nodes: t.Optional[int] = None -class DataAggregation(OrquestraBaseModel): +class DataAggregation(BaseModel): run: t.Optional[bool] = None resources: t.Optional[Resources] = None @@ -190,14 +191,14 @@ class ParameterKind(str, enum.Enum): VAR_KEYWORD = "VAR_KEYWORD" -class TaskParameter(OrquestraBaseModel): +class TaskParameter(BaseModel): name: ParameterName kind: ParameterKind # If we need more metadata related to parameters, like type hints or default values, # it should be added here. -class TaskOutputMetadata(OrquestraBaseModel): +class TaskOutputMetadata(BaseModel): """Information about the data shape returned by a task function.""" # If yes, it's possible to unpack the output in the workflow like: @@ -211,7 +212,7 @@ class TaskOutputMetadata(OrquestraBaseModel): n_outputs: int -class TaskDef(OrquestraBaseModel): +class TaskDef(BaseModel): # workflow-unique ID used to refer from task invocations id: TaskDefId @@ -269,7 +270,7 @@ class ArtifactFormat(str, enum.Enum): ConstantNodeId = str -class ArtifactNode(OrquestraBaseModel): +class ArtifactNode(BaseModel): # Workflow-scope unique ID used to refer from task invocations. If the task has # multiple outputs they will have distinct `id`s. id: ArtifactNodeId @@ -294,7 +295,7 @@ class ArtifactNode(OrquestraBaseModel): artifact_index: t.Optional[int] = None -class ConstantNodeJSON(OrquestraBaseModel): +class ConstantNodeJSON(BaseModel): """Piece of data that already exists at workflow submission time. The value is directly embedded in the workflow. To support arbitrary data shapes we @@ -313,7 +314,7 @@ class ConstantNodeJSON(OrquestraBaseModel): value_preview: pydantic.constr(max_length=12) # type: ignore -class ConstantNodePickle(OrquestraBaseModel): +class ConstantNodePickle(BaseModel): """Piece of data that already exists at workflow submission time. The value is directly embedded in the workflow. To support arbitrary data shapes we @@ -341,7 +342,7 @@ class ConstantNodePickle(OrquestraBaseModel): ArgumentId = t.Union[ArtifactNodeId, ConstantNodeId, SecretNodeId] -class TaskInvocation(OrquestraBaseModel): +class TaskInvocation(BaseModel): id: TaskInvocationId # What task should be executed. @@ -367,7 +368,7 @@ class TaskInvocation(OrquestraBaseModel): WorkflowDefName = str -class Version(OrquestraBaseModel): +class Version(BaseModel): original: str major: int minor: int @@ -375,12 +376,12 @@ class Version(OrquestraBaseModel): is_prerelease: bool -class WorkflowMetadata(OrquestraBaseModel): +class WorkflowMetadata(BaseModel): sdk_version: Version python_version: Version -class WorkflowDef(OrquestraBaseModel): +class WorkflowDef(BaseModel): """The main data structure for intermediate workflow representation. The structure is as flat as possible with relation based on "id"s, e.g. a single @@ -413,7 +414,7 @@ class WorkflowDef(OrquestraBaseModel): data_aggregation: t.Optional[DataAggregation] = None # Metadata defaults to None to allow older JSON to be loaded - metadata: t.Annotated[ + metadata: Annotated[ t.Optional[WorkflowMetadata], pydantic.Field(validate_default=True) ] = None @@ -421,7 +422,7 @@ class WorkflowDef(OrquestraBaseModel): # If none, the runtime will decide. resources: t.Optional[Resources] = None - @pydantic.field_validator("metadata", mode="after") + @field_validator("metadata", mode="after") def sdk_version_up_to_date(cls, v: t.Optional[WorkflowMetadata]): # Workaround for circular imports from orquestra.sdk import exceptions diff --git a/src/orquestra/sdk/schema/responses.py b/src/orquestra/sdk/schema/responses.py index 2581f0ec0..b76bcc84d 100644 --- a/src/orquestra/sdk/schema/responses.py +++ b/src/orquestra/sdk/schema/responses.py @@ -13,7 +13,7 @@ from pydantic import Field from typing_extensions import Annotated -from .._base._storage import OrquestraBaseModel +from .._base._storage import BaseModel from .ir import ArtifactFormat @@ -43,19 +43,19 @@ class ResponseStatusCode(enum.Enum): USER_CANCELLED = 15 -class ResponseMetadata(OrquestraBaseModel): +class ResponseMetadata(BaseModel): success: bool code: ResponseStatusCode message: str -class JSONResult(OrquestraBaseModel): +class JSONResult(BaseModel): # Output value dumped to a flat JSON string. value: str serialization_format: t.Literal[ArtifactFormat.JSON] = ArtifactFormat.JSON -class PickleResult(OrquestraBaseModel): +class PickleResult(BaseModel): # Output value dumped to a pickle byte string, encoded as base64, and split into # chunks. Chunking is required because some JSON parsers have limitation on max # string field length. @@ -70,12 +70,12 @@ class PickleResult(OrquestraBaseModel): ] -class ComputeEngineWorkflowResult(OrquestraBaseModel): +class ComputeEngineWorkflowResult(BaseModel): results: t.Tuple[WorkflowResult, ...] type: t.Literal["ComputeEngineWorkflowResult"] = "ComputeEngineWorkflowResult" -class ServiceResponse(OrquestraBaseModel): +class ServiceResponse(BaseModel): name: str is_running: bool info: t.Optional[str] diff --git a/src/orquestra/sdk/schema/workflow_run.py b/src/orquestra/sdk/schema/workflow_run.py index a1982c224..8f0848c2f 100644 --- a/src/orquestra/sdk/schema/workflow_run.py +++ b/src/orquestra/sdk/schema/workflow_run.py @@ -14,7 +14,7 @@ from orquestra.sdk._base._dates import Instant from orquestra.sdk.schema.ir import TaskInvocationId, WorkflowDef -from .._base._storage import OrquestraBaseModel +from .._base._storage import BaseModel WorkflowRunId = str TaskRunId = str @@ -55,20 +55,20 @@ def is_completed(self) -> bool: ) -class RunStatus(OrquestraBaseModel): +class RunStatus(BaseModel): state: State start_time: t.Optional[Instant] end_time: t.Optional[Instant] -class TaskRun(OrquestraBaseModel): +class TaskRun(BaseModel): id: TaskRunId invocation_id: TaskInvocationId status: RunStatus message: t.Optional[str] = None -class WorkflowRunOnlyID(OrquestraBaseModel): +class WorkflowRunOnlyID(BaseModel): """A WorkflowRun that only contains the ID.""" id: WorkflowRunId diff --git a/src/orquestra/sdk/secrets/_models.py b/src/orquestra/sdk/secrets/_models.py index e129e5870..976c053e2 100644 --- a/src/orquestra/sdk/secrets/_models.py +++ b/src/orquestra/sdk/secrets/_models.py @@ -7,7 +7,7 @@ """ from typing import Optional -from .._base._storage import OrquestraBaseModel +from .._base._storage import BaseModel SecretName = str SecretValue = str @@ -15,7 +15,7 @@ WorkspaceId = str -class SecretNameObj(OrquestraBaseModel): +class SecretNameObj(BaseModel): """ Model for: https://github.com/zapatacomputing/config-service/blob/3f275a52149fb2b74c6a8c01726cce4f390a1533/openapi/src/schemas/SecretName.yaml. @@ -27,7 +27,7 @@ class SecretNameObj(OrquestraBaseModel): name: SecretName -class SecretValueObj(OrquestraBaseModel): +class SecretValueObj(BaseModel): """ Model for: https://github.com/zapatacomputing/config-service/blob/3f275a52149fb2b74c6a8c01726cce4f390a1533/openapi/src/schemas/SecretValue.yaml. @@ -39,7 +39,7 @@ class SecretValueObj(OrquestraBaseModel): value: SecretValue -class SecretDefinition(OrquestraBaseModel): +class SecretDefinition(BaseModel): """ Model for: https://github.com/zapatacomputing/config-service/blob/3f275a52149fb2b74c6a8c01726cce4f390a1533/openapi/src/schemas/SecretDefinition.yaml. @@ -50,7 +50,7 @@ class SecretDefinition(OrquestraBaseModel): resourceGroup: Optional[ResourceGroup] = None -class ListSecretsRequest(OrquestraBaseModel): +class ListSecretsRequest(BaseModel): """ Model for: https://github.com/zapatacomputing/config-service/blob/fbfc4627450bc9a460278b242738e55210e7bf03/openapi/src/parameters/query/workspace.yaml.