Skip to content

Commit df94e1c

Browse files
mg515eapolinario
andauthored
Add flytekit-omegaconf plugin (flyteorg#2299)
* add flytekit-hydra Signed-off-by: mg515 <miha.garafolj@gmail.com> * fix small typo readme Signed-off-by: mg515 <miha.garafolj@gmail.com> * ruff ruff Signed-off-by: mg515 <miha.garafolj@gmail.com> * lint more Signed-off-by: mg515 <miha.garafolj@gmail.com> * rename plugin into flytekit-omegaconf Signed-off-by: mg515 <miha.garafolj@gmail.com> * lint sort imports Signed-off-by: mg515 <miha.garafolj@gmail.com> * use flytekit logger Signed-off-by: mg515 <miha.garafolj@gmail.com> * use flytekit logger #2 Signed-off-by: mg515 <miha.garafolj@gmail.com> * fix typing info in is_flatable Signed-off-by: mg515 <miha.garafolj@gmail.com> * use default_factory instead of mutable default value Signed-off-by: mg515 <miha.garafolj@gmail.com> * add python3.11 and python3.12 to setup.py Signed-off-by: mg515 <miha.garafolj@gmail.com> * make fmt Signed-off-by: mg515 <miha.garafolj@gmail.com> * define error message only once Signed-off-by: mg515 <miha.garafolj@gmail.com> * add docstring Signed-off-by: mg515 <miha.garafolj@gmail.com> * remove GenericEnumTransformer and tests Signed-off-by: mg515 <miha.garafolj@gmail.com> * fallback to TypeEngine.get_transformer(node_type) to find suitable transformer Signed-off-by: mg515 <miha.garafolj@gmail.com> * explicit valueerrors instead of asserts Signed-off-by: mg515 <miha.garafolj@gmail.com> * minor style improvements Signed-off-by: mg515 <miha.garafolj@gmail.com> * remove obsolete warnings Signed-off-by: mg515 <miha.garafolj@gmail.com> * import flytekit logger instead of instantiating our own Signed-off-by: mg515 <miha.garafolj@gmail.com> * docstrings in reST format Signed-off-by: mg515 <miha.garafolj@gmail.com> * refactor transformer mode Signed-off-by: mg515 <miha.garafolj@gmail.com> * improve docs Signed-off-by: mg515 <miha.garafolj@gmail.com> * refactor dictconfig class into smaller methods Signed-off-by: mg515 <miha.garafolj@gmail.com> * add unit tests for dictconfig transformer Signed-off-by: mg515 <miha.garafolj@gmail.com> * refactor of parse_type_description() Signed-off-by: mg515 <miha.garafolj@gmail.com> * add omegaconf plugin to pythonbuild.yaml --------- Signed-off-by: mg515 <miha.garafolj@gmail.com> Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com> Co-authored-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>
1 parent 3549597 commit df94e1c

14 files changed

Lines changed: 981 additions & 0 deletions

File tree

.github/workflows/pythonbuild.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ jobs:
346346
# onnx-tensorflow needs a version of tensorflow that does not work with protobuf>4.
347347
# The issue is being tracked on the tensorflow side in https://github.com/tensorflow/tensorflow/issues/53234#issuecomment-1330111693
348348
# flytekit-onnx-tensorflow
349+
- flytekit-omegaconf
349350
- flytekit-openai
350351
- flytekit-pandera
351352
- flytekit-papermill
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Flytekit OmegaConf Plugin
2+
3+
Flytekit python natively supports serialization of many data types for exchanging information between tasks.
4+
The Flytekit OmegaConf Plugin extends these by the `DictConfig` type from the
5+
[OmegaConf package](https://omegaconf.readthedocs.io/) as well as related types
6+
that are being used by the [hydra package](https://hydra.cc/) for configuration management.
7+
8+
## Task example
9+
```
10+
from dataclasses import dataclass
11+
import flytekitplugins.omegaconf # noqa F401
12+
from flytekit import task, workflow
13+
from omegaconf import DictConfig
14+
15+
@dataclass
16+
class MySimpleConf:
17+
_target_: str = "lightning_module.MyEncoderModule"
18+
learning_rate: float = 0.0001
19+
20+
@task
21+
def my_task(cfg: DictConfig) -> None:
22+
print(f"Doing things with {cfg.learning_rate=}")
23+
24+
25+
@workflow
26+
def pipeline(cfg: DictConfig) -> None:
27+
my_task(cfg=cfg)
28+
29+
30+
if __name__ == "__main__":
31+
from omegaconf import OmegaConf
32+
33+
cfg = OmegaConf.structured(MySimpleConf)
34+
pipeline(cfg=cfg)
35+
```
36+
37+
## Transformer configuration
38+
39+
The transformer can be set to one of three modes:
40+
41+
`Dataclass` - This mode should be used with a StructuredConfig and will reconstruct the config from the matching dataclass
42+
during deserialisation in order to make typing information from the dataclass and continued validation thereof available.
43+
This requires the dataclass definition to be available via python import in the Flyte execution environment in which
44+
objects are (de-)serialised.
45+
46+
`DictConfig` - This mode will deserialize the config into a DictConfig object. In particular, dataclasses are translated
47+
into DictConfig objects and only primitive types are being checked. The definition of underlying dataclasses for
48+
structured configs is only required during the initial serialization for this mode.
49+
50+
`Auto` - This mode will try to deserialize according to the Dataclass mode and fall back to the DictConfig mode if the
51+
dataclass definition is not available. This is the default mode.
52+
53+
You can set the transformer mode globally or for the current context only the following ways:
54+
```python
55+
from flytekitplugins.omegaconf import set_transformer_mode, set_local_transformer_mode, OmegaConfTransformerMode
56+
57+
# Set the global transformer mode using the new function
58+
set_transformer_mode(OmegaConfTransformerMode.DictConfig)
59+
60+
# You can also the mode for the current context only
61+
with set_local_transformer_mode(OmegaConfTransformerMode.Dataclass):
62+
# This will use the Dataclass mode
63+
pass
64+
```
65+
66+
```note
67+
Since the DictConfig is flattened and keys transformed into dot notation, the keys of the DictConfig must not contain
68+
dots.
69+
```
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from contextlib import contextmanager
2+
3+
from flytekitplugins.omegaconf.config import OmegaConfTransformerMode
4+
from flytekitplugins.omegaconf.dictconfig_transformer import DictConfigTransformer # noqa: F401
5+
from flytekitplugins.omegaconf.listconfig_transformer import ListConfigTransformer # noqa: F401
6+
7+
_TRANSFORMER_MODE = OmegaConfTransformerMode.Auto
8+
9+
10+
def set_transformer_mode(mode: OmegaConfTransformerMode) -> None:
11+
"""Set the global serialization mode for OmegaConf objects."""
12+
global _TRANSFORMER_MODE
13+
_TRANSFORMER_MODE = mode
14+
15+
16+
def get_transformer_mode() -> OmegaConfTransformerMode:
17+
"""Get the global serialization mode for OmegaConf objects."""
18+
return _TRANSFORMER_MODE
19+
20+
21+
@contextmanager
22+
def local_transformer_mode(mode: OmegaConfTransformerMode):
23+
"""Context manager to set a local serialization mode for OmegaConf objects."""
24+
global _TRANSFORMER_MODE
25+
previous_mode = _TRANSFORMER_MODE
26+
set_transformer_mode(mode)
27+
try:
28+
yield
29+
finally:
30+
set_transformer_mode(previous_mode)
31+
32+
33+
__all__ = ["set_transformer_mode", "get_transformer_mode", "local_transformer_mode", "OmegaConfTransformerMode"]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from enum import Enum
2+
3+
4+
class OmegaConfTransformerMode(Enum):
5+
"""
6+
Operation Mode indicating whether a (potentially unannotated) DictConfig object or a structured config using the
7+
underlying dataclass is returned.
8+
9+
Note: We define a single shared config across all transformers as recursive calls should refer to the same config
10+
Note: The latter requires the use of structured configs.
11+
"""
12+
13+
DictConfig = "DictConfig"
14+
DataClass = "DataClass"
15+
Auto = "Auto"
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import importlib
2+
import re
3+
from typing import Any, Dict, Type, TypeVar
4+
5+
import flatten_dict
6+
import flytekitplugins.omegaconf
7+
from flyteidl.core.literals_pb2 import Literal as PB_Literal
8+
from flytekitplugins.omegaconf.config import OmegaConfTransformerMode
9+
from flytekitplugins.omegaconf.type_information import extract_node_type
10+
from google.protobuf.json_format import MessageToDict, ParseDict
11+
from google.protobuf.struct_pb2 import Struct
12+
13+
import omegaconf
14+
from flytekit import FlyteContext
15+
from flytekit.core.type_engine import TypeTransformerFailedError
16+
from flytekit.extend import TypeEngine, TypeTransformer
17+
from flytekit.loggers import logger
18+
from flytekit.models.literals import Literal, Scalar
19+
from flytekit.models.types import LiteralType, SimpleType
20+
from omegaconf import DictConfig, OmegaConf
21+
22+
T = TypeVar("T")
23+
NoneType = type(None)
24+
25+
26+
class DictConfigTransformer(TypeTransformer[DictConfig]):
27+
def __init__(self):
28+
"""Construct DictConfigTransformer."""
29+
super().__init__(name="OmegaConf DictConfig", t=DictConfig)
30+
31+
def get_literal_type(self, t: Type[DictConfig]) -> LiteralType:
32+
"""
33+
Provide type hint for Flytekit type system.
34+
35+
To support the multivariate typing of nodes in a DictConfig, we encode them as binaries (no introspection)
36+
with multiple files.
37+
"""
38+
return LiteralType(simple=SimpleType.STRUCT)
39+
40+
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
41+
"""Convert from given python type object ``DictConfig`` to the Literal representation."""
42+
check_if_valid_dictconfig(python_val)
43+
44+
base_config = OmegaConf.get_type(python_val)
45+
type_map, value_map = extract_type_and_value_maps(ctx, python_val)
46+
wrapper = create_struct(type_map, value_map, base_config)
47+
48+
return Literal(scalar=Scalar(generic=wrapper))
49+
50+
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[DictConfig]) -> DictConfig:
51+
"""Re-hydrate the custom object from Flyte Literal value."""
52+
if lv and lv.scalar is not None:
53+
nested_dict = flatten_dict.unflatten(MessageToDict(lv.scalar.generic), splitter="dot")
54+
cfg_dict = {}
55+
for key, type_desc in nested_dict["types"].items():
56+
cfg_dict[key] = parse_node_value(ctx, key, type_desc, nested_dict)
57+
58+
return handle_base_dataclass(ctx, nested_dict, cfg_dict)
59+
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")
60+
61+
62+
def is_flattenable(d: DictConfig) -> bool:
63+
"""Check if a DictConfig can be properly flattened and unflattened, i.e. keys do not contain dots."""
64+
return all(
65+
isinstance(k, str) # keys are strings ...
66+
and "." not in k # ... and do not contain dots
67+
and (
68+
OmegaConf.is_missing(d, k) # values are either MISSING ...
69+
or not isinstance(d[k], DictConfig) # ... not nested Dictionaries ...
70+
or is_flattenable(d[k])
71+
) # or flattenable themselves
72+
for k in d.keys()
73+
)
74+
75+
76+
def check_if_valid_dictconfig(python_val: DictConfig) -> None:
77+
"""Validate the DictConfig to ensure it's serializable."""
78+
if not isinstance(python_val, DictConfig):
79+
raise ValueError(f"Invalid type {type(python_val)}, can only serialize DictConfigs")
80+
if not is_flattenable(python_val):
81+
raise ValueError(f"{python_val} cannot be flattened as it contains non-string keys or keys containing dots.")
82+
83+
84+
def extract_type_and_value_maps(ctx: FlyteContext, python_val: DictConfig) -> (Dict[str, str], Dict[str, Any]):
85+
"""Extract type and value maps from the DictConfig."""
86+
type_map = {}
87+
value_map = {}
88+
for key in python_val.keys():
89+
if OmegaConf.is_missing(python_val, key):
90+
type_map[key] = "MISSING"
91+
else:
92+
node_type, type_name = extract_node_type(python_val, key)
93+
type_map[key] = type_name
94+
95+
transformer = TypeEngine.get_transformer(node_type)
96+
literal_type = transformer.get_literal_type(node_type)
97+
98+
value_map[key] = MessageToDict(
99+
transformer.to_literal(ctx, python_val[key], node_type, literal_type).to_flyte_idl()
100+
)
101+
return type_map, value_map
102+
103+
104+
def create_struct(type_map: Dict[str, str], value_map: Dict[str, Any], base_config: Type) -> Struct:
105+
"""Create a protobuf Struct object from type and value maps."""
106+
wrapper = Struct()
107+
wrapper.update(
108+
flatten_dict.flatten(
109+
{
110+
"types": type_map,
111+
"values": value_map,
112+
"base_dataclass": f"{base_config.__module__}.{base_config.__name__}",
113+
},
114+
reducer="dot",
115+
keep_empty_types=(dict,),
116+
)
117+
)
118+
return wrapper
119+
120+
121+
def parse_type_description(type_desc: str) -> Type:
122+
"""Parse the type description and return the corresponding type."""
123+
generic_pattern = re.compile(r"(?P<type>[^\[\]]+)\[(?P<args>[^\[\]]+)\]")
124+
match = generic_pattern.match(type_desc)
125+
126+
if match:
127+
origin_type = match.group("type")
128+
args = match.group("args").split(", ")
129+
130+
origin_module, origin_class = origin_type.rsplit(".", 1)
131+
origin = importlib.import_module(origin_module).__getattribute__(origin_class)
132+
133+
sub_types = []
134+
for arg in args:
135+
if arg == "NoneType":
136+
sub_types.append(type(None))
137+
else:
138+
module_name, class_name = arg.rsplit(".", 1)
139+
sub_type = importlib.import_module(module_name).__getattribute__(class_name)
140+
sub_types.append(sub_type)
141+
142+
if origin_class == "Optional":
143+
return origin[sub_types[0]]
144+
return origin[tuple(sub_types)]
145+
else:
146+
module_name, class_name = type_desc.rsplit(".", 1)
147+
return importlib.import_module(module_name).__getattribute__(class_name)
148+
149+
150+
def parse_node_value(ctx: FlyteContext, key: str, type_desc: str, nested_dict: Dict[str, Any]) -> Any:
151+
"""Parse the node value from the nested dictionary."""
152+
if type_desc == "MISSING":
153+
return omegaconf.MISSING
154+
155+
node_type = parse_type_description(type_desc)
156+
transformer = TypeEngine.get_transformer(node_type)
157+
value_literal = Literal.from_flyte_idl(ParseDict(nested_dict["values"][key], PB_Literal()))
158+
return transformer.to_python_value(ctx, value_literal, node_type)
159+
160+
161+
def handle_base_dataclass(ctx: FlyteContext, nested_dict: Dict[str, Any], cfg_dict: Dict[str, Any]) -> DictConfig:
162+
"""Handle the base dataclass and create the DictConfig."""
163+
if (
164+
nested_dict["base_dataclass"] != "builtins.dict"
165+
and flytekitplugins.omegaconf.get_transformer_mode() != OmegaConfTransformerMode.DictConfig
166+
):
167+
# Explicitly instantiate dataclass and create DictConfig from there in order to have typing information
168+
module_name, class_name = nested_dict["base_dataclass"].rsplit(".", 1)
169+
try:
170+
return OmegaConf.structured(importlib.import_module(module_name).__getattribute__(class_name)(**cfg_dict))
171+
except (ModuleNotFoundError, AttributeError) as e:
172+
logger.error(
173+
f"Could not import module {module_name}. If you want to deserialise to DictConfig, "
174+
f"set the mode to DictConfigTransformerMode.DictConfig."
175+
)
176+
if flytekitplugins.omegaconf.get_transformer_mode() == OmegaConfTransformerMode.DataClass:
177+
raise e
178+
return OmegaConf.create(cfg_dict)
179+
180+
181+
TypeEngine.register(DictConfigTransformer())

0 commit comments

Comments
 (0)