Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 120 additions & 7 deletions compiler/fory_compiler/generators/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,12 +691,13 @@ def get_field_eq_expression(
) -> str:
member_name = self.get_field_member_name(field)
other_member = f"other.{member_name}"
if isinstance(field.field_type, PrimitiveType) and (
field.field_type.kind == PrimitiveKind.ANY
):
return (
f"((!{member_name}.has_value() && !{other_member}.has_value()) || "
f"({member_name}.type() == {other_member}.type()))"
if self.field_type_has_any(field.field_type):
return self.get_value_eq_expression(
field.field_type,
member_name,
other_member,
nullable=field.optional,
element_optional=field.element_optional,
)
if self.is_message_type(
field.field_type, parent_stack
Expand All @@ -716,6 +717,101 @@ def get_field_eq_expression(
)
return f"{member_name} == {other_member}"

def field_type_has_any(self, field_type: FieldType) -> bool:
"""Recursively check whether a field type contains `any`."""
if isinstance(field_type, PrimitiveType):
return field_type.kind == PrimitiveKind.ANY
if isinstance(field_type, ListType):
return self.field_type_has_any(field_type.element_type)
if isinstance(field_type, ArrayType):
return self.field_type_has_any(field_type.element_type)
if isinstance(field_type, MapType):
# `any` as map key will be rejected, so we don't check here.
return self.field_type_has_any(field_type.value_type)
return False

def get_value_eq_expression(
self,
field_type: FieldType,
left: str,
right: str,
nullable: bool = False,
element_optional: bool = False,
depth: int = 0,
) -> str:
if nullable:
# outer is std::optional<T>.
inner = self.get_value_eq_expression(
field_type,
f"*({left})",
f"*({right})",
element_optional=element_optional,
depth=depth + 1,
)
return (
f"([&]() {{ if (({left}).has_value() != ({right}).has_value()) "
f"{{ return false; }} if (!({left}).has_value()) {{ return true; }} "
f"return {inner}; }}())"
)

if not self.field_type_has_any(field_type):
# Compare directly.
return f"({left} == {right})"

if isinstance(field_type, PrimitiveType):
# This is only a rough compare.
return (
f"((!({left}).has_value() && !({right}).has_value()) || "
f"(({left}).type() == ({right}).type()))"
)

if isinstance(field_type, (ListType, ArrayType)):
# First compare size, then compare each element.
element_type = field_type.element_type
effective_element_optional = element_optional
if isinstance(field_type, ListType):
# Array doesn't allow optional element, so we only check List.
effective_element_optional = (
effective_element_optional or field_type.element_optional
)
left_it = f"_fory_left_it_{depth}"
right_it = f"_fory_right_it_{depth}"
element_expr = self.get_value_eq_expression(
element_type,
f"*{left_it}",
f"*{right_it}",
nullable=effective_element_optional,
depth=depth + 1,
)
return (
f"([&]() {{ if (({left}).size() != ({right}).size()) "
f"{{ return false; }} auto {left_it} = ({left}).begin(); "
f"auto {right_it} = ({right}).begin(); for (; {left_it} != "
f"({left}).end(); ++{left_it}, ++{right_it}) {{ if (!({element_expr})) "
f"{{ return false; }} }} return true; }}())"
)

if isinstance(field_type, MapType):
# First compare size, then compare values for the same key.
entry = f"_fory_left_entry_{depth}"
right_it = f"_fory_right_it_{depth}"
value_expr = self.get_value_eq_expression(
field_type.value_type,
f"{entry}.second",
f"{right_it}->second",
nullable=field_type.value_optional,
depth=depth + 1,
)
return (
f"([&]() {{ if (({left}).size() != ({right}).size()) "
f"{{ return false; }} for (const auto& {entry} : ({left})) "
f"{{ auto {right_it} = ({right}).find({entry}.first); "
f"if ({right_it} == ({right}).end()) {{ return false; }} "
f"if (!({value_expr})) {{ return false; }} }} return true; }}())"
)

return f"({left} == {right})"

def is_numeric_field(self, field: Field) -> bool:
if not isinstance(field.field_type, PrimitiveType):
return False
Expand Down Expand Up @@ -1072,7 +1168,24 @@ def generate_union_definition(
lines.append(
f"{body_indent} bool operator==(const {class_name}& other) const {{"
)
lines.append(f"{body_indent} return value_ == other.value_;")
lines.append(f"{body_indent} if (value_.index() != other.value_.index()) {{")
lines.append(f"{body_indent} return false;")
lines.append(f"{body_indent} }}")
lines.append(f"{body_indent} switch (value_.index()) {{")
for index, field in enumerate(union.fields):
left_value = f"std::get<{index}>(value_)"
right_value = f"std::get<{index}>(other.value_)"
eq_expression = self.get_value_eq_expression(
field.field_type,
left_value,
right_value,
element_optional=field.element_optional,
)
lines.append(f"{body_indent} case {index}:")
lines.append(f"{body_indent} return {eq_expression};")
lines.append(f"{body_indent} default:")
lines.append(f"{body_indent} return false;")
lines.append(f"{body_indent} }}")
lines.append(f"{body_indent} }}")
lines.append("")

Expand Down
Loading