From 1b41163d912132311a09e1aa9b6a8e993e6e8493 Mon Sep 17 00:00:00 2001 From: Peiyang He Date: Fri, 3 Jul 2026 13:18:43 +0800 Subject: [PATCH] fix(compiler): never generate C++ equality methods for message and union containing any --- compiler/fory_compiler/generators/cpp.py | 148 +++++++++++++++--- .../tests/test_generated_code.py | 88 +++++++++++ integration_tests/idl_tests/cpp/main.cc | 39 ++++- 3 files changed, 246 insertions(+), 29 deletions(-) diff --git a/compiler/fory_compiler/generators/cpp.py b/compiler/fory_compiler/generators/cpp.py index 91e6b49fdd..f7302dd150 100644 --- a/compiler/fory_compiler/generators/cpp.py +++ b/compiler/fory_compiler/generators/cpp.py @@ -691,13 +691,6 @@ def get_field_eq_expression( ) -> str: member_name = self.get_field_member_name(field) other_member = f"other.{member_name}" - if isinstance(field.field_type, PrimitiveType) and ( - field.field_type.kind == PrimitiveKind.ANY - ): - return ( - f"((!{member_name}.has_value() && !{other_member}.has_value()) || " - f"({member_name}.type() == {other_member}.type()))" - ) if self.is_message_type( field.field_type, parent_stack ) and self.get_field_weak_ref(field): @@ -716,6 +709,104 @@ def get_field_eq_expression( ) return f"{member_name} == {other_member}" + def message_has_any( + self, + message: Message, + parent_stack: Optional[List[Message]] = None, + visiting: Optional[Set[Tuple[str, int]]] = None, + ) -> bool: + if visiting is None: + visiting = set() + key = ("message", id(message)) + if key in visiting: + return False + visiting.add(key) + try: + lineage = (parent_stack or []) + [message] + return any( + self.field_type_has_any(field.field_type, lineage, visiting) + for field in message.fields + ) + finally: + visiting.remove(key) + + def union_has_any( + self, + union: Union, + parent_stack: Optional[List[Message]] = None, + visiting: Optional[Set[Tuple[str, int]]] = None, + ) -> bool: + if visiting is None: + visiting = set() + key = ("union", id(union)) + if key in visiting: + return False + visiting.add(key) + try: + return any( + self.field_type_has_any(field.field_type, parent_stack, visiting) + for field in union.fields + ) + finally: + visiting.remove(key) + + def field_type_has_any( + self, + field_type: FieldType, + parent_stack: Optional[List[Message]] = None, + visiting: Optional[Set[Tuple[str, int]]] = None, + ) -> bool: + """Return True when a field type or its children contain `any`.""" + if isinstance(field_type, PrimitiveType): + return field_type.kind == PrimitiveKind.ANY + if isinstance(field_type, ListType): + return self.field_type_has_any( + field_type.element_type, parent_stack, visiting + ) + if isinstance(field_type, ArrayType): + return self.field_type_has_any( + field_type.element_type, parent_stack, visiting + ) + if isinstance(field_type, MapType): + # `any` is not allowed as map key (rejected first by the validator), + # so we only check map value here. + return self.field_type_has_any( + field_type.value_type, parent_stack, visiting + ) + if isinstance(field_type, NamedType): + named_type = self.resolve_named_type(field_type.name, parent_stack) + if isinstance(named_type, Message): + return self.message_has_any( + named_type, self._parent_stack_for_type(named_type), visiting + ) + if isinstance(named_type, Union): + return self.union_has_any( + named_type, self._parent_stack_for_type(named_type), visiting + ) + return False + + def _parent_stack_for_type(self, type_def: object) -> List[Message]: + def visit(message: Message, parents: List[Message]) -> Optional[List[Message]]: + if message is type_def: + return parents + for nested_union in message.nested_unions: + if nested_union is type_def: + return parents + [message] + for nested_enum in message.nested_enums: + if nested_enum is type_def: + return parents + [message] + for nested_message in message.nested_messages: + found = visit(nested_message, parents + [message]) + if found is not None: + return found + return None + + for top in self.schema.messages: + found = visit(top, []) + if found is not None: + return found + return [] + def is_numeric_field(self, field: Field) -> bool: if not isinstance(field.field_type, PrimitiveType): return False @@ -914,19 +1005,23 @@ def generate_message_definition( lines.append("") lines.append("") - lines.append( - f"{body_indent}bool operator==(const {class_name}& other) const {{" - ) - if message.fields: - conditions = [ - self.get_field_eq_expression(field, lineage) for field in message.fields - ] - lines.append(f"{body_indent} return {' && '.join(conditions)};") - else: - lines.append(f"{body_indent} return true;") - lines.append(f"{body_indent}}}") + # We don't generate equality method for message containing `any` + # since C++ doesn't support std::any == std::any. + if not self.message_has_any(message, parent_stack): + lines.append( + f"{body_indent}bool operator==(const {class_name}& other) const {{" + ) + if message.fields: + conditions = [ + self.get_field_eq_expression(field, lineage) + for field in message.fields + ] + lines.append(f"{body_indent} return {' && '.join(conditions)};") + else: + lines.append(f"{body_indent} return true;") + lines.append(f"{body_indent}}}") + lines.append("") - lines.append("") lines.extend(self.generate_bytes_methods(class_name, body_indent)) struct_type_name = self.get_qualified_type_name(message.name, parent_stack) @@ -1069,12 +1164,15 @@ def generate_union_definition( ) lines.append(f"{body_indent} }}") lines.append("") - lines.append( - f"{body_indent} bool operator==(const {class_name}& other) const {{" - ) - lines.append(f"{body_indent} return value_ == other.value_;") - lines.append(f"{body_indent} }}") - lines.append("") + # We don't generate equality method for union containing `any` + # since C++ doesn't support std::any == std::any. + if not self.union_has_any(union, parent_stack): + lines.append( + f"{body_indent} bool operator==(const {class_name}& other) const {{" + ) + lines.append(f"{body_indent} return value_ == other.value_;") + lines.append(f"{body_indent} }}") + lines.append("") lines.extend(self.generate_bytes_methods(class_name, f"{body_indent} ")) diff --git a/compiler/fory_compiler/tests/test_generated_code.py b/compiler/fory_compiler/tests/test_generated_code.py index d5f7b39602..16e8fca31e 100644 --- a/compiler/fory_compiler/tests/test_generated_code.py +++ b/compiler/fory_compiler/tests/test_generated_code.py @@ -1126,6 +1126,94 @@ def test_cpp_generator_supports_decimal_fields_and_unions(): assert "(amount, fory::serialization::Decimal, fory::F(1))" in cpp_output +def test_cpp_omits_equality_for_any_types(): + schema = parse_fdl( + dedent( + """ + package gen; + + message Inner { + any value = 1; + } + + union AnyChoice { + Inner inner = 1; + string name = 2; + } + + message DirectAny { + any value = 1; + } + + message AnyList { + list values = 1; + } + + message AnyMap { + map values = 1; + } + + union DirectChoice { + any payload = 1; + list values = 2; + string name = 3; + } + + message DirectOwner { + Inner inner = 1; + } + + message ListOwner { + list values = 1; + } + + message MapOwner { + map values = 1; + } + + message UnionOwner { + AnyChoice choice = 1; + } + + message DeclaresNestedOnly { + message Nested { + any value = 1; + } + + string name = 1; + } + + message Plain { + string name = 1; + list values = 2; + map counts = 3; + } + + union PlainChoice { + string name = 1; + int32 code = 2; + } + """ + ) + ) + + cpp_output = render_files(generate_files(schema, CppGenerator)) + assert "bool operator==(const Inner& other) const" not in cpp_output + assert "bool operator==(const AnyChoice& other) const" not in cpp_output + assert "bool operator==(const DirectAny& other) const" not in cpp_output + assert "bool operator==(const AnyList& other) const" not in cpp_output + assert "bool operator==(const AnyMap& other) const" not in cpp_output + assert "bool operator==(const DirectChoice& other) const" not in cpp_output + assert "bool operator==(const DirectOwner& other) const" not in cpp_output + assert "bool operator==(const ListOwner& other) const" not in cpp_output + assert "bool operator==(const MapOwner& other) const" not in cpp_output + assert "bool operator==(const UnionOwner& other) const" not in cpp_output + assert "bool operator==(const Nested& other) const" not in cpp_output + assert "bool operator==(const DeclaresNestedOnly& other) const" in cpp_output + assert "bool operator==(const Plain& other) const" in cpp_output + assert "bool operator==(const PlainChoice& other) const" in cpp_output + + def test_cpp_nested_container_ref_uses_correct_pointer_type(): schema = parse_fdl( dedent( diff --git a/integration_tests/idl_tests/cpp/main.cc b/integration_tests/idl_tests/cpp/main.cc index cb41a987a1..8d9ec6baff 100644 --- a/integration_tests/idl_tests/cpp/main.cc +++ b/integration_tests/idl_tests/cpp/main.cc @@ -1076,6 +1076,23 @@ fory::Result RunEvolvingRoundTrip() { using StringMap = std::unordered_map; +template +fory::Result +ValidateAnyField(const std::any &actual_any, const std::any &expected_any, + const std::string &field_name) { + const auto *actual = std::any_cast(&actual_any); + const auto *expected = std::any_cast(&expected_any); + if (actual == nullptr || expected == nullptr) { + return fory::Unexpected( + fory::Error::invalid("any holder " + field_name + " type mismatch")); + } + if (!(*actual == *expected)) { + return fory::Unexpected( + fory::Error::invalid("any holder " + field_name + " value mismatch")); + } + return fory::Result(); +} + fory::Result RunRoundTrip(bool compatible) { auto fory = fory::serialization::Fory::builder() .xlang(true) @@ -1479,10 +1496,24 @@ fory::Result RunRoundTrip(bool compatible) { FORY_TRY(any_roundtrip, fory.deserialize( any_bytes.data(), any_bytes.size())); - if (!(any_roundtrip == any_holder)) { - return fory::Unexpected( - fory::Error::invalid("any holder roundtrip mismatch")); - } + FORY_RETURN_IF_ERROR(ValidateAnyField( + any_roundtrip.bool_value(), any_holder.bool_value(), "bool_value")); + FORY_RETURN_IF_ERROR(ValidateAnyField( + any_roundtrip.string_value(), any_holder.string_value(), "string_value")); + FORY_RETURN_IF_ERROR(ValidateAnyField( + any_roundtrip.date_value(), any_holder.date_value(), "date_value")); + FORY_RETURN_IF_ERROR(ValidateAnyField( + any_roundtrip.timestamp_value(), any_holder.timestamp_value(), + "timestamp_value")); + FORY_RETURN_IF_ERROR(ValidateAnyField( + any_roundtrip.message_value(), any_holder.message_value(), + "message_value")); + FORY_RETURN_IF_ERROR(ValidateAnyField( + any_roundtrip.union_value(), any_holder.union_value(), "union_value")); + FORY_RETURN_IF_ERROR(ValidateAnyField>( + any_roundtrip.list_value(), any_holder.list_value(), "list_value")); + FORY_RETURN_IF_ERROR(ValidateAnyField( + any_roundtrip.map_value(), any_holder.map_value(), "map_value")); example_peer::ExampleMessage example_message = BuildExampleMessage(); FORY_TRY(example_bytes, fory.serialize(example_message));