Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,10 @@ def generate_attribute_list_from_dataclass_json_mixin(
defs = schema.get("$defs", schema.get("definitions", {}))
if ref_name in defs:
ref_schema = defs[ref_name].copy()
# Check if the $ref points to an enum definition (no properties)
if ref_schema.get("enum"):
attribute_list.append((property_key, str))
continue
# Include $defs so nested models can resolve their own $refs
if "$defs" not in ref_schema and defs:
ref_schema["$defs"] = defs
Expand Down
45 changes: 45 additions & 0 deletions tests/flytekit/unit/core/test_enum_in_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import dataclasses
from enum import Enum

from pydantic import BaseModel

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine


class Status(str, Enum):
PENDING = "pending"
APPROVED = "approved"
REJECTED = "rejected"


class Job(BaseModel):
name: str
status: Status


def test_pydantic_model_with_enum_ref():
"""Test that a Pydantic model with an enum field (which produces a $ref in
the JSON schema) can be round-tripped through the type engine and that
guess_python_type reconstructs a valid dataclass."""
ctx = FlyteContext.current_context()
input = Job(name="test-job", status=Status.PENDING)

lt = TypeEngine.to_literal_type(Job)
lv = TypeEngine.to_literal(ctx, input, Job, lt)

assert lt
assert lv

# Roundtrip via the real Pydantic model
pv = TypeEngine.to_python_value(ctx, lv, Job)
assert pv == input

# Guess python type from the schema (simulates pyflyte run behaviour)
guessed = TypeEngine.guess_python_type(lt)
assert dataclasses.is_dataclass(guessed)

# The enum field should be reconstructed as str (enum $ref resolved)
v = guessed(name="test-job", status="pending")
assert v.name == "test-job"
assert v.status == "pending"
Loading