diff --git a/common/BUILD b/common/BUILD index 0ead8b15a..846b6f77f 100644 --- a/common/BUILD +++ b/common/BUILD @@ -68,7 +68,10 @@ cc_test( srcs = ["expr_test.cc"], deps = [ ":expr", + ":expr_factory", "//internal:testing", + "//parser:macro_expr_factory", + "@com_google_absl//absl/strings:string_view", ], ) diff --git a/common/expr_factory.h b/common/expr_factory.h index b9769b457..773217ad9 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -352,6 +352,29 @@ class ExprFactory { return expr; } + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewBind(NextIdFunc next_id, BindVar bind_var, BindExpr bind_expr, + RestExpr rest_expr) { + Expr expr; + expr.set_id(next_id()); + auto& comprehension_expr = expr.mutable_comprehension_expr(); + comprehension_expr.set_iter_var("#unused"); + comprehension_expr.set_iter_range( + NewList(next_id(), std::vector{})); + comprehension_expr.set_accu_var(bind_var); + comprehension_expr.set_accu_init(std::move(bind_expr)); + comprehension_expr.set_loop_condition(NewBoolConst(next_id(), false)); + comprehension_expr.set_loop_step(NewIdent(next_id(), bind_var)); + comprehension_expr.set_result(std::move(rest_expr)); + return expr; + } + private: friend class MacroExprFactory; friend class ParserMacroExprFactory; diff --git a/common/expr_test.cc b/common/expr_test.cc index 4f30dbd6f..0416c6961 100644 --- a/common/expr_test.cc +++ b/common/expr_test.cc @@ -14,11 +14,46 @@ #include "common/expr.h" +#include #include +#include "absl/strings/string_view.h" +#include "common/expr_factory.h" #include "internal/testing.h" +#include "parser/macro_expr_factory.h" namespace cel { +class TestMacroExprFactory final : public MacroExprFactory { + public: + TestMacroExprFactory() = default; + + Expr ReportError(absl::string_view) override { + return NewUnspecified(NextId()); + } + + Expr ReportErrorAt(const Expr&, absl::string_view) override { + return NewUnspecified(NextId()); + } + + using MacroExprFactory::NewBind; + using MacroExprFactory::NewBoolConst; + using MacroExprFactory::NewIdent; + using MacroExprFactory::NewList; + + protected: + ExprId NextId() override { return id_++; } + + ExprId CopyId(ExprId id) override { + if (id == 0) { + return 0; + } + return NextId(); + } + + private: + int64_t id_ = 1; +}; + namespace { using ::testing::_; @@ -670,5 +705,52 @@ TEST(Expr, Id) { EXPECT_THAT(expr.id(), Eq(1)); } +TEST(ExprFactory, NewBind) { + TestMacroExprFactory factory; + Expr bind_expr = factory.NewIdent(10, "x"); + Expr rest_expr = factory.NewIdent(20, "y"); + + auto next_id = [id = 100]() mutable { return id++; }; + + Expr expr = + factory.NewBind(next_id, "a", std::move(bind_expr), std::move(rest_expr)); + + EXPECT_EQ(expr.id(), 100); + ASSERT_TRUE(expr.has_comprehension_expr()); + + const auto& comp = expr.comprehension_expr(); + EXPECT_EQ(comp.iter_var(), "#unused"); + + ASSERT_TRUE(comp.has_iter_range()); + EXPECT_EQ(comp.iter_range().id(), 101); + EXPECT_EQ(comp.iter_range().kind_case(), ExprKindCase::kListExpr); + EXPECT_THAT(comp.iter_range().list_expr().elements(), IsEmpty()); + + EXPECT_EQ(comp.accu_var(), "a"); + + ASSERT_TRUE(comp.has_accu_init()); + Expr expected_bind_expr; + expected_bind_expr.set_id(10); + expected_bind_expr.mutable_ident_expr().set_name("x"); + EXPECT_EQ(comp.accu_init(), expected_bind_expr); + + ASSERT_TRUE(comp.has_loop_condition()); + EXPECT_EQ(comp.loop_condition().id(), 102); + EXPECT_EQ(comp.loop_condition().kind_case(), ExprKindCase::kConstant); + EXPECT_TRUE(comp.loop_condition().const_expr().has_bool_value()); + EXPECT_FALSE(comp.loop_condition().const_expr().bool_value()); + + ASSERT_TRUE(comp.has_loop_step()); + EXPECT_EQ(comp.loop_step().id(), 103); + EXPECT_EQ(comp.loop_step().kind_case(), ExprKindCase::kIdentExpr); + EXPECT_EQ(comp.loop_step().ident_expr().name(), "a"); + + ASSERT_TRUE(comp.has_result()); + Expr expected_rest_expr; + expected_rest_expr.set_id(20); + expected_rest_expr.mutable_ident_expr().set_name("y"); + EXPECT_EQ(comp.result(), expected_rest_expr); +} + } // namespace } // namespace cel