Skip to content

Commit

Permalink
Union/enum handling (#2845) (#2851)
Browse files Browse the repository at this point in the history
backport of #2845
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Oct 22, 2024
1 parent a47dbb6 commit 83474c6
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 8 deletions.
36 changes: 28 additions & 8 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
else:
for f in dataclasses.fields(type(v)): # type: ignore
original_type = f.type
if f.name not in expected_fields_dict:
raise TypeTransformerFailedError(
f"Field '{f.name}' is not present in the expected dataclass fields {expected_type.__name__}"
)
expected_type = expected_fields_dict[f.name]

if UnionTransformer.is_optional_type(original_type):
Expand Down Expand Up @@ -796,7 +800,7 @@ def to_literal(
if type(python_val).__class__ != enum.EnumMeta:
raise TypeTransformerFailedError("Expected an enum")
if type(python_val.value) != str:
raise TypeTransformerFailedError("Only string-valued enums are supportedd")
raise TypeTransformerFailedError("Only string-valued enums are supported")

return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore

Expand All @@ -808,6 +812,18 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:
return enum.Enum("DynamicEnum", {f"{i}": i for i in literal_type.enum_type.values}) # type: ignore
raise ValueError(f"Enum transformer cannot reverse {literal_type}")

def assert_type(self, t: Type[enum.Enum], v: T):
if sys.version_info < (3, 10):
if not isinstance(v, enum.Enum):
raise TypeTransformerFailedError(f"Value {v} needs to be an Enum in 3.9")
if not isinstance(v, t):
raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}")
return

val = v.value if isinstance(v, enum.Enum) else v
if val not in t:
raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}")


def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any):
attribute_list = []
Expand Down Expand Up @@ -1193,7 +1209,7 @@ def literal_map_to_kwargs(
raise ValueError("At least one of python_types or literal_types must be provided")

if literal_types:
python_interface_inputs = {
python_interface_inputs: dict[str, Type[T]] = {
name: TypeEngine.guess_python_type(lt.type) for name, lt in literal_types.items()
}
else:
Expand Down Expand Up @@ -1272,7 +1288,7 @@ def guess_python_types(
return python_types

@classmethod
def guess_python_type(cls, flyte_type: LiteralType) -> type:
def guess_python_type(cls, flyte_type: LiteralType) -> Type[T]:
"""
Transforms a flyte-specific ``LiteralType`` to a regular python value.
"""
Expand Down Expand Up @@ -1542,13 +1558,17 @@ def assert_type(self, t: Type[T], v: T):
# this is an edge case
return
try:
super().assert_type(sub_type, v)
return
sub_trans: TypeTransformer = TypeEngine.get_transformer(sub_type)
if sub_trans.type_assertions_enabled:
sub_trans.assert_type(sub_type, v)
return
else:
return
except TypeTransformerFailedError:
continue
except TypeError:
continue
raise TypeTransformerFailedError(f"Value {v} is not of type {t}")
else:
super().assert_type(t, v)

def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
t = get_underlying_type(t)
Expand Down Expand Up @@ -1806,7 +1826,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:

def guess_python_type(self, literal_type: LiteralType) -> Union[Type[dict], typing.Dict[Type, Type]]:
if literal_type.map_value_type:
mt = TypeEngine.guess_python_type(literal_type.map_value_type)
mt: Type = TypeEngine.guess_python_type(literal_type.map_value_type)
return typing.Dict[str, mt] # type: ignore

if literal_type.simple == SimpleType.STRUCT:
Expand Down
93 changes: 93 additions & 0 deletions tests/flytekit/unit/core/test_unions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import typing
from dataclasses import dataclass
from enum import Enum
import sys
import pytest

from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError


def test_asserting():
@dataclass
class A:
a: str = None

@dataclass
class B:
b: str = None

@dataclass
class C:
c: str = None

ctx = FlyteContextManager.current_context()

pt = typing.Union[A, B, str]
lt = TypeEngine.to_literal_type(pt)
# mimic a register/remote fetch
guessed = TypeEngine.guess_python_type(lt)

TypeEngine.to_literal(ctx, A("a"), guessed, lt)
TypeEngine.to_literal(ctx, B(b="bb"), guessed, lt)
TypeEngine.to_literal(ctx, "hello", guessed, lt)

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, C("cc"), guessed, lt)

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, 3, guessed, lt)


@pytest.mark.skipif(
sys.version_info < (3, 10), reason="enum checking only works in 3.10+"
)
def test_asserting_enum():
class Color(Enum):
RED = "one"
GREEN = "two"
BLUE = "blue"

lt = TypeEngine.to_literal_type(Color)
guessed = TypeEngine.guess_python_type(lt)
tf = TypeEngine.get_transformer(guessed)
tf.assert_type(guessed, "one")
tf.assert_type(guessed, guessed("two"))
tf.assert_type(Color, "one")

guessed2 = TypeEngine.guess_python_type(lt)
tf.assert_type(guessed, guessed2("two"))


@pytest.mark.skipif(
sys.version_info >= (3, 10), reason="3.9 enum testing"
)
def test_asserting_enum_39():
class Color(Enum):
RED = "one"
GREEN = "two"
BLUE = "blue"

lt = TypeEngine.to_literal_type(Color)
guessed = TypeEngine.guess_python_type(lt)
tf = TypeEngine.get_transformer(guessed)
tf.assert_type(guessed, guessed("two"))
tf.assert_type(Color, Color.GREEN)


@pytest.mark.sandbox_test
def test_with_remote():
from flytekit.remote.remote import FlyteRemote
from typing_extensions import Annotated, get_args
from flytekit.configuration import Config, Image, ImageConfig, SerializationSettings

r = FlyteRemote(
Config.auto(config_file="/Users/ytong/.flyte/config-sandbox.yaml"),
default_project="flytesnacks",
default_domain="development",
)
lp = r.fetch_launch_plan(name="yt_dbg.scratchpad.union_enums.wf", version="oppOd5jst-LWExhTLM0F2w")
guessed_union_type = TypeEngine.guess_python_type(lp.interface.inputs["x"].type)
guessed_enum = get_args(guessed_union_type)[0]
val = guessed_enum("one")
r.execute(lp, inputs={"x": val})

0 comments on commit 83474c6

Please sign in to comment.