diff --git a/policy/BUILD.bazel b/policy/BUILD.bazel index 5979f1ba7..bce68f001 100644 --- a/policy/BUILD.bazel +++ b/policy/BUILD.bazel @@ -59,3 +59,9 @@ java_library( name = "compiler_builder", exports = ["//policy/src/main/java/dev/cel/policy:compiler_builder"], ) + +java_library( + name = "rule_composer", + visibility = ["//:internal"], + exports = ["//policy/src/main/java/dev/cel/policy:rule_composer"], +) diff --git a/policy/src/main/java/dev/cel/policy/BUILD.bazel b/policy/src/main/java/dev/cel/policy/BUILD.bazel index a7bb90ffe..e0d6af461 100644 --- a/policy/src/main/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/main/java/dev/cel/policy/BUILD.bazel @@ -244,7 +244,6 @@ java_library( java_library( name = "rule_composer", srcs = ["RuleComposer.java"], - visibility = ["//visibility:private"], deps = [ ":compiled_rule", "//bundle:cel", @@ -257,6 +256,8 @@ java_library( "//common/ast:mutable_expr", "//common/formats:value_string", "//common/navigation:mutable_navigation", + "//common/types:cel_types", + "//common/types:type_providers", "//extensions:optional_library", "//optimizer:ast_optimizer", "//optimizer:mutable_ast", diff --git a/policy/src/main/java/dev/cel/policy/RuleComposer.java b/policy/src/main/java/dev/cel/policy/RuleComposer.java index 5fa0957f5..73d31a4ee 100644 --- a/policy/src/main/java/dev/cel/policy/RuleComposer.java +++ b/policy/src/main/java/dev/cel/policy/RuleComposer.java @@ -18,6 +18,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.stream.Collectors.toCollection; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import dev.cel.bundle.Cel; @@ -32,11 +33,14 @@ import dev.cel.common.formats.ValueString; import dev.cel.common.navigation.CelNavigableMutableAst; import dev.cel.common.navigation.CelNavigableMutableExpr; +import dev.cel.common.types.CelType; +import dev.cel.common.types.CelTypes; import dev.cel.extensions.CelOptionalLibrary.Function; import dev.cel.optimizer.AstMutator; import dev.cel.optimizer.CelAstOptimizer; import dev.cel.policy.CelCompiledRule.CelCompiledMatch; import dev.cel.policy.CelCompiledRule.CelCompiledMatch.OutputValue; +import dev.cel.policy.CelCompiledRule.CelCompiledMatch.Result; import dev.cel.policy.CelCompiledRule.CelCompiledVariable; import java.util.ArrayList; import java.util.Arrays; @@ -74,11 +78,15 @@ private Step optimizeRule(Cel cel, CelCompiledRule compiledRule) { } long lastOutputId = 0; + // The expected output type of the rule, used to verify that all branches agree on the type. + CelType lastOutputType = null; for (CelCompiledMatch match : Lists.reverse(compiledRule.matches())) { CelAbstractSyntaxTree conditionAst = match.condition(); boolean isTriviallyTrue = match.isConditionTriviallyTrue(); CelMutableAst condAst = CelMutableAst.fromCelAst(conditionAst); + long currentSourceId = lastOutputId; + switch (match.result().kind()) { case OUTPUT: // If the match has an output, then it is considered a non-optional output since @@ -86,42 +94,54 @@ private Step optimizeRule(Cel cel, CelCompiledRule compiledRule) { // of output being optional.none() will convert the non-optional value to an optional // one. OutputValue matchOutput = match.result().output(); - CelMutableAst outAst = CelMutableAst.fromCelAst(matchOutput.ast()); - Step step = Step.newNonOptionalStep(!isTriviallyTrue, condAst, outAst); + Step step = + Step.newNonOptionalStep( + !isTriviallyTrue, condAst, CelMutableAst.fromCelAst(matchOutput.ast())); + currentSourceId = matchOutput.sourceId(); + output = combine(astMutator, step, output); - assertComposedAstIsValid( - cel, - output.expr, - "incompatible output types found.", - matchOutput.sourceId(), - lastOutputId); - lastOutputId = matchOutput.sourceId(); + String outputFailureMessage = + String.format( + "incompatible output types: block has output type %s, but previous outputs have" + + " type %s", + lastOutputType == null ? "" : CelTypes.format(lastOutputType), + CelTypes.format(matchOutput.ast().getResultType())); + lastOutputType = + assertComposedAstIsValid( + cel, output.expr, outputFailureMessage, currentSourceId, lastOutputId) + .getResultType(); + break; case RULE: // If the match has a nested rule, then compute the rule and whether it has // an optional return value. CelCompiledRule matchNestedRule = match.result().rule(); Step nestedRule = optimizeRule(cel, matchNestedRule); - boolean nestedHasOptional = matchNestedRule.hasOptionalOutput(); - Step ruleStep = - nestedHasOptional - ? Step.newOptionalStep(!isTriviallyTrue, condAst, nestedRule.expr) - : Step.newNonOptionalStep(!isTriviallyTrue, condAst, nestedRule.expr); + new Step( + matchNestedRule.hasOptionalOutput(), !isTriviallyTrue, condAst, nestedRule.expr); + currentSourceId = getFirstOutputSourceId(matchNestedRule); + output = combine(astMutator, ruleStep, output); - assertComposedAstIsValid( - cel, - output.expr, - String.format( - "failed composing the subrule '%s' due to incompatible output types.", - matchNestedRule.ruleId().map(ValueString::value).orElse("")), - lastOutputId); + lastOutputType = + assertComposedAstIsValid( + cel, + output.expr, + String.format( + "failed composing the subrule '%s' due to incompatible output types.", + matchNestedRule.ruleId().map(ValueString::value).orElse("")), + currentSourceId, + lastOutputId) + .getResultType(); break; } + + lastOutputId = currentSourceId; } + Preconditions.checkState(output != null, "Policy contains no outputs."); CelMutableAst resultExpr = output.expr; resultExpr = inlineCompiledVariables(resultExpr, compiledRule.variables()); resultExpr = astMutator.renumberIdsConsecutively(resultExpr); @@ -266,21 +286,34 @@ private CelMutableAst inlineCompiledVariables( return mutatedAst; } - private void assertComposedAstIsValid( + private CelAbstractSyntaxTree assertComposedAstIsValid( Cel cel, CelMutableAst composedAst, String failureMessage, Long... ids) { - assertComposedAstIsValid(cel, composedAst, failureMessage, Arrays.asList(ids)); + return assertComposedAstIsValid(cel, composedAst, failureMessage, Arrays.asList(ids)); } - private void assertComposedAstIsValid( + private CelAbstractSyntaxTree assertComposedAstIsValid( Cel cel, CelMutableAst composedAst, String failureMessage, List ids) { try { - cel.check(composedAst.toParsedAst()).getAst(); + return cel.check(composedAst.toParsedAst()).getAst(); } catch (CelValidationException e) { ids = ids.stream().filter(id -> id > 0).collect(toCollection(ArrayList::new)); throw new RuleCompositionException(failureMessage, e, ids); } } + private static long getFirstOutputSourceId(CelCompiledRule rule) { + for (CelCompiledMatch match : rule.matches()) { + if (match.result().kind() == Result.Kind.OUTPUT) { + return match.result().output().sourceId(); + } else if (match.result().kind() == Result.Kind.RULE) { + return getFirstOutputSourceId(match.result().rule()); + } + } + + // Fallback to the nested rule ID if the policy is invalid and contains no output + return rule.sourceId(); + } + // Step represents an intermediate stage of rule and match expression composition. // // The CelCompiledRule and CelCompiledMatch types are meant to represent standalone tuples of @@ -311,11 +344,6 @@ private Step( this.expr = expr; } - private static Step newOptionalStep( - boolean isConditional, CelMutableAst cond, CelMutableAst expr) { - return new Step(/* isOptional= */ true, isConditional, cond, expr); - } - private static Step newNonOptionalStep( boolean isConditional, CelMutableAst cond, CelMutableAst expr) { return new Step(/* isOptional= */ false, isConditional, cond, expr); diff --git a/policy/src/test/java/dev/cel/policy/BUILD.bazel b/policy/src/test/java/dev/cel/policy/BUILD.bazel index 3089a3849..bc8a5d4b4 100644 --- a/policy/src/test/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/test/java/dev/cel/policy/BUILD.bazel @@ -27,9 +27,11 @@ java_library( "//parser:parser_factory", "//parser:unparser", "//policy", + "//policy:compiled_rule", "//policy:compiler_factory", "//policy:parser", "//policy:parser_factory", + "//policy:rule_composer", "//policy:source", "//policy:validation_exception", "//policy/testing:k8s_test_tag_handler", diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index 416e3b95f..b4065b60c 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -31,6 +31,7 @@ import dev.cel.bundle.CelEnvironmentYamlParser; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelOptions; +import dev.cel.common.formats.ValueString; import dev.cel.common.types.OptionalType; import dev.cel.common.types.SimpleType; import dev.cel.expr.conformance.proto3.TestAllTypes; @@ -356,6 +357,24 @@ public void evaluateYamlPolicy_withSimpleVariable() throws Exception { assertThat(evalResult).isFalse(); } + @Test + public void compose_ruleWithNoOutputs_throws() throws Exception { + Cel cel = newCel(); + CelCompiledRule emptyRule = + CelCompiledRule.create( + 1L, + Optional.of(ValueString.of(2L, "empty_rule")), + ImmutableList.of(), + ImmutableList.of(), + cel); + RuleComposer composer = RuleComposer.newInstance(emptyRule, "variables.", 1000); + CelAbstractSyntaxTree ast = cel.compile("true").getAst(); + + IllegalStateException e = + assertThrows(IllegalStateException.class, () -> composer.optimize(ast, cel)); + assertThat(e).hasMessageThat().isEqualTo("Policy contains no outputs."); + } + private static final class EvaluablePolicyTestData { private final TestYamlPolicy yamlPolicy; private final PolicyTestCase testCase; diff --git a/testing/src/test/resources/policy/compose_errors_conflicting_output/expected_errors.baseline b/testing/src/test/resources/policy/compose_errors_conflicting_output/expected_errors.baseline index 0facbbe2e..bc205c2ab 100644 --- a/testing/src/test/resources/policy/compose_errors_conflicting_output/expected_errors.baseline +++ b/testing/src/test/resources/policy/compose_errors_conflicting_output/expected_errors.baseline @@ -1,6 +1,6 @@ -ERROR: compose_errors_conflicting_output/policy.yaml:22:14: incompatible output types found. +ERROR: compose_errors_conflicting_output/policy.yaml:22:14: incompatible output types: block has output type map(string, bool), but previous outputs have type bool | output: "false" | .............^ -ERROR: compose_errors_conflicting_output/policy.yaml:23:14: incompatible output types found. +ERROR: compose_errors_conflicting_output/policy.yaml:23:14: incompatible output types: block has output type map(string, bool), but previous outputs have type bool | - output: "{'banned': true}" | .............^ \ No newline at end of file diff --git a/testing/src/test/resources/policy/compose_errors_conflicting_subrule/expected_errors.baseline b/testing/src/test/resources/policy/compose_errors_conflicting_subrule/expected_errors.baseline index 92ddff311..66e48ea57 100644 --- a/testing/src/test/resources/policy/compose_errors_conflicting_subrule/expected_errors.baseline +++ b/testing/src/test/resources/policy/compose_errors_conflicting_subrule/expected_errors.baseline @@ -1,3 +1,6 @@ +ERROR: compose_errors_conflicting_subrule/policy.yaml:34:18: failed composing the subrule 'banned regions' due to incompatible output types. + | output: "true" + | .................^ ERROR: compose_errors_conflicting_subrule/policy.yaml:36:14: failed composing the subrule 'banned regions' due to incompatible output types. | output: "{'banned': false}" | .............^ \ No newline at end of file