Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 43 additions & 20 deletions compiler/fory_compiler/generators/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,9 +1063,19 @@ def generate_union_definition(
body_indent = f"{indent} "

case_enum = f"{class_name}Case"
case_types = [
raw_case_types = [
self.get_union_case_type(field, parent_stack) for field in union.fields
]
case_aliases = [
f"ForyCase{self.to_pascal_case(field.name)}Type"
if "," in case_type
else None
for field, case_type in zip(union.fields, raw_case_types)
]
case_types = [
alias if alias is not None else case_type
for alias, case_type in zip(case_aliases, raw_case_types)
]
variant_type = f"std::variant<{', '.join(case_types)}>"

comment = self.format_type_id_comment(union, f"{indent}//")
Expand All @@ -1080,6 +1090,12 @@ def generate_union_definition(
lines.append(f"{body_indent} }};")
lines.append("")

for alias, case_type in zip(case_aliases, raw_case_types):
if alias is not None:
lines.append(f"{body_indent} using {alias} = {case_type};")
if any(alias is not None for alias in case_aliases):
lines.append("")

lines.append(f"{body_indent} {class_name}() = default;")
lines.append("")

Expand Down Expand Up @@ -1204,15 +1220,8 @@ def generate_union_macros(
union_type = self.get_namespaced_type_name(union.name, parent_stack)
lines.append(f"FORY_UNION({union_type},")
for index, field in enumerate(union.fields):
case_type = self.generate_namespaced_type(
field.field_type,
False,
field.ref,
field.element_optional,
field.element_ref,
False,
False,
parent_stack,
case_type = self.get_union_case_macro_type(
field, union_type, parent_stack
)
case_ctor = self.to_snake_case(field.name)
meta = self.get_union_field_meta(field)
Expand All @@ -1225,16 +1234,7 @@ def generate_union_macros(
case_ids = ", ".join(str(field.number) for field in union.fields)
lines.append(f"FORY_UNION_IDS({union_type}, {case_ids});")
for field in union.fields:
case_type = self.generate_namespaced_type(
field.field_type,
False,
field.ref,
field.element_optional,
field.element_ref,
False,
False,
parent_stack,
)
case_type = self.get_union_case_macro_type(field, union_type, parent_stack)
case_ctor = self.to_snake_case(field.name)
meta = self.get_union_field_meta(field)
lines.append(
Expand All @@ -1243,6 +1243,29 @@ def generate_union_macros(

return lines

def get_union_case_macro_type(
self,
field: Field,
union_type: str,
parent_stack: List[Message],
) -> str:
"""Return the C++ type name used in FORY_UNION and FORY_UNION_CASE macros."""
case_type = self.generate_namespaced_type(
field.field_type,
False,
field.ref,
field.element_optional,
field.element_ref,
False,
False,
parent_stack,
)
# FORY_UNION and FORY_UNION_CASE split macro arguments on commas,
# so raw template types such as std::unordered_map<K, V> need an alias.
if "," in case_type:
return f"{union_type}::ForyCase{self.to_pascal_case(field.name)}Type"
return case_type

def get_union_case_type(self, field: Field, parent_stack: List[Message]) -> str:
"""Return the C++ type for a union case."""
return self.generate_type(
Expand Down
65 changes: 65 additions & 0 deletions compiler/fory_compiler/tests/test_generated_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,71 @@ def test_cpp_generator_supports_decimal_fields_and_unions():
assert "(amount, fory::serialization::Decimal, fory::F(1))" in cpp_output


def test_cpp_union_aliases_comma_payload_types():
schema = parse_fdl(
dedent(
"""
package gen;

union MapChoice {
map<string, any> by_name = 1;
map<string, int32> counts = 2;
list<any> values = 3;
string name = 4;
}

union LargeChoice {
map<string, int32> counts = 1;
bool enabled = 2;
int8 i8 = 3;
int16 i16 = 4;
int32 i32 = 5;
int64 i64 = 6;
uint8 u8 = 7;
uint16 u16 = 8;
uint32 u32 = 9;
uint64 u64 = 10;
float32 f32 = 11;
float64 f64 = 12;
string name = 13;
bytes blob = 14;
decimal amount = 15;
date day = 16;
timestamp ts = 17;
}
"""
)
)

cpp_output = render_files(generate_files(schema, CppGenerator))
assert (
"using ForyCaseByNameType = std::unordered_map<std::string, std::any>;"
in cpp_output
)
assert (
"using ForyCaseCountsType = std::unordered_map<std::string, int32_t>;"
in cpp_output
)
assert (
"(by_name, gen::MapChoice::ForyCaseByNameType, "
"fory::F(1).map(fory::T::string(), fory::FieldNodeSpec{}))" in cpp_output
)
assert (
"(counts, gen::MapChoice::ForyCaseCountsType, "
"fory::F(2).map(fory::T::string(), fory::T::int32().varint()))" in cpp_output
)
assert (
"(values, std::vector<std::any>, "
"fory::F(3).list(fory::FieldNodeSpec{}))" in cpp_output
)
assert "(name, std::string, fory::F(4))" in cpp_output
assert (
"FORY_UNION_CASE(gen::LargeChoice, 1, "
"gen::LargeChoice::ForyCaseCountsType, gen::LargeChoice::counts, "
"fory::F(1).map(fory::T::string(), fory::T::int32().varint()));" in cpp_output
)


def test_cpp_omits_equality_for_any_types():
schema = parse_fdl(
dedent(
Expand Down
Loading