diff --git a/compiler/fory_compiler/generators/cpp.py b/compiler/fory_compiler/generators/cpp.py index f7302dd150..c5d4de1a34 100644 --- a/compiler/fory_compiler/generators/cpp.py +++ b/compiler/fory_compiler/generators/cpp.py @@ -1063,9 +1063,19 @@ def generate_union_definition( body_indent = f"{indent} " case_enum = f"{class_name}Case" - case_types = [ + raw_case_types = [ self.get_union_case_type(field, parent_stack) for field in union.fields ] + case_aliases = [ + f"ForyCase{self.to_pascal_case(field.name)}Type" + if "," in case_type + else None + for field, case_type in zip(union.fields, raw_case_types) + ] + case_types = [ + alias if alias is not None else case_type + for alias, case_type in zip(case_aliases, raw_case_types) + ] variant_type = f"std::variant<{', '.join(case_types)}>" comment = self.format_type_id_comment(union, f"{indent}//") @@ -1080,6 +1090,12 @@ def generate_union_definition( lines.append(f"{body_indent} }};") lines.append("") + for alias, case_type in zip(case_aliases, raw_case_types): + if alias is not None: + lines.append(f"{body_indent} using {alias} = {case_type};") + if any(alias is not None for alias in case_aliases): + lines.append("") + lines.append(f"{body_indent} {class_name}() = default;") lines.append("") @@ -1204,15 +1220,8 @@ def generate_union_macros( union_type = self.get_namespaced_type_name(union.name, parent_stack) lines.append(f"FORY_UNION({union_type},") for index, field in enumerate(union.fields): - case_type = self.generate_namespaced_type( - field.field_type, - False, - field.ref, - field.element_optional, - field.element_ref, - False, - False, - parent_stack, + case_type = self.get_union_case_macro_type( + field, union_type, parent_stack ) case_ctor = self.to_snake_case(field.name) meta = self.get_union_field_meta(field) @@ -1225,16 +1234,7 @@ def generate_union_macros( case_ids = ", ".join(str(field.number) for field in union.fields) lines.append(f"FORY_UNION_IDS({union_type}, {case_ids});") for field in union.fields: - case_type = self.generate_namespaced_type( - field.field_type, - False, - field.ref, - field.element_optional, - field.element_ref, - False, - False, - parent_stack, - ) + case_type = self.get_union_case_macro_type(field, union_type, parent_stack) case_ctor = self.to_snake_case(field.name) meta = self.get_union_field_meta(field) lines.append( @@ -1243,6 +1243,29 @@ def generate_union_macros( return lines + def get_union_case_macro_type( + self, + field: Field, + union_type: str, + parent_stack: List[Message], + ) -> str: + """Return the C++ type name used in FORY_UNION and FORY_UNION_CASE macros.""" + case_type = self.generate_namespaced_type( + field.field_type, + False, + field.ref, + field.element_optional, + field.element_ref, + False, + False, + parent_stack, + ) + # FORY_UNION and FORY_UNION_CASE split macro arguments on commas, + # so raw template types such as std::unordered_map need an alias. + if "," in case_type: + return f"{union_type}::ForyCase{self.to_pascal_case(field.name)}Type" + return case_type + def get_union_case_type(self, field: Field, parent_stack: List[Message]) -> str: """Return the C++ type for a union case.""" return self.generate_type( diff --git a/compiler/fory_compiler/tests/test_generated_code.py b/compiler/fory_compiler/tests/test_generated_code.py index 16e8fca31e..c61c9c2ceb 100644 --- a/compiler/fory_compiler/tests/test_generated_code.py +++ b/compiler/fory_compiler/tests/test_generated_code.py @@ -1126,6 +1126,71 @@ def test_cpp_generator_supports_decimal_fields_and_unions(): assert "(amount, fory::serialization::Decimal, fory::F(1))" in cpp_output +def test_cpp_union_aliases_comma_payload_types(): + schema = parse_fdl( + dedent( + """ + package gen; + + union MapChoice { + map by_name = 1; + map counts = 2; + list values = 3; + string name = 4; + } + + union LargeChoice { + map counts = 1; + bool enabled = 2; + int8 i8 = 3; + int16 i16 = 4; + int32 i32 = 5; + int64 i64 = 6; + uint8 u8 = 7; + uint16 u16 = 8; + uint32 u32 = 9; + uint64 u64 = 10; + float32 f32 = 11; + float64 f64 = 12; + string name = 13; + bytes blob = 14; + decimal amount = 15; + date day = 16; + timestamp ts = 17; + } + """ + ) + ) + + cpp_output = render_files(generate_files(schema, CppGenerator)) + assert ( + "using ForyCaseByNameType = std::unordered_map;" + in cpp_output + ) + assert ( + "using ForyCaseCountsType = std::unordered_map;" + in cpp_output + ) + assert ( + "(by_name, gen::MapChoice::ForyCaseByNameType, " + "fory::F(1).map(fory::T::string(), fory::FieldNodeSpec{}))" in cpp_output + ) + assert ( + "(counts, gen::MapChoice::ForyCaseCountsType, " + "fory::F(2).map(fory::T::string(), fory::T::int32().varint()))" in cpp_output + ) + assert ( + "(values, std::vector, " + "fory::F(3).list(fory::FieldNodeSpec{}))" in cpp_output + ) + assert "(name, std::string, fory::F(4))" in cpp_output + assert ( + "FORY_UNION_CASE(gen::LargeChoice, 1, " + "gen::LargeChoice::ForyCaseCountsType, gen::LargeChoice::counts, " + "fory::F(1).map(fory::T::string(), fory::T::int32().varint()));" in cpp_output + ) + + def test_cpp_omits_equality_for_any_types(): schema = parse_fdl( dedent(