From 4fe9dc2916d11bac75f32377f755cc12f8ae340b Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sun, 24 May 2026 21:11:32 -0400 Subject: [PATCH 1/3] finish1 --- python/tvm/tirx/__init__.py | 7 +- python/tvm/tirx/functor.py | 394 ++++++++++++++++++ .../python/tirx-transform/test_tir_functor.py | 203 +++++++++ 3 files changed, 603 insertions(+), 1 deletion(-) diff --git a/python/tvm/tirx/__init__.py b/python/tvm/tirx/__init__.py index 00a3522238af..4d805ad077b2 100644 --- a/python/tvm/tirx/__init__.py +++ b/python/tvm/tirx/__init__.py @@ -104,7 +104,12 @@ from . import backend from . import stmt_functor -from .functor import PyStmtExprVisitor, PyStmtExprMutator +from .functor import ( + PyStmtExprVisitor, + PyStmtExprMutator, + PyStmtExprVisitorWithAnalyzer, + PyStmtExprMutatorWithAnalyzer, +) # Compiler-only submodules. Skip under `TVM_USE_RUNTIME_LIB=1` since they # perform compiler-side FFI at module load (schema engine looks up diff --git a/python/tvm/tirx/functor.py b/python/tvm/tirx/functor.py index 4619c0b51fbb..1a379cdec8f3 100644 --- a/python/tvm/tirx/functor.py +++ b/python/tvm/tirx/functor.py @@ -38,12 +38,14 @@ Broadcast, BufferLoad, Call, + CallEffectKind, Cast, Div, FloatImm, FloorDiv, FloorMod, IntImm, + IterVar, Let, Max, Min, @@ -945,6 +947,219 @@ def visit_string_imm_(self, op: StringImm) -> None: _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore +class _AnalyzerContextMixin: + """Shared analyzer context helpers for Python functors.""" + + def _init_analyzer(self, analyzer): + if analyzer is None: + from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel + + analyzer = Analyzer() + self.analyzer = analyzer + self._constraint_scopes = [] + + def _real_condition(self, condition): + if isinstance(condition, Call) and getattr(condition.op, "name", None) == "tirx.likely": + return condition.args[0] + return condition + + def _negated_condition(self, condition): + return self.analyzer.rewrite_simplify(Not(condition)) + + def _is_pure(self, expr): + if isinstance(expr, BufferLoad | ProducerLoad): + return False + if isinstance(expr, Call): + try: + effect = expr.op.get_attr("TCallEffectKind") + except AttributeError: + return False + if effect is None: + return False + effect_value = getattr(effect, "value", effect) + return effect_value <= CallEffectKind.Pure.value and all( + self._is_pure(arg) for arg in expr.args + ) + if isinstance(expr, Let): + return self._is_pure(expr.value) and self._is_pure(expr.body) + if isinstance(expr, Reduce): + return ( + all(self._is_pure(source) for source in expr.source) + and all(self._is_pure(init) for init in expr.init) + and self._is_pure(expr.condition) + ) + if isinstance(expr, Select): + return ( + self._is_pure(expr.condition) + and self._is_pure(expr.true_value) + and self._is_pure(expr.false_value) + ) + if isinstance(expr, Ramp): + return ( + self._is_pure(expr.base) + and self._is_pure(expr.stride) + and self._is_pure(expr.lanes) + ) + if isinstance(expr, Broadcast): + return self._is_pure(expr.value) and self._is_pure(expr.lanes) + if isinstance(expr, Shuffle): + return all(self._is_pure(vec) for vec in expr.vectors) and all( + self._is_pure(index) for index in expr.indices + ) + if isinstance(expr, Cast): + return self._is_pure(expr.value) + if isinstance(expr, Not): + return self._is_pure(expr.a) + if isinstance( + expr, + Add + | Sub + | Mul + | Div + | Mod + | FloorDiv + | FloorMod + | Min + | Max + | EQ + | NE + | LT + | LE + | GT + | GE + | And + | Or, + ): + return self._is_pure(expr.a) and self._is_pure(expr.b) + + return isinstance(expr, Var | SizeVar | IntImm | FloatImm | StringImm) + + def _push_constraint(self, constraint): + scope = self.analyzer.constraint_scope(constraint) + scope.__enter__() + self._constraint_scopes.append(scope) + + def _pop_constraints(self, depth): + while len(self._constraint_scopes) > depth: + self._constraint_scopes.pop().__exit__(None, None, None) + + +class PyStmtExprVisitorWithAnalyzer(PyStmtExprVisitor, _AnalyzerContextMixin): + """A Python StmtExprVisitor that maintains an arithmetic analyzer context. + + The analyzer is available as ``self.analyzer`` from user callbacks. The default + traversal binds loop variables, block iter variables, let variables, and branch + conditions before visiting nested nodes, so callbacks can query the analyzer + using the surrounding IR context. + """ + + def __init__(self, analyzer=None): + super().__init__() + self._init_analyzer(analyzer) + + def visit_for_(self, op: For) -> None: + from tvm.ir import Range # pylint: disable=import-outside-toplevel + + depth = len(self._constraint_scopes) + self.visit_expr(op.min) + self.visit_expr(op.extent) + if op.step is not None: + self.visit_expr(op.step) + self.analyzer.bind(op.loop_var, Range.from_min_extent(op.min, op.extent)) + try: + self.visit_stmt(op.body) + finally: + self._pop_constraints(depth) + + def visit_attr_stmt_(self, op: AttrStmt) -> None: + from tvm.ir import Range # pylint: disable=import-outside-toplevel + + depth = len(self._constraint_scopes) + self.visit_expr(op.value) + if op.attr_key in ("thread_extent", "virtual_thread") and isinstance(op.node, IterVar): + self.analyzer.bind( + op.node.var, Range.from_min_extent(IntImm(op.value.dtype, 0), op.value) + ) + try: + self.visit_stmt(op.body) + finally: + self._pop_constraints(depth) + + def visit_sblock_(self, op: SBlock) -> None: + depth = len(self._constraint_scopes) + try: + for iter_var in op.iter_vars: + self.analyzer.bind(iter_var.var, iter_var.dom) + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + finally: + self._pop_constraints(depth) + + def visit_bind_(self, op: Bind) -> None: + self.visit_expr(op.value) + if self._is_pure(op.value): + self.analyzer.bind(op.var, op.value) + + def visit_seq_stmt_(self, op: SeqStmt) -> None: + depth = len(self._constraint_scopes) + try: + _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore + finally: + self._pop_constraints(depth) + + def visit_if_then_else_(self, op: IfThenElse) -> None: + condition = self._real_condition(op.condition) + self.visit_expr(op.condition) + depth = len(self._constraint_scopes) + with self.analyzer.constraint_scope(condition): + try: + self.visit_stmt(op.then_case) + finally: + self._pop_constraints(depth) + if op.else_case is not None: + with self.analyzer.constraint_scope(self._negated_condition(condition)): + try: + self.visit_stmt(op.else_case) + finally: + self._pop_constraints(depth) + + def visit_assert_stmt_(self, op: AssertStmt) -> None: + self.visit_expr(op.condition) + self._push_constraint(op.condition) + self.visit_expr(op.error_kind) + for msg in op.message_parts: + self.visit_expr(msg) + + def visit_let_(self, op: Let) -> None: + self.visit_expr(op.value) + if self._is_pure(op.value): + self.analyzer.bind(op.var, op.value) + self.visit_expr(op.body) + + def visit_call_(self, op: Call) -> None: + if getattr(op.op, "name", None) != "tirx.if_then_else": + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + return + + condition = op.args[0] + self.visit_expr(condition) + with self.analyzer.constraint_scope(condition): + self.visit_expr(op.args[1]) + with self.analyzer.constraint_scope(self._negated_condition(condition)): + self.visit_expr(op.args[2]) + + def visit_select_(self, op: Select) -> None: + self.visit_expr(op.condition) + with self.analyzer.constraint_scope(op.condition): + self.visit_expr(op.true_value) + with self.analyzer.constraint_scope(self._negated_condition(op.condition)): + self.visit_expr(op.false_value) + + def visit_reduce_(self, op: Reduce) -> None: + for iter_var in op.axis: + self.analyzer.bind(iter_var.var, iter_var.dom) + _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + + @tvm_ffi.register_object("tirx.PyStmtExprMutator") class _PyStmtExprMutator(tvm_ffi.core.Object): """ @@ -1976,3 +2191,182 @@ def visit_string_imm_(self, op: StringImm) -> PrimExpr: The mutated PrimExpr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + +class PyStmtExprMutatorWithAnalyzer(PyStmtExprMutator, _AnalyzerContextMixin): + """A Python StmtExprMutator that maintains an arithmetic analyzer context. + + The analyzer is available as ``self.analyzer`` from user callbacks. The default + mutation binds loop variables, block iter variables, let variables, and branch + conditions before mutating nested nodes, so callbacks can query the analyzer + using the surrounding IR context. + """ + + def __init__(self, analyzer=None): + super().__init__() + self._init_analyzer(analyzer) + + def visit_for_(self, op: For) -> Stmt: + from tvm.ir import Range # pylint: disable=import-outside-toplevel + + depth = len(self._constraint_scopes) + min_value = self.visit_expr(op.min) + extent = self.visit_expr(op.extent) + step = self.visit_expr(op.step) if op.step is not None else None + self.analyzer.bind(op.loop_var, Range.from_min_extent(min_value, extent)) + try: + body = self.visit_stmt(op.body) + finally: + self._pop_constraints(depth) + if ( + min_value.same_as(op.min) + and extent.same_as(op.extent) + and body.same_as(op.body) + and ( + (step is None and op.step is None) + or (step is not None and op.step is not None and step.same_as(op.step)) + ) + ): + return op + return For( + op.loop_var, + min_value, + extent, + op.kind, + body, + op.thread_binding, + op.annotations, + step, + op.span, + ) + + def visit_attr_stmt_(self, op: AttrStmt) -> Stmt: + from tvm.ir import Range # pylint: disable=import-outside-toplevel + + depth = len(self._constraint_scopes) + value = self.visit_expr(op.value) + if op.attr_key in ("thread_extent", "virtual_thread") and isinstance(op.node, IterVar): + self.analyzer.bind(op.node.var, Range.from_min_extent(IntImm(value.dtype, 0), value)) + try: + body = self.visit_stmt(op.body) + finally: + self._pop_constraints(depth) + if value.same_as(op.value) and body.same_as(op.body): + return op + return AttrStmt(op.node, op.attr_key, value, body, op.span) + + def visit_sblock_(self, op: SBlock) -> Stmt: + depth = len(self._constraint_scopes) + try: + for iter_var in op.iter_vars: + self.analyzer.bind(iter_var.var, iter_var.dom) + return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore + finally: + self._pop_constraints(depth) + + def visit_bind_(self, op: Bind) -> Stmt: + value = self.visit_expr(op.value) + if self._is_pure(value): + self.analyzer.bind(op.var, value) + if value.same_as(op.value): + return op + return Bind(op.var, value, op.span) + + def visit_seq_stmt_(self, op: SeqStmt) -> Stmt: + depth = len(self._constraint_scopes) + try: + seq = [self.visit_stmt(stmt) for stmt in op.seq] + finally: + self._pop_constraints(depth) + if all(new.same_as(old) for new, old in zip(seq, op.seq)): + return op + return SeqStmt(seq, op.span) + + def visit_if_then_else_(self, op: IfThenElse) -> Stmt: + condition = self.visit_expr(op.condition) + real_condition = self._real_condition(condition) + depth = len(self._constraint_scopes) + with self.analyzer.constraint_scope(real_condition): + try: + then_case = self.visit_stmt(op.then_case) + finally: + self._pop_constraints(depth) + else_case = None + if op.else_case is not None: + with self.analyzer.constraint_scope(self._negated_condition(real_condition)): + try: + else_case = self.visit_stmt(op.else_case) + finally: + self._pop_constraints(depth) + if ( + condition.same_as(op.condition) + and then_case.same_as(op.then_case) + and ( + (else_case is None and op.else_case is None) + or ( + else_case is not None + and op.else_case is not None + and else_case.same_as(op.else_case) + ) + ) + ): + return op + return IfThenElse(condition, then_case, else_case, op.span) + + def visit_assert_stmt_(self, op: AssertStmt) -> Stmt: + condition = self.visit_expr(op.condition) + self._push_constraint(condition) + message_parts = [self.visit_expr(msg) for msg in op.message_parts] + if condition.same_as(op.condition) and all( + new.same_as(old) for new, old in zip(message_parts, op.message_parts) + ): + return op + return AssertStmt(condition, op.error_kind, message_parts, op.span) + + def visit_let_(self, op: Let) -> PrimExpr: + value = self.visit_expr(op.value) + if self._is_pure(value): + self.analyzer.bind(op.var, value) + body = self.visit_expr(op.body) + if value.same_as(op.value) and body.same_as(op.body): + return op + return Let(op.var, value, body, op.span) + + def visit_reduce_(self, op: Reduce) -> PrimExpr: + for iter_var in op.axis: + self.analyzer.bind(iter_var.var, iter_var.dom) + return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + def visit_call_(self, op: Call) -> PrimExpr: + if getattr(op.op, "name", None) != "tirx.if_then_else": + args = [self.visit_expr(arg) for arg in op.args] + if all(new.same_as(old) for new, old in zip(args, op.args)): + return op + return Call(op.dtype, op.op, args, op.annotations, op.span) + + condition = self.visit_expr(op.args[0]) + with self.analyzer.constraint_scope(condition): + true_value = self.visit_expr(op.args[1]) + with self.analyzer.constraint_scope(self._negated_condition(condition)): + false_value = self.visit_expr(op.args[2]) + if ( + condition.same_as(op.args[0]) + and true_value.same_as(op.args[1]) + and false_value.same_as(op.args[2]) + ): + return op + return Call(op.dtype, op.op, [condition, true_value, false_value], op.annotations, op.span) + + def visit_select_(self, op: Select) -> PrimExpr: + condition = self.visit_expr(op.condition) + with self.analyzer.constraint_scope(condition): + true_value = self.visit_expr(op.true_value) + with self.analyzer.constraint_scope(self._negated_condition(condition)): + false_value = self.visit_expr(op.false_value) + if ( + condition.same_as(op.condition) + and true_value.same_as(op.true_value) + and false_value.same_as(op.false_value) + ): + return op + return Select(condition, true_value, false_value, op.span) diff --git a/tests/python/tirx-transform/test_tir_functor.py b/tests/python/tirx-transform/test_tir_functor.py index 021acd8fb60b..fb23bc14adf2 100644 --- a/tests/python/tirx-transform/test_tir_functor.py +++ b/tests/python/tirx-transform/test_tir_functor.py @@ -21,8 +21,10 @@ from tvm import tirx from tvm.tirx import ( EQ, + GE, LT, Add, + AssertStmt, Cast, Evaluate, FloatImm, @@ -33,7 +35,11 @@ Min, Mul, PyStmtExprMutator, + PyStmtExprMutatorWithAnalyzer, PyStmtExprVisitor, + PyStmtExprVisitorWithAnalyzer, + Select, + SeqStmt, StringImm, Sub, Var, @@ -202,6 +208,65 @@ def visit_add_(self, op: Add): return Add(Mul(a, IntImm("int32", 2)), b) +@tirx.functor.visitor +class AnalyzerAwareVisitor(PyStmtExprVisitorWithAnalyzer): + """Record analyzer facts visible from Python visitor callbacks.""" + + def __init__(self, var): + super().__init__() + self.var = var + self.facts = [] + + def visit_evaluate_(self, op: Evaluate): + if op.value.same_as(self.var): + self.facts.append(self.analyzer.can_prove(GE(self.var, IntImm("int32", 0)))) + self.facts.append(self.analyzer.can_prove(LT(self.var, IntImm("int32", 10)))) + super().visit_evaluate_(op) + + +@tirx.functor.mutator +class AnalyzerAwareMutator(PyStmtExprMutatorWithAnalyzer): + """Use branch constraints from the analyzer to rewrite proven predicates.""" + + def _rewrite_if_proven(self, value): + if value.dtype == "bool" and self.analyzer.can_prove(value): + return IntImm("bool", True) + return value + + def visit_lt_(self, op: LT): + a = self.visit_expr(op.a) + b = self.visit_expr(op.b) + value = op if a.same_as(op.a) and b.same_as(op.b) else LT(a, b, op.span) + return self._rewrite_if_proven(value) + + def visit_ge_(self, op: GE): + a = self.visit_expr(op.a) + b = self.visit_expr(op.b) + value = op if a.same_as(op.a) and b.same_as(op.b) else GE(a, b, op.span) + return self._rewrite_if_proven(value) + + def visit_evaluate_(self, op: Evaluate): + value = self.visit_expr(op.value) + value = self._rewrite_if_proven(value) + if value.same_as(op.value): + return op + return Evaluate(value, op.span) + + +@tirx.functor.visitor +class PredicateVisitor(PyStmtExprVisitorWithAnalyzer): + """Record whether boolean Evaluate nodes are provable in analyzer context.""" + + def __init__(self): + super().__init__() + self.facts = [] + + def visit_evaluate_(self, op: Evaluate): + if op.value.dtype == "bool": + self.facts.append(self.analyzer.can_prove(op.value)) + super().visit_evaluate_(op) + + def test_basic_visitor(): """Test the basic AST printer visitor""" expr = Add(Var("x", dtype="int32"), Var("y", dtype="int32")) @@ -330,6 +395,144 @@ def test_complex_mutator(): assert isinstance(modified_expr.a, Mul) # First operand should be multiplied by 2 +def test_analyzer_aware_visitor_loop_context(): + """Test that analyzer-aware visitors expose loop bounds to Python callbacks.""" + i = Var("i", dtype="int32") + stmt = For( + i, + IntImm("int32", 0), + IntImm("int32", 10), + tirx.ForKind.SERIAL, + Evaluate(i), + ) + + visitor = AnalyzerAwareVisitor(i) + visitor.visit_stmt(stmt) + + assert visitor.facts == [True, True] + + +def test_analyzer_aware_mutator_branch_context(): + """Test that analyzer-aware mutators expose branch predicates to Python callbacks.""" + x = Var("x", dtype="int32") + stmt = IfThenElse( + LT(x, IntImm("int32", 4)), + Evaluate(LT(x, IntImm("int32", 8))), + Evaluate(GE(x, IntImm("int32", 4))), + ) + + result = AnalyzerAwareMutator().visit_stmt(stmt) + + assert isinstance(result, IfThenElse) + assert isinstance(result.then_case.value, IntImm) + assert isinstance(result.else_case.value, IntImm) + assert result.then_case.value.value == 1 + assert result.else_case.value.value == 1 + + +def test_analyzer_aware_visitor_assert_context(): + """Test that assert constraints are visible to later statements in the same sequence.""" + x = Var("x", dtype="int32") + stmt = SeqStmt( + [ + AssertStmt(LT(x, IntImm("int32", 4)), StringImm("ValueError")), + Evaluate(LT(x, IntImm("int32", 8))), + ] + ) + + visitor = PredicateVisitor() + visitor.visit_stmt(stmt) + + assert visitor.facts == [True] + + +def test_analyzer_aware_visitor_pure_bind_context(): + """Test that pure Bind values are visible to later statements in the same sequence.""" + x = Var("x", dtype="int32") + stmt = SeqStmt( + [ + tirx.Bind(x, IntImm("int32", 4)), + Evaluate(GE(x, IntImm("int32", 4))), + ] + ) + + visitor = PredicateVisitor() + visitor.visit_stmt(stmt) + + assert visitor.facts == [True] + + +def test_analyzer_aware_mutator_skips_opaque_bind_context(): + """Test that opaque Bind values are not inserted into the analyzer context.""" + h = Var("h", dtype="handle") + stmt = SeqStmt( + [ + tirx.Bind(h, tirx.tvm_stack_alloca("tvm_ffi_any", 1)), + Evaluate(IntImm("int32", 0)), + ] + ) + + result = AnalyzerAwareMutator().visit_stmt(stmt) + + assert isinstance(result, SeqStmt) + + +def test_analyzer_aware_visitor_branch_assert_does_not_leak(): + """Test that assert constraints inside a branch do not leak to following statements.""" + x = Var("x", dtype="int32") + stmt = SeqStmt( + [ + IfThenElse( + LT(x, IntImm("int32", 4)), + AssertStmt(LT(x, IntImm("int32", 1)), StringImm("ValueError")), + None, + ), + Evaluate(LT(x, IntImm("int32", 1))), + ] + ) + + visitor = PredicateVisitor() + visitor.visit_stmt(stmt) + + assert visitor.facts == [False] + + +def test_analyzer_aware_mutator_select_context(): + """Test that analyzer-aware mutators expose Select branch predicates.""" + x = Var("x", dtype="int32") + stmt = Evaluate( + Select( + LT(x, IntImm("int32", 4)), + LT(x, IntImm("int32", 8)), + GE(x, IntImm("int32", 4)), + ) + ) + + result = AnalyzerAwareMutator().visit_stmt(stmt) + + assert isinstance(result.value, IntImm) + assert result.value.value == 1 + + +def test_analyzer_aware_mutator_if_then_else_call_context(): + """Test that analyzer-aware mutators expose tirx.if_then_else expression predicates.""" + x = Var("x", dtype="int32") + stmt = Evaluate( + tirx.if_then_else( + LT(x, IntImm("int32", 4)), + LT(x, IntImm("int32", 8)), + GE(x, IntImm("int32", 4)), + ) + ) + + result = AnalyzerAwareMutator().visit_stmt(stmt) + + assert isinstance(result.value.args[1], IntImm) + assert isinstance(result.value.args[2], IntImm) + assert result.value.args[1].value == 1 + assert result.value.args[2].value == 1 + + def test_different_expr_types(): """Test visitor with various expression types""" x = Var("x", dtype="int32") From 1fadf0454bb210a85a5ec4c2a30f76bc1242dd48 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sun, 24 May 2026 22:06:49 -0400 Subject: [PATCH 2/3] finish3 --- python/tvm/tirx/functor.py | 760 +++++++------- src/tirx/ir/py_functor.cc | 924 +++++++++++++++++- .../python/tirx-transform/test_tir_functor.py | 41 +- 3 files changed, 1344 insertions(+), 381 deletions(-) diff --git a/python/tvm/tirx/functor.py b/python/tvm/tirx/functor.py index 1a379cdec8f3..007b94083422 100644 --- a/python/tvm/tirx/functor.py +++ b/python/tvm/tirx/functor.py @@ -38,14 +38,12 @@ Broadcast, BufferLoad, Call, - CallEffectKind, Cast, Div, FloatImm, FloorDiv, FloorMod, IntImm, - IterVar, Let, Max, Min, @@ -266,6 +264,121 @@ def __init__( ) +@tvm_ffi.register_object("tirx.PyStmtExprVisitorWithAnalyzer") +class _PyStmtExprVisitorWithAnalyzer(tvm_ffi.core.Object): + """ + An internal C++-backed wrapper for analyzer-aware Python StmtExprVisitor. + """ + + def __init__( + self, + f_visit_stmt: Callable | None = None, + f_visit_expr: Callable | None = None, + # Stmt + f_visit_bind: Callable | None = None, + f_visit_attr_stmt: Callable | None = None, + f_visit_if_then_else: Callable | None = None, + f_visit_for: Callable | None = None, + f_visit_while: Callable | None = None, + f_visit_alloc_buffer: Callable | None = None, + f_visit_decl_buffer: Callable | None = None, + f_visit_buffer_store: Callable | None = None, + f_visit_assert_stmt: Callable | None = None, + f_visit_seq_stmt: Callable | None = None, + f_visit_evaluate: Callable | None = None, + f_visit_block: Callable | None = None, + f_visit_sblock_realize: Callable | None = None, + # PrimExpr + f_visit_var: Callable | None = None, + f_visit_size_var: Callable | None = None, + f_visit_buffer_load: Callable | None = None, + f_visit_producer_load: Callable | None = None, + f_visit_let: Callable | None = None, + f_visit_call: Callable | None = None, + f_visit_add: Callable | None = None, + f_visit_sub: Callable | None = None, + f_visit_mul: Callable | None = None, + f_visit_div: Callable | None = None, + f_visit_mod: Callable | None = None, + f_visit_floor_div: Callable | None = None, + f_visit_floor_mod: Callable | None = None, + f_visit_min: Callable | None = None, + f_visit_max: Callable | None = None, + f_visit_eq: Callable | None = None, + f_visit_ne: Callable | None = None, + f_visit_lt: Callable | None = None, + f_visit_le: Callable | None = None, + f_visit_gt: Callable | None = None, + f_visit_ge: Callable | None = None, + f_visit_and: Callable | None = None, + f_visit_or: Callable | None = None, + f_visit_reduce: Callable | None = None, + f_visit_cast: Callable | None = None, + f_visit_not: Callable | None = None, + f_visit_select: Callable | None = None, + f_visit_ramp: Callable | None = None, + f_visit_broadcast: Callable | None = None, + f_visit_shuffle: Callable | None = None, + f_visit_int_imm: Callable | None = None, + f_visit_float_imm: Callable | None = None, + f_visit_string_imm: Callable | None = None, + ) -> None: + """Constructor.""" + self.__init_handle_by_constructor__( + _ffi_api.MakePyStmtExprVisitorWithAnalyzer, # type: ignore + [ + f_visit_stmt, + f_visit_expr, + f_visit_bind, + f_visit_attr_stmt, + f_visit_if_then_else, + f_visit_for, + f_visit_while, + f_visit_alloc_buffer, + f_visit_decl_buffer, + f_visit_buffer_store, + f_visit_assert_stmt, + f_visit_seq_stmt, + f_visit_evaluate, + f_visit_block, + f_visit_sblock_realize, + f_visit_var, + f_visit_size_var, + f_visit_buffer_load, + f_visit_producer_load, + f_visit_let, + f_visit_call, + f_visit_add, + f_visit_sub, + f_visit_mul, + f_visit_div, + f_visit_mod, + f_visit_floor_div, + f_visit_floor_mod, + f_visit_min, + f_visit_max, + f_visit_eq, + f_visit_ne, + f_visit_lt, + f_visit_le, + f_visit_gt, + f_visit_ge, + f_visit_and, + f_visit_or, + f_visit_reduce, + f_visit_cast, + f_visit_not, + f_visit_select, + f_visit_ramp, + f_visit_broadcast, + f_visit_shuffle, + f_visit_int_imm, + f_visit_float_imm, + f_visit_string_imm, + ], + ) + + class PyStmtExprVisitor: """ A Python StmtExprVisitor to define custom visitor for both Stmt and PrimExpr. @@ -947,217 +1060,53 @@ def visit_string_imm_(self, op: StringImm) -> None: _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore -class _AnalyzerContextMixin: - """Shared analyzer context helpers for Python functors.""" - - def _init_analyzer(self, analyzer): +def _analyzer_from_module(mod): + from tvm.arith.analyzer import Analyzer # pylint: disable=import-outside-toplevel + + analyzer = Analyzer.__new__(Analyzer) + analyzer._const_int_bound = mod("const_int_bound") + analyzer._const_int_bound_update = mod("const_int_bound_update") + analyzer._const_int_bound_is_bound = mod("const_int_bound_is_bound") + analyzer._bind = mod("bind") + analyzer._modular_set = mod("modular_set") + analyzer._simplify = mod("Simplify") + analyzer._rewrite_simplify = mod("rewrite_simplify") + analyzer._get_rewrite_simplify_stats = mod("get_rewrite_simplify_stats") + analyzer._reset_rewrite_simplify_stats = mod("reset_rewrite_simplify_stats") + analyzer._canonical_simplify = mod("canonical_simplify") + analyzer._int_set = mod("int_set") + analyzer._enter_constraint_context = mod("enter_constraint_context") + analyzer._can_prove_equal = mod("can_prove_equal") + analyzer._can_prove = mod("can_prove") + analyzer._get_enabled_extensions = mod("get_enabled_extensions") + analyzer._set_enabled_extensions = mod("set_enabled_extensions") + return analyzer + + +class _AnalyzerBackedVisitorMixin: + @property + def analyzer(self): + analyzer = getattr(self, "_analyzer", None) if analyzer is None: - from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel - - analyzer = Analyzer() - self.analyzer = analyzer - self._constraint_scopes = [] - - def _real_condition(self, condition): - if isinstance(condition, Call) and getattr(condition.op, "name", None) == "tirx.likely": - return condition.args[0] - return condition - - def _negated_condition(self, condition): - return self.analyzer.rewrite_simplify(Not(condition)) - - def _is_pure(self, expr): - if isinstance(expr, BufferLoad | ProducerLoad): - return False - if isinstance(expr, Call): - try: - effect = expr.op.get_attr("TCallEffectKind") - except AttributeError: - return False - if effect is None: - return False - effect_value = getattr(effect, "value", effect) - return effect_value <= CallEffectKind.Pure.value and all( - self._is_pure(arg) for arg in expr.args - ) - if isinstance(expr, Let): - return self._is_pure(expr.value) and self._is_pure(expr.body) - if isinstance(expr, Reduce): - return ( - all(self._is_pure(source) for source in expr.source) - and all(self._is_pure(init) for init in expr.init) - and self._is_pure(expr.condition) - ) - if isinstance(expr, Select): - return ( - self._is_pure(expr.condition) - and self._is_pure(expr.true_value) - and self._is_pure(expr.false_value) - ) - if isinstance(expr, Ramp): - return ( - self._is_pure(expr.base) - and self._is_pure(expr.stride) - and self._is_pure(expr.lanes) - ) - if isinstance(expr, Broadcast): - return self._is_pure(expr.value) and self._is_pure(expr.lanes) - if isinstance(expr, Shuffle): - return all(self._is_pure(vec) for vec in expr.vectors) and all( - self._is_pure(index) for index in expr.indices - ) - if isinstance(expr, Cast): - return self._is_pure(expr.value) - if isinstance(expr, Not): - return self._is_pure(expr.a) - if isinstance( - expr, - Add - | Sub - | Mul - | Div - | Mod - | FloorDiv - | FloorMod - | Min - | Max - | EQ - | NE - | LT - | LE - | GT - | GE - | And - | Or, - ): - return self._is_pure(expr.a) and self._is_pure(expr.b) - - return isinstance(expr, Var | SizeVar | IntImm | FloatImm | StringImm) - - def _push_constraint(self, constraint): - scope = self.analyzer.constraint_scope(constraint) - scope.__enter__() - self._constraint_scopes.append(scope) - - def _pop_constraints(self, depth): - while len(self._constraint_scopes) > depth: - self._constraint_scopes.pop().__exit__(None, None, None) - - -class PyStmtExprVisitorWithAnalyzer(PyStmtExprVisitor, _AnalyzerContextMixin): - """A Python StmtExprVisitor that maintains an arithmetic analyzer context. - - The analyzer is available as ``self.analyzer`` from user callbacks. The default - traversal binds loop variables, block iter variables, let variables, and branch - conditions before visiting nested nodes, so callbacks can query the analyzer - using the surrounding IR context. - """ - - def __init__(self, analyzer=None): - super().__init__() - self._init_analyzer(analyzer) - - def visit_for_(self, op: For) -> None: - from tvm.ir import Range # pylint: disable=import-outside-toplevel - - depth = len(self._constraint_scopes) - self.visit_expr(op.min) - self.visit_expr(op.extent) - if op.step is not None: - self.visit_expr(op.step) - self.analyzer.bind(op.loop_var, Range.from_min_extent(op.min, op.extent)) - try: - self.visit_stmt(op.body) - finally: - self._pop_constraints(depth) + mod = _ffi_api.PyStmtExprVisitorWithAnalyzerGetAnalyzer(self._outer()) # type: ignore + analyzer = _analyzer_from_module(mod) + self._analyzer = analyzer + return analyzer - def visit_attr_stmt_(self, op: AttrStmt) -> None: - from tvm.ir import Range # pylint: disable=import-outside-toplevel - - depth = len(self._constraint_scopes) - self.visit_expr(op.value) - if op.attr_key in ("thread_extent", "virtual_thread") and isinstance(op.node, IterVar): - self.analyzer.bind( - op.node.var, Range.from_min_extent(IntImm(op.value.dtype, 0), op.value) - ) - try: - self.visit_stmt(op.body) - finally: - self._pop_constraints(depth) - - def visit_sblock_(self, op: SBlock) -> None: - depth = len(self._constraint_scopes) - try: - for iter_var in op.iter_vars: - self.analyzer.bind(iter_var.var, iter_var.dom) - _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore - finally: - self._pop_constraints(depth) - - def visit_bind_(self, op: Bind) -> None: - self.visit_expr(op.value) - if self._is_pure(op.value): - self.analyzer.bind(op.var, op.value) - - def visit_seq_stmt_(self, op: SeqStmt) -> None: - depth = len(self._constraint_scopes) - try: - _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore - finally: - self._pop_constraints(depth) - - def visit_if_then_else_(self, op: IfThenElse) -> None: - condition = self._real_condition(op.condition) - self.visit_expr(op.condition) - depth = len(self._constraint_scopes) - with self.analyzer.constraint_scope(condition): - try: - self.visit_stmt(op.then_case) - finally: - self._pop_constraints(depth) - if op.else_case is not None: - with self.analyzer.constraint_scope(self._negated_condition(condition)): - try: - self.visit_stmt(op.else_case) - finally: - self._pop_constraints(depth) - - def visit_assert_stmt_(self, op: AssertStmt) -> None: - self.visit_expr(op.condition) - self._push_constraint(op.condition) - self.visit_expr(op.error_kind) - for msg in op.message_parts: - self.visit_expr(msg) - def visit_let_(self, op: Let) -> None: - self.visit_expr(op.value) - if self._is_pure(op.value): - self.analyzer.bind(op.var, op.value) - self.visit_expr(op.body) - - def visit_call_(self, op: Call) -> None: - if getattr(op.op, "name", None) != "tirx.if_then_else": - _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore - return +class PyStmtExprVisitorWithAnalyzer(PyStmtExprVisitor, _AnalyzerBackedVisitorMixin): + """A C++-backed Python StmtExprVisitor with an arithmetic analyzer context.""" - condition = op.args[0] - self.visit_expr(condition) - with self.analyzer.constraint_scope(condition): - self.visit_expr(op.args[1]) - with self.analyzer.constraint_scope(self._negated_condition(condition)): - self.visit_expr(op.args[2]) + _tvm_metadata = { + **PyStmtExprVisitor._tvm_metadata, + "cls": _PyStmtExprVisitorWithAnalyzer, + } - def visit_select_(self, op: Select) -> None: - self.visit_expr(op.condition) - with self.analyzer.constraint_scope(op.condition): - self.visit_expr(op.true_value) - with self.analyzer.constraint_scope(self._negated_condition(op.condition)): - self.visit_expr(op.false_value) + def visit_stmt(self, stmt: Stmt) -> None: + _ffi_api.PyStmtExprVisitorWithAnalyzerVisitStmt(self._outer(), stmt) # type: ignore - def visit_reduce_(self, op: Reduce) -> None: - for iter_var in op.axis: - self.analyzer.bind(iter_var.var, iter_var.dom) - _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore + def visit_expr(self, expr: PrimExpr) -> None: + _ffi_api.PyStmtExprVisitorWithAnalyzerVisitExpr(self._outer(), expr) # type: ignore @tvm_ffi.register_object("tirx.PyStmtExprMutator") @@ -1280,6 +1229,121 @@ def __init__( ) +@tvm_ffi.register_object("tirx.PyStmtExprMutatorWithAnalyzer") +class _PyStmtExprMutatorWithAnalyzer(tvm_ffi.core.Object): + """ + An internal C++-backed wrapper for analyzer-aware Python StmtExprMutator. + """ + + def __init__( + self, + f_visit_stmt: Callable | None = None, + f_visit_expr: Callable | None = None, + # Stmt + f_visit_bind: Callable | None = None, + f_visit_attr_stmt: Callable | None = None, + f_visit_if_then_else: Callable | None = None, + f_visit_for: Callable | None = None, + f_visit_while: Callable | None = None, + f_visit_alloc_buffer: Callable | None = None, + f_visit_decl_buffer: Callable | None = None, + f_visit_buffer_store: Callable | None = None, + f_visit_assert_stmt: Callable | None = None, + f_visit_seq_stmt: Callable | None = None, + f_visit_evaluate: Callable | None = None, + f_visit_block: Callable | None = None, + f_visit_sblock_realize: Callable | None = None, + # PrimExpr + f_visit_var: Callable | None = None, + f_visit_size_var: Callable | None = None, + f_visit_buffer_load: Callable | None = None, + f_visit_producer_load: Callable | None = None, + f_visit_let: Callable | None = None, + f_visit_call: Callable | None = None, + f_visit_add: Callable | None = None, + f_visit_sub: Callable | None = None, + f_visit_mul: Callable | None = None, + f_visit_div: Callable | None = None, + f_visit_mod: Callable | None = None, + f_visit_floor_div: Callable | None = None, + f_visit_floor_mod: Callable | None = None, + f_visit_min: Callable | None = None, + f_visit_max: Callable | None = None, + f_visit_eq: Callable | None = None, + f_visit_ne: Callable | None = None, + f_visit_lt: Callable | None = None, + f_visit_le: Callable | None = None, + f_visit_gt: Callable | None = None, + f_visit_ge: Callable | None = None, + f_visit_and: Callable | None = None, + f_visit_or: Callable | None = None, + f_visit_reduce: Callable | None = None, + f_visit_cast: Callable | None = None, + f_visit_not: Callable | None = None, + f_visit_select: Callable | None = None, + f_visit_ramp: Callable | None = None, + f_visit_broadcast: Callable | None = None, + f_visit_shuffle: Callable | None = None, + f_visit_int_imm: Callable | None = None, + f_visit_float_imm: Callable | None = None, + f_visit_string_imm: Callable | None = None, + ) -> None: + """Constructor.""" + self.__init_handle_by_constructor__( + _ffi_api.MakePyStmtExprMutatorWithAnalyzer, # type: ignore + [ + f_visit_stmt, + f_visit_expr, + f_visit_bind, + f_visit_attr_stmt, + f_visit_if_then_else, + f_visit_for, + f_visit_while, + f_visit_alloc_buffer, + f_visit_decl_buffer, + f_visit_buffer_store, + f_visit_assert_stmt, + f_visit_seq_stmt, + f_visit_evaluate, + f_visit_block, + f_visit_sblock_realize, + f_visit_var, + f_visit_size_var, + f_visit_buffer_load, + f_visit_producer_load, + f_visit_let, + f_visit_call, + f_visit_add, + f_visit_sub, + f_visit_mul, + f_visit_div, + f_visit_mod, + f_visit_floor_div, + f_visit_floor_mod, + f_visit_min, + f_visit_max, + f_visit_eq, + f_visit_ne, + f_visit_lt, + f_visit_le, + f_visit_gt, + f_visit_ge, + f_visit_and, + f_visit_or, + f_visit_reduce, + f_visit_cast, + f_visit_not, + f_visit_select, + f_visit_ramp, + f_visit_broadcast, + f_visit_shuffle, + f_visit_int_imm, + f_visit_float_imm, + f_visit_string_imm, + ], + ) + + class PyStmtExprMutator: """ A Python StmtExprMutator to define custom mutator for both Stmt and PrimExpr. @@ -2193,180 +2257,118 @@ def visit_string_imm_(self, op: StringImm) -> PrimExpr: return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore -class PyStmtExprMutatorWithAnalyzer(PyStmtExprMutator, _AnalyzerContextMixin): - """A Python StmtExprMutator that maintains an arithmetic analyzer context. - - The analyzer is available as ``self.analyzer`` from user callbacks. The default - mutation binds loop variables, block iter variables, let variables, and branch - conditions before mutating nested nodes, so callbacks can query the analyzer - using the surrounding IR context. - """ - - def __init__(self, analyzer=None): - super().__init__() - self._init_analyzer(analyzer) - - def visit_for_(self, op: For) -> Stmt: - from tvm.ir import Range # pylint: disable=import-outside-toplevel - - depth = len(self._constraint_scopes) - min_value = self.visit_expr(op.min) - extent = self.visit_expr(op.extent) - step = self.visit_expr(op.step) if op.step is not None else None - self.analyzer.bind(op.loop_var, Range.from_min_extent(min_value, extent)) - try: - body = self.visit_stmt(op.body) - finally: - self._pop_constraints(depth) - if ( - min_value.same_as(op.min) - and extent.same_as(op.extent) - and body.same_as(op.body) - and ( - (step is None and op.step is None) - or (step is not None and op.step is not None and step.same_as(op.step)) - ) - ): - return op - return For( - op.loop_var, - min_value, - extent, - op.kind, - body, - op.thread_binding, - op.annotations, - step, - op.span, - ) - - def visit_attr_stmt_(self, op: AttrStmt) -> Stmt: - from tvm.ir import Range # pylint: disable=import-outside-toplevel - - depth = len(self._constraint_scopes) - value = self.visit_expr(op.value) - if op.attr_key in ("thread_extent", "virtual_thread") and isinstance(op.node, IterVar): - self.analyzer.bind(op.node.var, Range.from_min_extent(IntImm(value.dtype, 0), value)) - try: - body = self.visit_stmt(op.body) - finally: - self._pop_constraints(depth) - if value.same_as(op.value) and body.same_as(op.body): - return op - return AttrStmt(op.node, op.attr_key, value, body, op.span) - - def visit_sblock_(self, op: SBlock) -> Stmt: - depth = len(self._constraint_scopes) - try: - for iter_var in op.iter_vars: - self.analyzer.bind(iter_var.var, iter_var.dom) - return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore - finally: - self._pop_constraints(depth) - - def visit_bind_(self, op: Bind) -> Stmt: - value = self.visit_expr(op.value) - if self._is_pure(value): - self.analyzer.bind(op.var, value) - if value.same_as(op.value): - return op - return Bind(op.var, value, op.span) +class _AnalyzerBackedMutatorMixin: + @property + def analyzer(self): + analyzer = getattr(self, "_analyzer", None) + if analyzer is None: + mod = _ffi_api.PyStmtExprMutatorWithAnalyzerGetAnalyzer(self._outer()) # type: ignore + analyzer = _analyzer_from_module(mod) + self._analyzer = analyzer + return analyzer - def visit_seq_stmt_(self, op: SeqStmt) -> Stmt: - depth = len(self._constraint_scopes) - try: - seq = [self.visit_stmt(stmt) for stmt in op.seq] - finally: - self._pop_constraints(depth) - if all(new.same_as(old) for new, old in zip(seq, op.seq)): - return op - return SeqStmt(seq, op.span) - def visit_if_then_else_(self, op: IfThenElse) -> Stmt: - condition = self.visit_expr(op.condition) - real_condition = self._real_condition(condition) - depth = len(self._constraint_scopes) - with self.analyzer.constraint_scope(real_condition): - try: - then_case = self.visit_stmt(op.then_case) - finally: - self._pop_constraints(depth) - else_case = None - if op.else_case is not None: - with self.analyzer.constraint_scope(self._negated_condition(real_condition)): - try: - else_case = self.visit_stmt(op.else_case) - finally: - self._pop_constraints(depth) - if ( - condition.same_as(op.condition) - and then_case.same_as(op.then_case) - and ( - (else_case is None and op.else_case is None) - or ( - else_case is not None - and op.else_case is not None - and else_case.same_as(op.else_case) - ) - ) - ): - return op - return IfThenElse(condition, then_case, else_case, op.span) - - def visit_assert_stmt_(self, op: AssertStmt) -> Stmt: - condition = self.visit_expr(op.condition) - self._push_constraint(condition) - message_parts = [self.visit_expr(msg) for msg in op.message_parts] - if condition.same_as(op.condition) and all( - new.same_as(old) for new, old in zip(message_parts, op.message_parts) - ): - return op - return AssertStmt(condition, op.error_kind, message_parts, op.span) +class PyStmtExprMutatorWithAnalyzer(PyStmtExprMutator, _AnalyzerBackedMutatorMixin): + """A C++-backed Python StmtExprMutator with an arithmetic analyzer context.""" - def visit_let_(self, op: Let) -> PrimExpr: - value = self.visit_expr(op.value) - if self._is_pure(value): - self.analyzer.bind(op.var, value) - body = self.visit_expr(op.body) - if value.same_as(op.value) and body.same_as(op.body): - return op - return Let(op.var, value, body, op.span) - - def visit_reduce_(self, op: Reduce) -> PrimExpr: - for iter_var in op.axis: - self.analyzer.bind(iter_var.var, iter_var.dom) - return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + _tvm_metadata = { + **PyStmtExprMutator._tvm_metadata, + "cls": _PyStmtExprMutatorWithAnalyzer, + } - def visit_call_(self, op: Call) -> PrimExpr: - if getattr(op.op, "name", None) != "tirx.if_then_else": - args = [self.visit_expr(arg) for arg in op.args] - if all(new.same_as(old) for new, old in zip(args, op.args)): - return op - return Call(op.dtype, op.op, args, op.annotations, op.span) - - condition = self.visit_expr(op.args[0]) - with self.analyzer.constraint_scope(condition): - true_value = self.visit_expr(op.args[1]) - with self.analyzer.constraint_scope(self._negated_condition(condition)): - false_value = self.visit_expr(op.args[2]) - if ( - condition.same_as(op.args[0]) - and true_value.same_as(op.args[1]) - and false_value.same_as(op.args[2]) - ): - return op - return Call(op.dtype, op.op, [condition, true_value, false_value], op.annotations, op.span) + def visit_expr(self, expr: PrimExpr) -> PrimExpr: + return _ffi_api.PyStmtExprMutatorWithAnalyzerVisitExpr(self._outer(), expr) # type: ignore - def visit_select_(self, op: Select) -> PrimExpr: - condition = self.visit_expr(op.condition) - with self.analyzer.constraint_scope(condition): - true_value = self.visit_expr(op.true_value) - with self.analyzer.constraint_scope(self._negated_condition(condition)): - false_value = self.visit_expr(op.false_value) - if ( - condition.same_as(op.condition) - and true_value.same_as(op.true_value) - and false_value.same_as(op.false_value) - ): - return op - return Select(condition, true_value, false_value, op.span) + def visit_stmt(self, stmt: Stmt) -> Stmt: + return _ffi_api.PyStmtExprMutatorWithAnalyzerVisitStmt(self._outer(), stmt) # type: ignore + + +def _make_analyzer_visitor_default_stmt(): + def visit(self, op): + _ffi_api.PyStmtExprVisitorWithAnalyzerDefaultVisitStmt(self._outer(), op) # type: ignore + + return visit + + +def _make_analyzer_visitor_default_expr(): + def visit(self, op): + _ffi_api.PyStmtExprVisitorWithAnalyzerDefaultVisitExpr(self._outer(), op) # type: ignore + + return visit + + +def _make_analyzer_mutator_default_stmt(): + def visit(self, op): + return _ffi_api.PyStmtExprMutatorWithAnalyzerDefaultVisitStmt(self._outer(), op) # type: ignore + + return visit + + +def _make_analyzer_mutator_default_expr(): + def visit(self, op): + return _ffi_api.PyStmtExprMutatorWithAnalyzerDefaultVisitExpr(self._outer(), op) # type: ignore + + return visit + + +_STMT_VISIT_METHODS = [ + "visit_bind_", + "visit_attr_stmt_", + "visit_if_then_else_", + "visit_for_", + "visit_while_", + "visit_alloc_buffer_", + "visit_decl_buffer_", + "visit_buffer_store_", + "visit_assert_stmt_", + "visit_seq_stmt_", + "visit_evaluate_", + "visit_sblock_", + "visit_sblock_realize_", +] + +_EXPR_VISIT_METHODS = [ + "visit_var_", + "visit_size_var_", + "visit_buffer_load_", + "visit_producer_load_", + "visit_let_", + "visit_call_", + "visit_add_", + "visit_sub_", + "visit_mul_", + "visit_div_", + "visit_mod_", + "visit_floor_div_", + "visit_floor_mod_", + "visit_min_", + "visit_max_", + "visit_eq_", + "visit_ne_", + "visit_lt_", + "visit_le_", + "visit_gt_", + "visit_ge_", + "visit_and_", + "visit_or_", + "visit_reduce_", + "visit_cast_", + "visit_not_", + "visit_select_", + "visit_ramp_", + "visit_broadcast_", + "visit_shuffle_", + "visit_int_imm_", + "visit_float_imm_", + "visit_string_imm_", +] + +for _method in _STMT_VISIT_METHODS: + setattr(PyStmtExprVisitorWithAnalyzer, _method, _make_analyzer_visitor_default_stmt()) + setattr(PyStmtExprMutatorWithAnalyzer, _method, _make_analyzer_mutator_default_stmt()) + +for _method in _EXPR_VISIT_METHODS: + setattr(PyStmtExprVisitorWithAnalyzer, _method, _make_analyzer_visitor_default_expr()) + setattr(PyStmtExprMutatorWithAnalyzer, _method, _make_analyzer_mutator_default_expr()) + +del _method diff --git a/src/tirx/ir/py_functor.cc b/src/tirx/ir/py_functor.cc index 65d7c3c1b45b..e24618a33e29 100644 --- a/src/tirx/ir/py_functor.cc +++ b/src/tirx/ir/py_functor.cc @@ -23,13 +23,131 @@ * StmtExprVisitor/StmtExprMutator. */ +#include #include +#include +#include +#include +#include +#include #include +#include #include +#include +#include +#include + namespace tvm { namespace tirx { +namespace { + +ffi::Function MakeAnalyzerModule(std::shared_ptr analyzer) { + using ffi::Function; + using ffi::TypedFunction; + auto f = [analyzer](std::string name) -> ffi::Function { + if (name == "const_int_bound") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->const_int_bound(args[0].cast()); + }); + } else if (name == "modular_set") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->modular_set(args[0].cast()); + }); + } else if (name == "const_int_bound_update") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + analyzer->const_int_bound.Update(args[0].cast(), args[1].cast(), + args[2].cast()); + }); + } else if (name == "const_int_bound_is_bound") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->const_int_bound.IsBound(args[0].cast()); + }); + } else if (name == "Simplify") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + if (args.size() == 1) { + *ret = analyzer->Simplify(args[0].cast()); + } else if (args.size() == 2) { + *ret = analyzer->Simplify(args[0].cast(), args[1].cast()); + } else { + TVM_FFI_THROW(InternalError) << "Invalid size of argument (" << args.size() << ")"; + } + }); + } else if (name == "rewrite_simplify") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->rewrite_simplify(args[0].cast()); + }); + } else if (name == "get_rewrite_simplify_stats") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->rewrite_simplify.GetStatsCounters(); + }); + } else if (name == "reset_rewrite_simplify_stats") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + analyzer->rewrite_simplify.ResetStatsCounters(); + }); + } else if (name == "canonical_simplify") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->canonical_simplify(args[0].cast()); + }); + } else if (name == "int_set") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->int_set(args[0].cast(), + args[1].cast>()); + }); + } else if (name == "bind") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + bool allow_override = args.size() >= 3 && args[2].cast(); + if (auto opt_range = args[1].try_cast()) { + analyzer->Bind(args[0].cast(), opt_range.value(), allow_override); + } else { + analyzer->Bind(args[0].cast(), args[1].cast(), allow_override); + } + }); + } else if (name == "can_prove") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + int strength = args[1].cast(); + *ret = analyzer->CanProve(args[0].cast(), + static_cast(strength)); + }); + } else if (name == "enter_constraint_context") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + auto ctx = std::shared_ptr>( + new With(analyzer.get(), args[0].cast())); + auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); }; + *ret = ffi::Function::FromPacked(fexit); + }); + } else if (name == "can_prove_equal") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->CanProveEqual(args[0].cast(), args[1].cast()); + }); + } else if (name == "get_enabled_extensions") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = static_cast(analyzer->rewrite_simplify.GetEnabledExtensions()); + }); + } else if (name == "set_enabled_extensions") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + int64_t flags = args[0].cast(); + analyzer->rewrite_simplify.SetEnabledExtensions( + static_cast(flags)); + }); + } + return ffi::Function(); + }; + return ffi::TypedFunction(f); +} + +PrimExpr ExtractRealCondition(PrimExpr condition) { + if (auto call = condition.as()) { + if (call->op.same_as(builtin::likely())) { + return call->args[0]; + } + } + return condition; +} + +} // namespace + // ================================================ // Helper Macros // ================================================ @@ -800,6 +918,757 @@ class PyStmtExprMutator : public ffi::ObjectRef { PyStmtExprMutatorNode); }; +#define PY_STMT_EXPR_FUNCTOR_CALLBACKS(V) \ + V(f_visit_stmt) \ + V(f_visit_expr) \ + V(f_visit_bind) \ + V(f_visit_attr_stmt) \ + V(f_visit_if_then_else) \ + V(f_visit_for) \ + V(f_visit_while) \ + V(f_visit_alloc_buffer) \ + V(f_visit_decl_buffer) \ + V(f_visit_buffer_store) \ + V(f_visit_assert_stmt) \ + V(f_visit_seq_stmt) \ + V(f_visit_evaluate) \ + V(f_visit_block) \ + V(f_visit_sblock_realize) \ + V(f_visit_var) \ + V(f_visit_size_var) \ + V(f_visit_buffer_load) \ + V(f_visit_producer_load) \ + V(f_visit_let) \ + V(f_visit_call) \ + V(f_visit_add) \ + V(f_visit_sub) \ + V(f_visit_mul) \ + V(f_visit_div) \ + V(f_visit_mod) \ + V(f_visit_floor_div) \ + V(f_visit_floor_mod) \ + V(f_visit_min) \ + V(f_visit_max) \ + V(f_visit_eq) \ + V(f_visit_ne) \ + V(f_visit_lt) \ + V(f_visit_le) \ + V(f_visit_gt) \ + V(f_visit_ge) \ + V(f_visit_and) \ + V(f_visit_or) \ + V(f_visit_reduce) \ + V(f_visit_cast) \ + V(f_visit_not) \ + V(f_visit_select) \ + V(f_visit_ramp) \ + V(f_visit_broadcast) \ + V(f_visit_shuffle) \ + V(f_visit_int_imm) \ + V(f_visit_float_imm) \ + V(f_visit_string_imm) + +template +void SetStmtExprFunctorCallbacks(TNode* node, + const ffi::Array>& callbacks) { + int index = 0; +#define SET_CALLBACK(FIELD) node->FIELD = callbacks[index++].value_or(ffi::Function(nullptr)); + PY_STMT_EXPR_FUNCTOR_CALLBACKS(SET_CALLBACK) +#undef SET_CALLBACK + TVM_FFI_ICHECK_EQ(index, callbacks.size()); +} + +#define PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(OP, METHOD) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + self->METHOD(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_EXPR_VISITOR_DEFAULT_DISPATCH(OP, METHOD) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + self->METHOD(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(OP) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + self->StmtExprVisitor::VisitStmt_(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(OP) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + self->StmtExprVisitor::VisitExpr_(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(OP, METHOD) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + return self->METHOD(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_EXPR_MUTATOR_DEFAULT_DISPATCH(OP, METHOD) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + return self->METHOD(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(OP) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + return self->StmtExprMutator::VisitStmt_(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(OP) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + return self->StmtExprMutator::VisitExpr_(static_cast(n.get())); \ + }); + +class PyStmtExprVisitorWithAnalyzerNode : public PyStmtExprVisitorNode { + private: + using TSelf = PyStmtExprVisitorWithAnalyzerNode; + using FExprType = tvm::NodeFunctor; + using FStmtType = tvm::NodeFunctor; + + public: + PyStmtExprVisitorWithAnalyzerNode() = default; + PyStmtExprVisitorWithAnalyzerNode(const PyStmtExprVisitorWithAnalyzerNode& other) + : PyStmtExprVisitorNode(other) {} + + ffi::Function GetAnalyzer() { return MakeAnalyzerModule(analyzer_); } + + void DefaultVisitExprWithAnalyzer(const PrimExpr& expr) { + static FExprType vtable = InitExprVTable(); + vtable(expr, this); + } + + void DefaultVisitStmtWithAnalyzer(const Stmt& stmt) { + static FStmtType vtable = InitStmtVTable(); + vtable(stmt, this); + } + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("tirx.PyStmtExprVisitorWithAnalyzer", + PyStmtExprVisitorWithAnalyzerNode, PyStmtExprVisitorNode); + + private: + std::shared_ptr analyzer_{std::make_shared()}; + ScopeStack> constraint_scope_; + + void VisitStmt_(const ForNode* op) override { + if (f_visit_for != nullptr) { + f_visit_for(op); + } else { + DefaultVisitFor(op); + } + } + + void VisitStmt_(const SBlockNode* op) override { + if (f_visit_block != nullptr) { + f_visit_block(op); + } else { + DefaultVisitSBlock(op); + } + } + + void VisitStmt_(const BindNode* op) override { + if (f_visit_bind != nullptr) { + f_visit_bind(op); + } else { + DefaultVisitBind(op); + } + } + + void VisitStmt_(const IfThenElseNode* op) override { + if (f_visit_if_then_else != nullptr) { + f_visit_if_then_else(op); + } else { + DefaultVisitIfThenElse(op); + } + } + + void VisitStmt_(const AttrStmtNode* op) override { + if (f_visit_attr_stmt != nullptr) { + f_visit_attr_stmt(op); + } else { + DefaultVisitAttrStmt(op); + } + } + + void VisitStmt_(const AssertStmtNode* op) override { + if (f_visit_assert_stmt != nullptr) { + f_visit_assert_stmt(op); + } else { + DefaultVisitAssertStmt(op); + } + } + + void VisitStmt_(const SeqStmtNode* op) override { + if (f_visit_seq_stmt != nullptr) { + f_visit_seq_stmt(op); + } else { + DefaultVisitSeqStmt(op); + } + } + + void VisitExpr_(const CallNode* op) override { + if (f_visit_call != nullptr) { + f_visit_call(op); + } else { + DefaultVisitCall(op); + } + } + + void VisitExpr_(const LetNode* op) override { + if (f_visit_let != nullptr) { + f_visit_let(op); + } else { + DefaultVisitLet(op); + } + } + + void VisitExpr_(const ReduceNode* op) override { + if (f_visit_reduce != nullptr) { + f_visit_reduce(op); + } else { + DefaultVisitReduce(op); + } + } + + void DefaultVisitFor(const ForNode* op) { + constraint_scope_.WithNewScope([&]() { + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + StmtExprVisitor::VisitStmt_(op); + }); + } + + void DefaultVisitSBlock(const SBlockNode* op) { + constraint_scope_.WithNewScope([&]() { + for (const IterVar& iter_var : op->iter_vars) { + analyzer_->Bind(iter_var->var, iter_var->dom); + } + StmtExprVisitor::VisitStmt_(op); + }); + } + + void DefaultVisitBind(const BindNode* op) { + this->VisitExpr(op->value); + if (SideEffect(op->value) <= CallEffectKind::kPure) { + analyzer_->Bind(op->var, op->value); + } + } + + void DefaultVisitIfThenElse(const IfThenElseNode* op) { + constraint_scope_.WithNewScope([&]() { + this->VisitExpr(op->condition); + PrimExpr real_condition = ExtractRealCondition(op->condition); + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), real_condition); + this->VisitStmt(op->then_case); + }); + if (op->else_case) { + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), + analyzer_->rewrite_simplify(Not(real_condition))); + this->VisitStmt(op->else_case.value()); + }); + } + }); + } + + void DefaultVisitAttrStmt(const AttrStmtNode* op) { + constraint_scope_.WithNewScope([&]() { + if (op->attr_key == tirx::attr::thread_extent || + op->attr_key == s_tir::attr::virtual_thread) { + IterVar iv = Downcast(op->node); + TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); + analyzer_->Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value)); + } + StmtExprVisitor::VisitStmt_(op); + }); + } + + void DefaultVisitAssertStmt(const AssertStmtNode* op) { + this->VisitExpr(op->condition); + constraint_scope_.Current().Emplace(analyzer_.get(), op->condition); + this->VisitExpr(op->error_kind); + for (const StringImm& message : op->message_parts) { + this->VisitExpr(message); + } + } + + void DefaultVisitSeqStmt(const SeqStmtNode* op) { StmtExprVisitor::VisitStmt_(op); } + + void DefaultVisitCall(const CallNode* op) { + static auto op_if_then_else = Op::Get("tirx.if_then_else"); + if (op->op.same_as(op_if_then_else)) { + PrimExpr cond = op->args[0]; + this->VisitExpr(op->args[0]); + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), cond); + this->VisitExpr(op->args[1]); + }); + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), + analyzer_->rewrite_simplify(Not(cond))); + this->VisitExpr(op->args[2]); + }); + } else { + StmtExprVisitor::VisitExpr_(op); + } + } + + void DefaultVisitLet(const LetNode* op) { + this->VisitExpr(op->value); + if (SideEffect(op->value) <= CallEffectKind::kPure) { + analyzer_->Bind(op->var, op->value); + } + this->VisitExpr(op->body); + } + + void DefaultVisitReduce(const ReduceNode* op) { + for (const IterVar& iv : op->axis) { + analyzer_->Bind(iv->var, iv->dom); + } + StmtExprVisitor::VisitExpr_(op); + } + + static FStmtType InitStmtVTable() { + FStmtType vtable; + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(BindNode, DefaultVisitBind); + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(AttrStmtNode, DefaultVisitAttrStmt); + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(IfThenElseNode, DefaultVisitIfThenElse); + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(ForNode, DefaultVisitFor); + PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(WhileNode); + PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(AllocBufferNode); + PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(DeclBufferNode); + PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(BufferStoreNode); + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(AssertStmtNode, DefaultVisitAssertStmt); + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(SeqStmtNode, DefaultVisitSeqStmt); + PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(EvaluateNode); + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(SBlockNode, DefaultVisitSBlock); + PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(SBlockRealizeNode); + vtable.Finalize(); + return vtable; + } + + static FExprType InitExprVTable() { + FExprType vtable; + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(VarNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(SizeVarNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(BufferLoadNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(ProducerLoadNode); + PY_ANALYZER_EXPR_VISITOR_DEFAULT_DISPATCH(LetNode, DefaultVisitLet); + PY_ANALYZER_EXPR_VISITOR_DEFAULT_DISPATCH(CallNode, DefaultVisitCall); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(AddNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(SubNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(MulNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(DivNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(ModNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(FloorDivNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(FloorModNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(MinNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(MaxNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(EQNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(NENode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(LTNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(LENode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(GTNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(GENode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(AndNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(OrNode); + PY_ANALYZER_EXPR_VISITOR_DEFAULT_DISPATCH(ReduceNode, DefaultVisitReduce); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(CastNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(NotNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(SelectNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(RampNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(ShuffleNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(BroadcastNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(IntImmNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(FloatImmNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(StringImmNode); + vtable.Finalize(); + return vtable; + } +}; + +class PyStmtExprVisitorWithAnalyzer : public ffi::ObjectRef { + public: + explicit PyStmtExprVisitorWithAnalyzer(ffi::ObjectPtr data) + : ffi::ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } + + TVM_DLL static PyStmtExprVisitorWithAnalyzer MakePyStmtExprVisitorWithAnalyzer( + ffi::Array> callbacks) { + ffi::ObjectPtr n = + ffi::make_object(); + SetStmtExprFunctorCallbacks(n.get(), callbacks); + return PyStmtExprVisitorWithAnalyzer(n); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PyStmtExprVisitorWithAnalyzer, ffi::ObjectRef, + PyStmtExprVisitorWithAnalyzerNode); +}; + +class PyStmtExprMutatorWithAnalyzerNode : public PyStmtExprMutatorNode { + private: + using TSelf = PyStmtExprMutatorWithAnalyzerNode; + using FExprType = tvm::NodeFunctor; + using FStmtType = tvm::NodeFunctor; + + public: + PyStmtExprMutatorWithAnalyzerNode() = default; + PyStmtExprMutatorWithAnalyzerNode(const PyStmtExprMutatorWithAnalyzerNode& other) + : PyStmtExprMutatorNode(other) {} + + ffi::Function GetAnalyzer() { return MakeAnalyzerModule(analyzer_); } + + PrimExpr DefaultVisitExprWithAnalyzer(const PrimExpr& expr) { + static FExprType vtable = InitExprVTable(); + return vtable(expr, this); + } + + Stmt DefaultVisitStmtWithAnalyzer(const Stmt& stmt) { + static FStmtType vtable = InitStmtVTable(); + return vtable(stmt, this); + } + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("tirx.PyStmtExprMutatorWithAnalyzer", + PyStmtExprMutatorWithAnalyzerNode, PyStmtExprMutatorNode); + + private: + std::shared_ptr analyzer_{std::make_shared()}; + ScopeStack> constraint_scope_; + + Stmt VisitStmt_(const ForNode* op) override { + if (f_visit_for != nullptr) { + return f_visit_for(op).cast(); + } + return DefaultVisitFor(op); + } + + Stmt VisitStmt_(const SBlockNode* op) override { + if (f_visit_block != nullptr) { + return f_visit_block(op).cast(); + } + return DefaultVisitSBlock(op); + } + + Stmt VisitStmt_(const BindNode* op) override { + if (f_visit_bind != nullptr) { + return f_visit_bind(op).cast(); + } + return DefaultVisitBind(op); + } + + Stmt VisitStmt_(const IfThenElseNode* op) override { + if (f_visit_if_then_else != nullptr) { + return f_visit_if_then_else(op).cast(); + } + return DefaultVisitIfThenElse(op); + } + + Stmt VisitStmt_(const AttrStmtNode* op) override { + if (f_visit_attr_stmt != nullptr) { + return f_visit_attr_stmt(op).cast(); + } + return DefaultVisitAttrStmt(op); + } + + Stmt VisitStmt_(const AssertStmtNode* op) override { + if (f_visit_assert_stmt != nullptr) { + return f_visit_assert_stmt(op).cast(); + } + return DefaultVisitAssertStmt(op); + } + + Stmt VisitStmt_(const SeqStmtNode* op) override { + if (f_visit_seq_stmt != nullptr) { + return f_visit_seq_stmt(op).cast(); + } + return DefaultVisitSeqStmt(op); + } + + PrimExpr VisitExpr_(const CallNode* op) override { + if (f_visit_call != nullptr) { + return f_visit_call(op).cast(); + } + return DefaultVisitCall(op); + } + + PrimExpr VisitExpr_(const LetNode* op) override { + if (f_visit_let != nullptr) { + return f_visit_let(op).cast(); + } + return DefaultVisitLet(op); + } + + PrimExpr VisitExpr_(const SelectNode* op) override { + if (f_visit_select != nullptr) { + return f_visit_select(op).cast(); + } + return DefaultVisitSelect(op); + } + + PrimExpr VisitExpr_(const ReduceNode* op) override { + if (f_visit_reduce != nullptr) { + return f_visit_reduce(op).cast(); + } + return DefaultVisitReduce(op); + } + + Stmt DefaultVisitFor(const ForNode* op) { + return constraint_scope_.WithNewScope([&]() -> Stmt { + Range dom = Range::FromMinExtent(op->min, op->extent); + analyzer_->Bind(op->loop_var, dom); + return StmtExprMutator::VisitStmt_(op); + }); + } + + Stmt DefaultVisitSBlock(const SBlockNode* op) { + return constraint_scope_.WithNewScope([&]() -> Stmt { + for (const IterVar& iter_var : op->iter_vars) { + analyzer_->Bind(iter_var->var, iter_var->dom); + } + return StmtExprMutator::VisitStmt_(op); + }); + } + + Stmt DefaultVisitBind(const BindNode* op) { + PrimExpr value = this->VisitExpr(op->value); + if (SideEffect(value) <= CallEffectKind::kPure) { + analyzer_->Bind(op->var, value); + } + if (value.same_as(op->value)) { + return ffi::GetRef(op); + } + auto n = this->CopyOnWrite(op); + n->value = std::move(value); + return Stmt(n); + } + + Stmt DefaultVisitIfThenElse(const IfThenElseNode* op) { + return constraint_scope_.WithNewScope([&]() -> Stmt { + PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr real_condition = ExtractRealCondition(condition); + Stmt then_case; + ffi::Optional else_case; + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), real_condition); + then_case = this->VisitStmt(op->then_case); + }); + if (op->else_case) { + PrimExpr neg_condition = analyzer_->rewrite_simplify(Not(real_condition)); + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), neg_condition); + else_case = this->VisitStmt(op->else_case.value()); + }); + } + if (is_one(real_condition)) return then_case; + if (is_zero(real_condition)) return else_case.value_or(Evaluate(0)); + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return ffi::GetRef(op); + } + auto n = this->CopyOnWrite(op); + n->condition = std::move(condition); + n->then_case = std::move(then_case); + n->else_case = std::move(else_case); + return Stmt(n); + }); + } + + Stmt DefaultVisitAttrStmt(const AttrStmtNode* op) { + return constraint_scope_.WithNewScope([&]() -> Stmt { + if (op->attr_key == tirx::attr::thread_extent || + op->attr_key == s_tir::attr::virtual_thread) { + IterVar iv = Downcast(op->node); + TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); + analyzer_->Bind(iv->var, Range::FromMinExtent(make_zero(op->value.dtype()), op->value)); + } + return StmtExprMutator::VisitStmt_(op); + }); + } + + Stmt DefaultVisitAssertStmt(const AssertStmtNode* op) { + PrimExpr condition = this->VisitExpr(op->condition); + constraint_scope_.Current().Emplace(analyzer_.get(), condition); + PrimExpr error_kind = this->VisitExpr(op->error_kind); + ffi::Array message_parts; + bool message_parts_same = true; + for (const StringImm& message : op->message_parts) { + StringImm new_message = Downcast(this->VisitExpr(message)); + if (!new_message.same_as(message)) { + message_parts_same = false; + } + message_parts.push_back(std::move(new_message)); + } + if (condition.same_as(op->condition) && error_kind.same_as(op->error_kind) && + message_parts_same) { + return ffi::GetRef(op); + } + auto n = this->CopyOnWrite(op); + n->condition = std::move(condition); + n->error_kind = Downcast(std::move(error_kind)); + if (!message_parts_same) { + n->message_parts = std::move(message_parts); + } + return Stmt(n); + } + + Stmt DefaultVisitSeqStmt(const SeqStmtNode* op) { return StmtExprMutator::VisitStmt_(op); } + + PrimExpr DefaultVisitCall(const CallNode* op) { + static auto op_if_then_else = Op::Get("tirx.if_then_else"); + if (op->op.same_as(op_if_then_else)) { + PrimExpr cond = this->VisitExpr(op->args[0]); + PrimExpr true_value; + PrimExpr false_value; + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), cond); + true_value = this->VisitExpr(op->args[1]); + }); + constraint_scope_.WithNewScope([&]() { + PrimExpr not_cond = analyzer_->rewrite_simplify(Not(cond)); + constraint_scope_.Current().Emplace(analyzer_.get(), not_cond); + false_value = this->VisitExpr(op->args[2]); + }); + if (is_zero(cond)) return false_value; + if (is_one(cond)) return true_value; + if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) && + false_value.same_as(op->args[2])) { + return ffi::GetRef(op); + } + return Call(op->dtype, op->op, {cond, true_value, false_value}, op->annotations, op->span); + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr DefaultVisitLet(const LetNode* op) { + PrimExpr value = this->VisitExpr(op->value); + if (SideEffect(value) <= CallEffectKind::kPure) { + analyzer_->Bind(op->var, value); + } + PrimExpr body = this->VisitExpr(op->body); + if (value.same_as(op->value) && body.same_as(op->body)) { + return ffi::GetRef(op); + } + return Let(op->var, value, body); + } + + PrimExpr DefaultVisitSelect(const SelectNode* op) { + PrimExpr cond = this->VisitExpr(op->condition); + PrimExpr true_value; + PrimExpr false_value; + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), cond); + true_value = this->VisitExpr(op->true_value); + }); + constraint_scope_.WithNewScope([&]() { + PrimExpr neg_cond = analyzer_->rewrite_simplify(Not(cond)); + constraint_scope_.Current().Emplace(analyzer_.get(), neg_cond); + false_value = this->VisitExpr(op->false_value); + }); + if (is_zero(cond)) return false_value; + if (is_one(cond)) return true_value; + if (cond.same_as(op->condition) && true_value.same_as(op->true_value) && + false_value.same_as(op->false_value)) { + return ffi::GetRef(op); + } + return Select(cond, true_value, false_value); + } + + PrimExpr DefaultVisitReduce(const ReduceNode* op) { + for (const IterVar& iv : op->axis) { + analyzer_->Bind(iv->var, iv->dom); + } + return StmtExprMutator::VisitExpr_(op); + } + + static FStmtType InitStmtVTable() { + FStmtType vtable; + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(BindNode, DefaultVisitBind); + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(AttrStmtNode, DefaultVisitAttrStmt); + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(IfThenElseNode, DefaultVisitIfThenElse); + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(ForNode, DefaultVisitFor); + PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(WhileNode); + PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(AllocBufferNode); + PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(DeclBufferNode); + PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(BufferStoreNode); + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(AssertStmtNode, DefaultVisitAssertStmt); + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(SeqStmtNode, DefaultVisitSeqStmt); + PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(EvaluateNode); + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(SBlockNode, DefaultVisitSBlock); + PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(SBlockRealizeNode); + vtable.Finalize(); + return vtable; + } + + static FExprType InitExprVTable() { + FExprType vtable; + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(VarNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(SizeVarNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(BufferLoadNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(ProducerLoadNode); + PY_ANALYZER_EXPR_MUTATOR_DEFAULT_DISPATCH(LetNode, DefaultVisitLet); + PY_ANALYZER_EXPR_MUTATOR_DEFAULT_DISPATCH(CallNode, DefaultVisitCall); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(AddNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(SubNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(MulNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(DivNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(ModNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(FloorDivNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(FloorModNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(MinNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(MaxNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(EQNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(NENode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(LTNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(LENode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(GTNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(GENode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(AndNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(OrNode); + PY_ANALYZER_EXPR_MUTATOR_DEFAULT_DISPATCH(ReduceNode, DefaultVisitReduce); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(CastNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(NotNode); + PY_ANALYZER_EXPR_MUTATOR_DEFAULT_DISPATCH(SelectNode, DefaultVisitSelect); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(RampNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(ShuffleNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(BroadcastNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(IntImmNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(FloatImmNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(StringImmNode); + vtable.Finalize(); + return vtable; + } +}; + +class PyStmtExprMutatorWithAnalyzer : public ffi::ObjectRef { + public: + explicit PyStmtExprMutatorWithAnalyzer(ffi::ObjectPtr data) + : ffi::ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } + + TVM_DLL static PyStmtExprMutatorWithAnalyzer MakePyStmtExprMutatorWithAnalyzer( + ffi::Array> callbacks) { + ffi::ObjectPtr n = + ffi::make_object(); + SetStmtExprFunctorCallbacks(n.get(), callbacks); + return PyStmtExprMutatorWithAnalyzer(n); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PyStmtExprMutatorWithAnalyzer, ffi::ObjectRef, + PyStmtExprMutatorWithAnalyzerNode); +}; + // ================================================ // TVM Register // ================================================ @@ -807,13 +1676,19 @@ class PyStmtExprMutator : public ffi::ObjectRef { TVM_FFI_STATIC_INIT_BLOCK() { PyStmtExprVisitorNode::RegisterReflection(); PyStmtExprMutatorNode::RegisterReflection(); + PyStmtExprVisitorWithAnalyzerNode::RegisterReflection(); + PyStmtExprMutatorWithAnalyzerNode::RegisterReflection(); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tirx.MakePyStmtExprVisitor", PyStmtExprVisitor::MakePyStmtExprVisitor) - .def("tirx.MakePyStmtExprMutator", PyStmtExprMutator::MakePyStmtExprMutator); + .def("tirx.MakePyStmtExprMutator", PyStmtExprMutator::MakePyStmtExprMutator) + .def("tirx.MakePyStmtExprVisitorWithAnalyzer", + PyStmtExprVisitorWithAnalyzer::MakePyStmtExprVisitorWithAnalyzer) + .def("tirx.MakePyStmtExprMutatorWithAnalyzer", + PyStmtExprMutatorWithAnalyzer::MakePyStmtExprMutatorWithAnalyzer); } // StmtExprVisitor @@ -830,6 +1705,29 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](PyStmtExprVisitor visitor, const PrimExpr& expr) { visitor->VisitExpr(expr); }); } +// StmtExprVisitorWithAnalyzer +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tirx.PyStmtExprVisitorWithAnalyzerDefaultVisitExpr", + [](PyStmtExprVisitorWithAnalyzer visitor, const PrimExpr& expr) { + visitor->DefaultVisitExprWithAnalyzer(expr); + }) + .def("tirx.PyStmtExprVisitorWithAnalyzerDefaultVisitStmt", + [](PyStmtExprVisitorWithAnalyzer visitor, const Stmt& stmt) { + visitor->DefaultVisitStmtWithAnalyzer(stmt); + }) + .def( + "tirx.PyStmtExprVisitorWithAnalyzerVisitStmt", + [](PyStmtExprVisitorWithAnalyzer visitor, const Stmt& stmt) { visitor->VisitStmt(stmt); }) + .def("tirx.PyStmtExprVisitorWithAnalyzerVisitExpr", + [](PyStmtExprVisitorWithAnalyzer visitor, const PrimExpr& expr) { + visitor->VisitExpr(expr); + }) + .def("tirx.PyStmtExprVisitorWithAnalyzerGetAnalyzer", + [](PyStmtExprVisitorWithAnalyzer visitor) { return visitor->GetAnalyzer(); }); +} + // StmtExprMutator TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; @@ -848,5 +1746,29 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](PyStmtExprMutator mutator, const Stmt& stmt) { return mutator->VisitStmt(stmt); }); } +// StmtExprMutatorWithAnalyzer +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tirx.PyStmtExprMutatorWithAnalyzerDefaultVisitExpr", + [](PyStmtExprMutatorWithAnalyzer mutator, const PrimExpr& expr) { + return mutator->DefaultVisitExprWithAnalyzer(expr); + }) + .def("tirx.PyStmtExprMutatorWithAnalyzerDefaultVisitStmt", + [](PyStmtExprMutatorWithAnalyzer mutator, const Stmt& stmt) { + return mutator->DefaultVisitStmtWithAnalyzer(stmt); + }) + .def("tirx.PyStmtExprMutatorWithAnalyzerVisitExpr", + [](PyStmtExprMutatorWithAnalyzer mutator, const PrimExpr& expr) { + return mutator->VisitExpr(expr); + }) + .def("tirx.PyStmtExprMutatorWithAnalyzerVisitStmt", + [](PyStmtExprMutatorWithAnalyzer mutator, const Stmt& stmt) { + return mutator->VisitStmt(stmt); + }) + .def("tirx.PyStmtExprMutatorWithAnalyzerGetAnalyzer", + [](PyStmtExprMutatorWithAnalyzer mutator) { return mutator->GetAnalyzer(); }); +} + } // namespace tirx } // namespace tvm diff --git a/tests/python/tirx-transform/test_tir_functor.py b/tests/python/tirx-transform/test_tir_functor.py index fb23bc14adf2..c9d2d0324e91 100644 --- a/tests/python/tirx-transform/test_tir_functor.py +++ b/tests/python/tirx-transform/test_tir_functor.py @@ -267,6 +267,28 @@ def visit_evaluate_(self, op: Evaluate): super().visit_evaluate_(op) +@tirx.functor.visitor +class AssertMessageVisitor(PyStmtExprVisitorWithAnalyzer): + """Record string immediates reached through analyzer-aware assert traversal.""" + + def __init__(self): + super().__init__() + self.strings = [] + + def visit_string_imm_(self, op: StringImm): + self.strings.append(op.value) + + +@tirx.functor.mutator +class AssertMessageMutator(PyStmtExprMutatorWithAnalyzer): + """Rewrite assert message strings through analyzer-aware traversal.""" + + def visit_string_imm_(self, op: StringImm): + if op.value == "bad": + return StringImm("rewritten") + return op + + def test_basic_visitor(): """Test the basic AST printer visitor""" expr = Add(Var("x", dtype="int32"), Var("y", dtype="int32")) @@ -446,6 +468,22 @@ def test_analyzer_aware_visitor_assert_context(): assert visitor.facts == [True] +def test_analyzer_aware_assert_visits_error_kind_and_message(): + """Test that analyzer-aware assert traversal covers error kind and message parts.""" + x = Var("x", dtype="int32") + stmt = AssertStmt(LT(x, IntImm("int32", 4)), StringImm("ValueError"), [StringImm("bad")]) + + visitor = AssertMessageVisitor() + visitor.visit_stmt(stmt) + + assert visitor.strings == ["ValueError", "bad"] + + result = AssertMessageMutator().visit_stmt(stmt) + + assert result.error_kind.value == "ValueError" + assert [part.value for part in result.message_parts] == ["rewritten"] + + def test_analyzer_aware_visitor_pure_bind_context(): """Test that pure Bind values are visible to later statements in the same sequence.""" x = Var("x", dtype="int32") @@ -474,7 +512,8 @@ def test_analyzer_aware_mutator_skips_opaque_bind_context(): result = AnalyzerAwareMutator().visit_stmt(stmt) - assert isinstance(result, SeqStmt) + assert isinstance(result, tirx.Bind) + assert result.var.same_as(h) def test_analyzer_aware_visitor_branch_assert_does_not_leak(): From a2e8773f6d0ddd68000fa88a52b0840ebd40beb4 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sun, 24 May 2026 22:33:15 -0400 Subject: [PATCH 3/3] finish5 --- python/tvm/arith/analyzer.py | 11 ++++++++++- python/tvm/tirx/functor.py | 19 +------------------ src/tirx/ir/py_functor.cc | 4 ++-- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index ea70c4de3d0f..1564a73e4c46 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -108,7 +108,16 @@ class Analyzer: """ def __init__(self): - _mod = _ffi_api.CreateAnalyzer() + self._init_from_module(_ffi_api.CreateAnalyzer()) + + @classmethod + def _from_module(cls, mod): + analyzer = cls.__new__(cls) + analyzer._init_from_module(mod) + return analyzer + + def _init_from_module(self, mod): + _mod = mod self._const_int_bound = _mod("const_int_bound") self._const_int_bound_update = _mod("const_int_bound_update") self._const_int_bound_is_bound = _mod("const_int_bound_is_bound") diff --git a/python/tvm/tirx/functor.py b/python/tvm/tirx/functor.py index 007b94083422..c2969ea5a8a1 100644 --- a/python/tvm/tirx/functor.py +++ b/python/tvm/tirx/functor.py @@ -1063,24 +1063,7 @@ def visit_string_imm_(self, op: StringImm) -> None: def _analyzer_from_module(mod): from tvm.arith.analyzer import Analyzer # pylint: disable=import-outside-toplevel - analyzer = Analyzer.__new__(Analyzer) - analyzer._const_int_bound = mod("const_int_bound") - analyzer._const_int_bound_update = mod("const_int_bound_update") - analyzer._const_int_bound_is_bound = mod("const_int_bound_is_bound") - analyzer._bind = mod("bind") - analyzer._modular_set = mod("modular_set") - analyzer._simplify = mod("Simplify") - analyzer._rewrite_simplify = mod("rewrite_simplify") - analyzer._get_rewrite_simplify_stats = mod("get_rewrite_simplify_stats") - analyzer._reset_rewrite_simplify_stats = mod("reset_rewrite_simplify_stats") - analyzer._canonical_simplify = mod("canonical_simplify") - analyzer._int_set = mod("int_set") - analyzer._enter_constraint_context = mod("enter_constraint_context") - analyzer._can_prove_equal = mod("can_prove_equal") - analyzer._can_prove = mod("can_prove") - analyzer._get_enabled_extensions = mod("get_enabled_extensions") - analyzer._set_enabled_extensions = mod("set_enabled_extensions") - return analyzer + return Analyzer._from_module(mod) class _AnalyzerBackedVisitorMixin: diff --git a/src/tirx/ir/py_functor.cc b/src/tirx/ir/py_functor.cc index e24618a33e29..fcb0586c0391 100644 --- a/src/tirx/ir/py_functor.cc +++ b/src/tirx/ir/py_functor.cc @@ -112,8 +112,8 @@ ffi::Function MakeAnalyzerModule(std::shared_ptr analyzer) { }); } else if (name == "enter_constraint_context") { return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { - auto ctx = std::shared_ptr>( - new With(analyzer.get(), args[0].cast())); + auto ctx = std::make_shared>(analyzer.get(), + args[0].cast()); auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); }; *ret = ffi::Function::FromPacked(fexit); });