diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index ea70c4de3d0f..1564a73e4c46 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -108,7 +108,16 @@ class Analyzer: """ def __init__(self): - _mod = _ffi_api.CreateAnalyzer() + self._init_from_module(_ffi_api.CreateAnalyzer()) + + @classmethod + def _from_module(cls, mod): + analyzer = cls.__new__(cls) + analyzer._init_from_module(mod) + return analyzer + + def _init_from_module(self, mod): + _mod = mod self._const_int_bound = _mod("const_int_bound") self._const_int_bound_update = _mod("const_int_bound_update") self._const_int_bound_is_bound = _mod("const_int_bound_is_bound") diff --git a/python/tvm/tirx/__init__.py b/python/tvm/tirx/__init__.py index 00a3522238af..4d805ad077b2 100644 --- a/python/tvm/tirx/__init__.py +++ b/python/tvm/tirx/__init__.py @@ -104,7 +104,12 @@ from . import backend from . import stmt_functor -from .functor import PyStmtExprVisitor, PyStmtExprMutator +from .functor import ( + PyStmtExprVisitor, + PyStmtExprMutator, + PyStmtExprVisitorWithAnalyzer, + PyStmtExprMutatorWithAnalyzer, +) # Compiler-only submodules. Skip under `TVM_USE_RUNTIME_LIB=1` since they # perform compiler-side FFI at module load (schema engine looks up diff --git a/python/tvm/tirx/functor.py b/python/tvm/tirx/functor.py index 4619c0b51fbb..c2969ea5a8a1 100644 --- a/python/tvm/tirx/functor.py +++ b/python/tvm/tirx/functor.py @@ -264,6 +264,121 @@ def __init__( ) +@tvm_ffi.register_object("tirx.PyStmtExprVisitorWithAnalyzer") +class _PyStmtExprVisitorWithAnalyzer(tvm_ffi.core.Object): + """ + An internal C++-backed wrapper for analyzer-aware Python StmtExprVisitor. + """ + + def __init__( + self, + f_visit_stmt: Callable | None = None, + f_visit_expr: Callable | None = None, + # Stmt + f_visit_bind: Callable | None = None, + f_visit_attr_stmt: Callable | None = None, + f_visit_if_then_else: Callable | None = None, + f_visit_for: Callable | None = None, + f_visit_while: Callable | None = None, + f_visit_alloc_buffer: Callable | None = None, + f_visit_decl_buffer: Callable | None = None, + f_visit_buffer_store: Callable | None = None, + f_visit_assert_stmt: Callable | None = None, + f_visit_seq_stmt: Callable | None = None, + f_visit_evaluate: Callable | None = None, + f_visit_block: Callable | None = None, + f_visit_sblock_realize: Callable | None = None, + # PrimExpr + f_visit_var: Callable | None = None, + f_visit_size_var: Callable | None = None, + f_visit_buffer_load: Callable | None = None, + f_visit_producer_load: Callable | None = None, + f_visit_let: Callable | None = None, + f_visit_call: Callable | None = None, + f_visit_add: Callable | None = None, + f_visit_sub: Callable | None = None, + f_visit_mul: Callable | None = None, + f_visit_div: Callable | None = None, + f_visit_mod: Callable | None = None, + f_visit_floor_div: Callable | None = None, + f_visit_floor_mod: Callable | None = None, + f_visit_min: Callable | None = None, + f_visit_max: Callable | None = None, + f_visit_eq: Callable | None = None, + f_visit_ne: Callable | None = None, + f_visit_lt: Callable | None = None, + f_visit_le: Callable | None = None, + f_visit_gt: Callable | None = None, + f_visit_ge: Callable | None = None, + f_visit_and: Callable | None = None, + f_visit_or: Callable | None = None, + f_visit_reduce: Callable | None = None, + f_visit_cast: Callable | None = None, + f_visit_not: Callable | None = None, + f_visit_select: Callable | None = None, + f_visit_ramp: Callable | None = None, + f_visit_broadcast: Callable | None = None, + f_visit_shuffle: Callable | None = None, + f_visit_int_imm: Callable | None = None, + f_visit_float_imm: Callable | None = None, + f_visit_string_imm: Callable | None = None, + ) -> None: + """Constructor.""" + self.__init_handle_by_constructor__( + _ffi_api.MakePyStmtExprVisitorWithAnalyzer, # type: ignore + [ + f_visit_stmt, + f_visit_expr, + f_visit_bind, + f_visit_attr_stmt, + f_visit_if_then_else, + f_visit_for, + f_visit_while, + f_visit_alloc_buffer, + f_visit_decl_buffer, + f_visit_buffer_store, + f_visit_assert_stmt, + f_visit_seq_stmt, + f_visit_evaluate, + f_visit_block, + f_visit_sblock_realize, + f_visit_var, + f_visit_size_var, + f_visit_buffer_load, + f_visit_producer_load, + f_visit_let, + f_visit_call, + f_visit_add, + f_visit_sub, + f_visit_mul, + f_visit_div, + f_visit_mod, + f_visit_floor_div, + f_visit_floor_mod, + f_visit_min, + f_visit_max, + f_visit_eq, + f_visit_ne, + f_visit_lt, + f_visit_le, + f_visit_gt, + f_visit_ge, + f_visit_and, + f_visit_or, + f_visit_reduce, + f_visit_cast, + f_visit_not, + f_visit_select, + f_visit_ramp, + f_visit_broadcast, + f_visit_shuffle, + f_visit_int_imm, + f_visit_float_imm, + f_visit_string_imm, + ], + ) + + class PyStmtExprVisitor: """ A Python StmtExprVisitor to define custom visitor for both Stmt and PrimExpr. @@ -945,6 +1060,38 @@ def visit_string_imm_(self, op: StringImm) -> None: _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore +def _analyzer_from_module(mod): + from tvm.arith.analyzer import Analyzer # pylint: disable=import-outside-toplevel + + return Analyzer._from_module(mod) + + +class _AnalyzerBackedVisitorMixin: + @property + def analyzer(self): + analyzer = getattr(self, "_analyzer", None) + if analyzer is None: + mod = _ffi_api.PyStmtExprVisitorWithAnalyzerGetAnalyzer(self._outer()) # type: ignore + analyzer = _analyzer_from_module(mod) + self._analyzer = analyzer + return analyzer + + +class PyStmtExprVisitorWithAnalyzer(PyStmtExprVisitor, _AnalyzerBackedVisitorMixin): + """A C++-backed Python StmtExprVisitor with an arithmetic analyzer context.""" + + _tvm_metadata = { + **PyStmtExprVisitor._tvm_metadata, + "cls": _PyStmtExprVisitorWithAnalyzer, + } + + def visit_stmt(self, stmt: Stmt) -> None: + _ffi_api.PyStmtExprVisitorWithAnalyzerVisitStmt(self._outer(), stmt) # type: ignore + + def visit_expr(self, expr: PrimExpr) -> None: + _ffi_api.PyStmtExprVisitorWithAnalyzerVisitExpr(self._outer(), expr) # type: ignore + + @tvm_ffi.register_object("tirx.PyStmtExprMutator") class _PyStmtExprMutator(tvm_ffi.core.Object): """ @@ -1065,6 +1212,121 @@ def __init__( ) +@tvm_ffi.register_object("tirx.PyStmtExprMutatorWithAnalyzer") +class _PyStmtExprMutatorWithAnalyzer(tvm_ffi.core.Object): + """ + An internal C++-backed wrapper for analyzer-aware Python StmtExprMutator. + """ + + def __init__( + self, + f_visit_stmt: Callable | None = None, + f_visit_expr: Callable | None = None, + # Stmt + f_visit_bind: Callable | None = None, + f_visit_attr_stmt: Callable | None = None, + f_visit_if_then_else: Callable | None = None, + f_visit_for: Callable | None = None, + f_visit_while: Callable | None = None, + f_visit_alloc_buffer: Callable | None = None, + f_visit_decl_buffer: Callable | None = None, + f_visit_buffer_store: Callable | None = None, + f_visit_assert_stmt: Callable | None = None, + f_visit_seq_stmt: Callable | None = None, + f_visit_evaluate: Callable | None = None, + f_visit_block: Callable | None = None, + f_visit_sblock_realize: Callable | None = None, + # PrimExpr + f_visit_var: Callable | None = None, + f_visit_size_var: Callable | None = None, + f_visit_buffer_load: Callable | None = None, + f_visit_producer_load: Callable | None = None, + f_visit_let: Callable | None = None, + f_visit_call: Callable | None = None, + f_visit_add: Callable | None = None, + f_visit_sub: Callable | None = None, + f_visit_mul: Callable | None = None, + f_visit_div: Callable | None = None, + f_visit_mod: Callable | None = None, + f_visit_floor_div: Callable | None = None, + f_visit_floor_mod: Callable | None = None, + f_visit_min: Callable | None = None, + f_visit_max: Callable | None = None, + f_visit_eq: Callable | None = None, + f_visit_ne: Callable | None = None, + f_visit_lt: Callable | None = None, + f_visit_le: Callable | None = None, + f_visit_gt: Callable | None = None, + f_visit_ge: Callable | None = None, + f_visit_and: Callable | None = None, + f_visit_or: Callable | None = None, + f_visit_reduce: Callable | None = None, + f_visit_cast: Callable | None = None, + f_visit_not: Callable | None = None, + f_visit_select: Callable | None = None, + f_visit_ramp: Callable | None = None, + f_visit_broadcast: Callable | None = None, + f_visit_shuffle: Callable | None = None, + f_visit_int_imm: Callable | None = None, + f_visit_float_imm: Callable | None = None, + f_visit_string_imm: Callable | None = None, + ) -> None: + """Constructor.""" + self.__init_handle_by_constructor__( + _ffi_api.MakePyStmtExprMutatorWithAnalyzer, # type: ignore + [ + f_visit_stmt, + f_visit_expr, + f_visit_bind, + f_visit_attr_stmt, + f_visit_if_then_else, + f_visit_for, + f_visit_while, + f_visit_alloc_buffer, + f_visit_decl_buffer, + f_visit_buffer_store, + f_visit_assert_stmt, + f_visit_seq_stmt, + f_visit_evaluate, + f_visit_block, + f_visit_sblock_realize, + f_visit_var, + f_visit_size_var, + f_visit_buffer_load, + f_visit_producer_load, + f_visit_let, + f_visit_call, + f_visit_add, + f_visit_sub, + f_visit_mul, + f_visit_div, + f_visit_mod, + f_visit_floor_div, + f_visit_floor_mod, + f_visit_min, + f_visit_max, + f_visit_eq, + f_visit_ne, + f_visit_lt, + f_visit_le, + f_visit_gt, + f_visit_ge, + f_visit_and, + f_visit_or, + f_visit_reduce, + f_visit_cast, + f_visit_not, + f_visit_select, + f_visit_ramp, + f_visit_broadcast, + f_visit_shuffle, + f_visit_int_imm, + f_visit_float_imm, + f_visit_string_imm, + ], + ) + + class PyStmtExprMutator: """ A Python StmtExprMutator to define custom mutator for both Stmt and PrimExpr. @@ -1976,3 +2238,120 @@ def visit_string_imm_(self, op: StringImm) -> PrimExpr: The mutated PrimExpr. """ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op) # type: ignore + + +class _AnalyzerBackedMutatorMixin: + @property + def analyzer(self): + analyzer = getattr(self, "_analyzer", None) + if analyzer is None: + mod = _ffi_api.PyStmtExprMutatorWithAnalyzerGetAnalyzer(self._outer()) # type: ignore + analyzer = _analyzer_from_module(mod) + self._analyzer = analyzer + return analyzer + + +class PyStmtExprMutatorWithAnalyzer(PyStmtExprMutator, _AnalyzerBackedMutatorMixin): + """A C++-backed Python StmtExprMutator with an arithmetic analyzer context.""" + + _tvm_metadata = { + **PyStmtExprMutator._tvm_metadata, + "cls": _PyStmtExprMutatorWithAnalyzer, + } + + def visit_expr(self, expr: PrimExpr) -> PrimExpr: + return _ffi_api.PyStmtExprMutatorWithAnalyzerVisitExpr(self._outer(), expr) # type: ignore + + def visit_stmt(self, stmt: Stmt) -> Stmt: + return _ffi_api.PyStmtExprMutatorWithAnalyzerVisitStmt(self._outer(), stmt) # type: ignore + + +def _make_analyzer_visitor_default_stmt(): + def visit(self, op): + _ffi_api.PyStmtExprVisitorWithAnalyzerDefaultVisitStmt(self._outer(), op) # type: ignore + + return visit + + +def _make_analyzer_visitor_default_expr(): + def visit(self, op): + _ffi_api.PyStmtExprVisitorWithAnalyzerDefaultVisitExpr(self._outer(), op) # type: ignore + + return visit + + +def _make_analyzer_mutator_default_stmt(): + def visit(self, op): + return _ffi_api.PyStmtExprMutatorWithAnalyzerDefaultVisitStmt(self._outer(), op) # type: ignore + + return visit + + +def _make_analyzer_mutator_default_expr(): + def visit(self, op): + return _ffi_api.PyStmtExprMutatorWithAnalyzerDefaultVisitExpr(self._outer(), op) # type: ignore + + return visit + + +_STMT_VISIT_METHODS = [ + "visit_bind_", + "visit_attr_stmt_", + "visit_if_then_else_", + "visit_for_", + "visit_while_", + "visit_alloc_buffer_", + "visit_decl_buffer_", + "visit_buffer_store_", + "visit_assert_stmt_", + "visit_seq_stmt_", + "visit_evaluate_", + "visit_sblock_", + "visit_sblock_realize_", +] + +_EXPR_VISIT_METHODS = [ + "visit_var_", + "visit_size_var_", + "visit_buffer_load_", + "visit_producer_load_", + "visit_let_", + "visit_call_", + "visit_add_", + "visit_sub_", + "visit_mul_", + "visit_div_", + "visit_mod_", + "visit_floor_div_", + "visit_floor_mod_", + "visit_min_", + "visit_max_", + "visit_eq_", + "visit_ne_", + "visit_lt_", + "visit_le_", + "visit_gt_", + "visit_ge_", + "visit_and_", + "visit_or_", + "visit_reduce_", + "visit_cast_", + "visit_not_", + "visit_select_", + "visit_ramp_", + "visit_broadcast_", + "visit_shuffle_", + "visit_int_imm_", + "visit_float_imm_", + "visit_string_imm_", +] + +for _method in _STMT_VISIT_METHODS: + setattr(PyStmtExprVisitorWithAnalyzer, _method, _make_analyzer_visitor_default_stmt()) + setattr(PyStmtExprMutatorWithAnalyzer, _method, _make_analyzer_mutator_default_stmt()) + +for _method in _EXPR_VISIT_METHODS: + setattr(PyStmtExprVisitorWithAnalyzer, _method, _make_analyzer_visitor_default_expr()) + setattr(PyStmtExprMutatorWithAnalyzer, _method, _make_analyzer_mutator_default_expr()) + +del _method diff --git a/src/tirx/ir/py_functor.cc b/src/tirx/ir/py_functor.cc index 65d7c3c1b45b..fcb0586c0391 100644 --- a/src/tirx/ir/py_functor.cc +++ b/src/tirx/ir/py_functor.cc @@ -23,13 +23,131 @@ * StmtExprVisitor/StmtExprMutator. */ +#include #include +#include +#include +#include +#include +#include #include +#include #include +#include +#include +#include + namespace tvm { namespace tirx { +namespace { + +ffi::Function MakeAnalyzerModule(std::shared_ptr analyzer) { + using ffi::Function; + using ffi::TypedFunction; + auto f = [analyzer](std::string name) -> ffi::Function { + if (name == "const_int_bound") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->const_int_bound(args[0].cast()); + }); + } else if (name == "modular_set") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->modular_set(args[0].cast()); + }); + } else if (name == "const_int_bound_update") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + analyzer->const_int_bound.Update(args[0].cast(), args[1].cast(), + args[2].cast()); + }); + } else if (name == "const_int_bound_is_bound") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->const_int_bound.IsBound(args[0].cast()); + }); + } else if (name == "Simplify") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + if (args.size() == 1) { + *ret = analyzer->Simplify(args[0].cast()); + } else if (args.size() == 2) { + *ret = analyzer->Simplify(args[0].cast(), args[1].cast()); + } else { + TVM_FFI_THROW(InternalError) << "Invalid size of argument (" << args.size() << ")"; + } + }); + } else if (name == "rewrite_simplify") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->rewrite_simplify(args[0].cast()); + }); + } else if (name == "get_rewrite_simplify_stats") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->rewrite_simplify.GetStatsCounters(); + }); + } else if (name == "reset_rewrite_simplify_stats") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + analyzer->rewrite_simplify.ResetStatsCounters(); + }); + } else if (name == "canonical_simplify") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->canonical_simplify(args[0].cast()); + }); + } else if (name == "int_set") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->int_set(args[0].cast(), + args[1].cast>()); + }); + } else if (name == "bind") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + bool allow_override = args.size() >= 3 && args[2].cast(); + if (auto opt_range = args[1].try_cast()) { + analyzer->Bind(args[0].cast(), opt_range.value(), allow_override); + } else { + analyzer->Bind(args[0].cast(), args[1].cast(), allow_override); + } + }); + } else if (name == "can_prove") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + int strength = args[1].cast(); + *ret = analyzer->CanProve(args[0].cast(), + static_cast(strength)); + }); + } else if (name == "enter_constraint_context") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + auto ctx = std::make_shared>(analyzer.get(), + args[0].cast()); + auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); }; + *ret = ffi::Function::FromPacked(fexit); + }); + } else if (name == "can_prove_equal") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = analyzer->CanProveEqual(args[0].cast(), args[1].cast()); + }); + } else if (name == "get_enabled_extensions") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + *ret = static_cast(analyzer->rewrite_simplify.GetEnabledExtensions()); + }); + } else if (name == "set_enabled_extensions") { + return ffi::Function([analyzer](ffi::PackedArgs args, ffi::Any* ret) { + int64_t flags = args[0].cast(); + analyzer->rewrite_simplify.SetEnabledExtensions( + static_cast(flags)); + }); + } + return ffi::Function(); + }; + return ffi::TypedFunction(f); +} + +PrimExpr ExtractRealCondition(PrimExpr condition) { + if (auto call = condition.as()) { + if (call->op.same_as(builtin::likely())) { + return call->args[0]; + } + } + return condition; +} + +} // namespace + // ================================================ // Helper Macros // ================================================ @@ -800,6 +918,757 @@ class PyStmtExprMutator : public ffi::ObjectRef { PyStmtExprMutatorNode); }; +#define PY_STMT_EXPR_FUNCTOR_CALLBACKS(V) \ + V(f_visit_stmt) \ + V(f_visit_expr) \ + V(f_visit_bind) \ + V(f_visit_attr_stmt) \ + V(f_visit_if_then_else) \ + V(f_visit_for) \ + V(f_visit_while) \ + V(f_visit_alloc_buffer) \ + V(f_visit_decl_buffer) \ + V(f_visit_buffer_store) \ + V(f_visit_assert_stmt) \ + V(f_visit_seq_stmt) \ + V(f_visit_evaluate) \ + V(f_visit_block) \ + V(f_visit_sblock_realize) \ + V(f_visit_var) \ + V(f_visit_size_var) \ + V(f_visit_buffer_load) \ + V(f_visit_producer_load) \ + V(f_visit_let) \ + V(f_visit_call) \ + V(f_visit_add) \ + V(f_visit_sub) \ + V(f_visit_mul) \ + V(f_visit_div) \ + V(f_visit_mod) \ + V(f_visit_floor_div) \ + V(f_visit_floor_mod) \ + V(f_visit_min) \ + V(f_visit_max) \ + V(f_visit_eq) \ + V(f_visit_ne) \ + V(f_visit_lt) \ + V(f_visit_le) \ + V(f_visit_gt) \ + V(f_visit_ge) \ + V(f_visit_and) \ + V(f_visit_or) \ + V(f_visit_reduce) \ + V(f_visit_cast) \ + V(f_visit_not) \ + V(f_visit_select) \ + V(f_visit_ramp) \ + V(f_visit_broadcast) \ + V(f_visit_shuffle) \ + V(f_visit_int_imm) \ + V(f_visit_float_imm) \ + V(f_visit_string_imm) + +template +void SetStmtExprFunctorCallbacks(TNode* node, + const ffi::Array>& callbacks) { + int index = 0; +#define SET_CALLBACK(FIELD) node->FIELD = callbacks[index++].value_or(ffi::Function(nullptr)); + PY_STMT_EXPR_FUNCTOR_CALLBACKS(SET_CALLBACK) +#undef SET_CALLBACK + TVM_FFI_ICHECK_EQ(index, callbacks.size()); +} + +#define PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(OP, METHOD) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + self->METHOD(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_EXPR_VISITOR_DEFAULT_DISPATCH(OP, METHOD) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + self->METHOD(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(OP) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + self->StmtExprVisitor::VisitStmt_(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(OP) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + self->StmtExprVisitor::VisitExpr_(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(OP, METHOD) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + return self->METHOD(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_EXPR_MUTATOR_DEFAULT_DISPATCH(OP, METHOD) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + return self->METHOD(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(OP) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + return self->StmtExprMutator::VisitStmt_(static_cast(n.get())); \ + }); + +#define PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(OP) \ + vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self) { \ + return self->StmtExprMutator::VisitExpr_(static_cast(n.get())); \ + }); + +class PyStmtExprVisitorWithAnalyzerNode : public PyStmtExprVisitorNode { + private: + using TSelf = PyStmtExprVisitorWithAnalyzerNode; + using FExprType = tvm::NodeFunctor; + using FStmtType = tvm::NodeFunctor; + + public: + PyStmtExprVisitorWithAnalyzerNode() = default; + PyStmtExprVisitorWithAnalyzerNode(const PyStmtExprVisitorWithAnalyzerNode& other) + : PyStmtExprVisitorNode(other) {} + + ffi::Function GetAnalyzer() { return MakeAnalyzerModule(analyzer_); } + + void DefaultVisitExprWithAnalyzer(const PrimExpr& expr) { + static FExprType vtable = InitExprVTable(); + vtable(expr, this); + } + + void DefaultVisitStmtWithAnalyzer(const Stmt& stmt) { + static FStmtType vtable = InitStmtVTable(); + vtable(stmt, this); + } + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("tirx.PyStmtExprVisitorWithAnalyzer", + PyStmtExprVisitorWithAnalyzerNode, PyStmtExprVisitorNode); + + private: + std::shared_ptr analyzer_{std::make_shared()}; + ScopeStack> constraint_scope_; + + void VisitStmt_(const ForNode* op) override { + if (f_visit_for != nullptr) { + f_visit_for(op); + } else { + DefaultVisitFor(op); + } + } + + void VisitStmt_(const SBlockNode* op) override { + if (f_visit_block != nullptr) { + f_visit_block(op); + } else { + DefaultVisitSBlock(op); + } + } + + void VisitStmt_(const BindNode* op) override { + if (f_visit_bind != nullptr) { + f_visit_bind(op); + } else { + DefaultVisitBind(op); + } + } + + void VisitStmt_(const IfThenElseNode* op) override { + if (f_visit_if_then_else != nullptr) { + f_visit_if_then_else(op); + } else { + DefaultVisitIfThenElse(op); + } + } + + void VisitStmt_(const AttrStmtNode* op) override { + if (f_visit_attr_stmt != nullptr) { + f_visit_attr_stmt(op); + } else { + DefaultVisitAttrStmt(op); + } + } + + void VisitStmt_(const AssertStmtNode* op) override { + if (f_visit_assert_stmt != nullptr) { + f_visit_assert_stmt(op); + } else { + DefaultVisitAssertStmt(op); + } + } + + void VisitStmt_(const SeqStmtNode* op) override { + if (f_visit_seq_stmt != nullptr) { + f_visit_seq_stmt(op); + } else { + DefaultVisitSeqStmt(op); + } + } + + void VisitExpr_(const CallNode* op) override { + if (f_visit_call != nullptr) { + f_visit_call(op); + } else { + DefaultVisitCall(op); + } + } + + void VisitExpr_(const LetNode* op) override { + if (f_visit_let != nullptr) { + f_visit_let(op); + } else { + DefaultVisitLet(op); + } + } + + void VisitExpr_(const ReduceNode* op) override { + if (f_visit_reduce != nullptr) { + f_visit_reduce(op); + } else { + DefaultVisitReduce(op); + } + } + + void DefaultVisitFor(const ForNode* op) { + constraint_scope_.WithNewScope([&]() { + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + StmtExprVisitor::VisitStmt_(op); + }); + } + + void DefaultVisitSBlock(const SBlockNode* op) { + constraint_scope_.WithNewScope([&]() { + for (const IterVar& iter_var : op->iter_vars) { + analyzer_->Bind(iter_var->var, iter_var->dom); + } + StmtExprVisitor::VisitStmt_(op); + }); + } + + void DefaultVisitBind(const BindNode* op) { + this->VisitExpr(op->value); + if (SideEffect(op->value) <= CallEffectKind::kPure) { + analyzer_->Bind(op->var, op->value); + } + } + + void DefaultVisitIfThenElse(const IfThenElseNode* op) { + constraint_scope_.WithNewScope([&]() { + this->VisitExpr(op->condition); + PrimExpr real_condition = ExtractRealCondition(op->condition); + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), real_condition); + this->VisitStmt(op->then_case); + }); + if (op->else_case) { + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), + analyzer_->rewrite_simplify(Not(real_condition))); + this->VisitStmt(op->else_case.value()); + }); + } + }); + } + + void DefaultVisitAttrStmt(const AttrStmtNode* op) { + constraint_scope_.WithNewScope([&]() { + if (op->attr_key == tirx::attr::thread_extent || + op->attr_key == s_tir::attr::virtual_thread) { + IterVar iv = Downcast(op->node); + TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); + analyzer_->Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value)); + } + StmtExprVisitor::VisitStmt_(op); + }); + } + + void DefaultVisitAssertStmt(const AssertStmtNode* op) { + this->VisitExpr(op->condition); + constraint_scope_.Current().Emplace(analyzer_.get(), op->condition); + this->VisitExpr(op->error_kind); + for (const StringImm& message : op->message_parts) { + this->VisitExpr(message); + } + } + + void DefaultVisitSeqStmt(const SeqStmtNode* op) { StmtExprVisitor::VisitStmt_(op); } + + void DefaultVisitCall(const CallNode* op) { + static auto op_if_then_else = Op::Get("tirx.if_then_else"); + if (op->op.same_as(op_if_then_else)) { + PrimExpr cond = op->args[0]; + this->VisitExpr(op->args[0]); + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), cond); + this->VisitExpr(op->args[1]); + }); + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), + analyzer_->rewrite_simplify(Not(cond))); + this->VisitExpr(op->args[2]); + }); + } else { + StmtExprVisitor::VisitExpr_(op); + } + } + + void DefaultVisitLet(const LetNode* op) { + this->VisitExpr(op->value); + if (SideEffect(op->value) <= CallEffectKind::kPure) { + analyzer_->Bind(op->var, op->value); + } + this->VisitExpr(op->body); + } + + void DefaultVisitReduce(const ReduceNode* op) { + for (const IterVar& iv : op->axis) { + analyzer_->Bind(iv->var, iv->dom); + } + StmtExprVisitor::VisitExpr_(op); + } + + static FStmtType InitStmtVTable() { + FStmtType vtable; + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(BindNode, DefaultVisitBind); + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(AttrStmtNode, DefaultVisitAttrStmt); + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(IfThenElseNode, DefaultVisitIfThenElse); + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(ForNode, DefaultVisitFor); + PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(WhileNode); + PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(AllocBufferNode); + PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(DeclBufferNode); + PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(BufferStoreNode); + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(AssertStmtNode, DefaultVisitAssertStmt); + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(SeqStmtNode, DefaultVisitSeqStmt); + PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(EvaluateNode); + PY_ANALYZER_STMT_VISITOR_DEFAULT_DISPATCH(SBlockNode, DefaultVisitSBlock); + PY_ANALYZER_STMT_VISITOR_BASE_DISPATCH(SBlockRealizeNode); + vtable.Finalize(); + return vtable; + } + + static FExprType InitExprVTable() { + FExprType vtable; + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(VarNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(SizeVarNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(BufferLoadNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(ProducerLoadNode); + PY_ANALYZER_EXPR_VISITOR_DEFAULT_DISPATCH(LetNode, DefaultVisitLet); + PY_ANALYZER_EXPR_VISITOR_DEFAULT_DISPATCH(CallNode, DefaultVisitCall); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(AddNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(SubNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(MulNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(DivNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(ModNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(FloorDivNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(FloorModNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(MinNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(MaxNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(EQNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(NENode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(LTNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(LENode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(GTNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(GENode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(AndNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(OrNode); + PY_ANALYZER_EXPR_VISITOR_DEFAULT_DISPATCH(ReduceNode, DefaultVisitReduce); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(CastNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(NotNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(SelectNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(RampNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(ShuffleNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(BroadcastNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(IntImmNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(FloatImmNode); + PY_ANALYZER_EXPR_VISITOR_BASE_DISPATCH(StringImmNode); + vtable.Finalize(); + return vtable; + } +}; + +class PyStmtExprVisitorWithAnalyzer : public ffi::ObjectRef { + public: + explicit PyStmtExprVisitorWithAnalyzer(ffi::ObjectPtr data) + : ffi::ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } + + TVM_DLL static PyStmtExprVisitorWithAnalyzer MakePyStmtExprVisitorWithAnalyzer( + ffi::Array> callbacks) { + ffi::ObjectPtr n = + ffi::make_object(); + SetStmtExprFunctorCallbacks(n.get(), callbacks); + return PyStmtExprVisitorWithAnalyzer(n); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PyStmtExprVisitorWithAnalyzer, ffi::ObjectRef, + PyStmtExprVisitorWithAnalyzerNode); +}; + +class PyStmtExprMutatorWithAnalyzerNode : public PyStmtExprMutatorNode { + private: + using TSelf = PyStmtExprMutatorWithAnalyzerNode; + using FExprType = tvm::NodeFunctor; + using FStmtType = tvm::NodeFunctor; + + public: + PyStmtExprMutatorWithAnalyzerNode() = default; + PyStmtExprMutatorWithAnalyzerNode(const PyStmtExprMutatorWithAnalyzerNode& other) + : PyStmtExprMutatorNode(other) {} + + ffi::Function GetAnalyzer() { return MakeAnalyzerModule(analyzer_); } + + PrimExpr DefaultVisitExprWithAnalyzer(const PrimExpr& expr) { + static FExprType vtable = InitExprVTable(); + return vtable(expr, this); + } + + Stmt DefaultVisitStmtWithAnalyzer(const Stmt& stmt) { + static FStmtType vtable = InitStmtVTable(); + return vtable(stmt, this); + } + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("tirx.PyStmtExprMutatorWithAnalyzer", + PyStmtExprMutatorWithAnalyzerNode, PyStmtExprMutatorNode); + + private: + std::shared_ptr analyzer_{std::make_shared()}; + ScopeStack> constraint_scope_; + + Stmt VisitStmt_(const ForNode* op) override { + if (f_visit_for != nullptr) { + return f_visit_for(op).cast(); + } + return DefaultVisitFor(op); + } + + Stmt VisitStmt_(const SBlockNode* op) override { + if (f_visit_block != nullptr) { + return f_visit_block(op).cast(); + } + return DefaultVisitSBlock(op); + } + + Stmt VisitStmt_(const BindNode* op) override { + if (f_visit_bind != nullptr) { + return f_visit_bind(op).cast(); + } + return DefaultVisitBind(op); + } + + Stmt VisitStmt_(const IfThenElseNode* op) override { + if (f_visit_if_then_else != nullptr) { + return f_visit_if_then_else(op).cast(); + } + return DefaultVisitIfThenElse(op); + } + + Stmt VisitStmt_(const AttrStmtNode* op) override { + if (f_visit_attr_stmt != nullptr) { + return f_visit_attr_stmt(op).cast(); + } + return DefaultVisitAttrStmt(op); + } + + Stmt VisitStmt_(const AssertStmtNode* op) override { + if (f_visit_assert_stmt != nullptr) { + return f_visit_assert_stmt(op).cast(); + } + return DefaultVisitAssertStmt(op); + } + + Stmt VisitStmt_(const SeqStmtNode* op) override { + if (f_visit_seq_stmt != nullptr) { + return f_visit_seq_stmt(op).cast(); + } + return DefaultVisitSeqStmt(op); + } + + PrimExpr VisitExpr_(const CallNode* op) override { + if (f_visit_call != nullptr) { + return f_visit_call(op).cast(); + } + return DefaultVisitCall(op); + } + + PrimExpr VisitExpr_(const LetNode* op) override { + if (f_visit_let != nullptr) { + return f_visit_let(op).cast(); + } + return DefaultVisitLet(op); + } + + PrimExpr VisitExpr_(const SelectNode* op) override { + if (f_visit_select != nullptr) { + return f_visit_select(op).cast(); + } + return DefaultVisitSelect(op); + } + + PrimExpr VisitExpr_(const ReduceNode* op) override { + if (f_visit_reduce != nullptr) { + return f_visit_reduce(op).cast(); + } + return DefaultVisitReduce(op); + } + + Stmt DefaultVisitFor(const ForNode* op) { + return constraint_scope_.WithNewScope([&]() -> Stmt { + Range dom = Range::FromMinExtent(op->min, op->extent); + analyzer_->Bind(op->loop_var, dom); + return StmtExprMutator::VisitStmt_(op); + }); + } + + Stmt DefaultVisitSBlock(const SBlockNode* op) { + return constraint_scope_.WithNewScope([&]() -> Stmt { + for (const IterVar& iter_var : op->iter_vars) { + analyzer_->Bind(iter_var->var, iter_var->dom); + } + return StmtExprMutator::VisitStmt_(op); + }); + } + + Stmt DefaultVisitBind(const BindNode* op) { + PrimExpr value = this->VisitExpr(op->value); + if (SideEffect(value) <= CallEffectKind::kPure) { + analyzer_->Bind(op->var, value); + } + if (value.same_as(op->value)) { + return ffi::GetRef(op); + } + auto n = this->CopyOnWrite(op); + n->value = std::move(value); + return Stmt(n); + } + + Stmt DefaultVisitIfThenElse(const IfThenElseNode* op) { + return constraint_scope_.WithNewScope([&]() -> Stmt { + PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr real_condition = ExtractRealCondition(condition); + Stmt then_case; + ffi::Optional else_case; + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), real_condition); + then_case = this->VisitStmt(op->then_case); + }); + if (op->else_case) { + PrimExpr neg_condition = analyzer_->rewrite_simplify(Not(real_condition)); + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), neg_condition); + else_case = this->VisitStmt(op->else_case.value()); + }); + } + if (is_one(real_condition)) return then_case; + if (is_zero(real_condition)) return else_case.value_or(Evaluate(0)); + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return ffi::GetRef(op); + } + auto n = this->CopyOnWrite(op); + n->condition = std::move(condition); + n->then_case = std::move(then_case); + n->else_case = std::move(else_case); + return Stmt(n); + }); + } + + Stmt DefaultVisitAttrStmt(const AttrStmtNode* op) { + return constraint_scope_.WithNewScope([&]() -> Stmt { + if (op->attr_key == tirx::attr::thread_extent || + op->attr_key == s_tir::attr::virtual_thread) { + IterVar iv = Downcast(op->node); + TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); + analyzer_->Bind(iv->var, Range::FromMinExtent(make_zero(op->value.dtype()), op->value)); + } + return StmtExprMutator::VisitStmt_(op); + }); + } + + Stmt DefaultVisitAssertStmt(const AssertStmtNode* op) { + PrimExpr condition = this->VisitExpr(op->condition); + constraint_scope_.Current().Emplace(analyzer_.get(), condition); + PrimExpr error_kind = this->VisitExpr(op->error_kind); + ffi::Array message_parts; + bool message_parts_same = true; + for (const StringImm& message : op->message_parts) { + StringImm new_message = Downcast(this->VisitExpr(message)); + if (!new_message.same_as(message)) { + message_parts_same = false; + } + message_parts.push_back(std::move(new_message)); + } + if (condition.same_as(op->condition) && error_kind.same_as(op->error_kind) && + message_parts_same) { + return ffi::GetRef(op); + } + auto n = this->CopyOnWrite(op); + n->condition = std::move(condition); + n->error_kind = Downcast(std::move(error_kind)); + if (!message_parts_same) { + n->message_parts = std::move(message_parts); + } + return Stmt(n); + } + + Stmt DefaultVisitSeqStmt(const SeqStmtNode* op) { return StmtExprMutator::VisitStmt_(op); } + + PrimExpr DefaultVisitCall(const CallNode* op) { + static auto op_if_then_else = Op::Get("tirx.if_then_else"); + if (op->op.same_as(op_if_then_else)) { + PrimExpr cond = this->VisitExpr(op->args[0]); + PrimExpr true_value; + PrimExpr false_value; + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), cond); + true_value = this->VisitExpr(op->args[1]); + }); + constraint_scope_.WithNewScope([&]() { + PrimExpr not_cond = analyzer_->rewrite_simplify(Not(cond)); + constraint_scope_.Current().Emplace(analyzer_.get(), not_cond); + false_value = this->VisitExpr(op->args[2]); + }); + if (is_zero(cond)) return false_value; + if (is_one(cond)) return true_value; + if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) && + false_value.same_as(op->args[2])) { + return ffi::GetRef(op); + } + return Call(op->dtype, op->op, {cond, true_value, false_value}, op->annotations, op->span); + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr DefaultVisitLet(const LetNode* op) { + PrimExpr value = this->VisitExpr(op->value); + if (SideEffect(value) <= CallEffectKind::kPure) { + analyzer_->Bind(op->var, value); + } + PrimExpr body = this->VisitExpr(op->body); + if (value.same_as(op->value) && body.same_as(op->body)) { + return ffi::GetRef(op); + } + return Let(op->var, value, body); + } + + PrimExpr DefaultVisitSelect(const SelectNode* op) { + PrimExpr cond = this->VisitExpr(op->condition); + PrimExpr true_value; + PrimExpr false_value; + constraint_scope_.WithNewScope([&]() { + constraint_scope_.Current().Emplace(analyzer_.get(), cond); + true_value = this->VisitExpr(op->true_value); + }); + constraint_scope_.WithNewScope([&]() { + PrimExpr neg_cond = analyzer_->rewrite_simplify(Not(cond)); + constraint_scope_.Current().Emplace(analyzer_.get(), neg_cond); + false_value = this->VisitExpr(op->false_value); + }); + if (is_zero(cond)) return false_value; + if (is_one(cond)) return true_value; + if (cond.same_as(op->condition) && true_value.same_as(op->true_value) && + false_value.same_as(op->false_value)) { + return ffi::GetRef(op); + } + return Select(cond, true_value, false_value); + } + + PrimExpr DefaultVisitReduce(const ReduceNode* op) { + for (const IterVar& iv : op->axis) { + analyzer_->Bind(iv->var, iv->dom); + } + return StmtExprMutator::VisitExpr_(op); + } + + static FStmtType InitStmtVTable() { + FStmtType vtable; + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(BindNode, DefaultVisitBind); + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(AttrStmtNode, DefaultVisitAttrStmt); + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(IfThenElseNode, DefaultVisitIfThenElse); + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(ForNode, DefaultVisitFor); + PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(WhileNode); + PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(AllocBufferNode); + PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(DeclBufferNode); + PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(BufferStoreNode); + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(AssertStmtNode, DefaultVisitAssertStmt); + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(SeqStmtNode, DefaultVisitSeqStmt); + PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(EvaluateNode); + PY_ANALYZER_STMT_MUTATOR_DEFAULT_DISPATCH(SBlockNode, DefaultVisitSBlock); + PY_ANALYZER_STMT_MUTATOR_BASE_DISPATCH(SBlockRealizeNode); + vtable.Finalize(); + return vtable; + } + + static FExprType InitExprVTable() { + FExprType vtable; + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(VarNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(SizeVarNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(BufferLoadNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(ProducerLoadNode); + PY_ANALYZER_EXPR_MUTATOR_DEFAULT_DISPATCH(LetNode, DefaultVisitLet); + PY_ANALYZER_EXPR_MUTATOR_DEFAULT_DISPATCH(CallNode, DefaultVisitCall); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(AddNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(SubNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(MulNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(DivNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(ModNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(FloorDivNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(FloorModNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(MinNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(MaxNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(EQNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(NENode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(LTNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(LENode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(GTNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(GENode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(AndNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(OrNode); + PY_ANALYZER_EXPR_MUTATOR_DEFAULT_DISPATCH(ReduceNode, DefaultVisitReduce); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(CastNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(NotNode); + PY_ANALYZER_EXPR_MUTATOR_DEFAULT_DISPATCH(SelectNode, DefaultVisitSelect); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(RampNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(ShuffleNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(BroadcastNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(IntImmNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(FloatImmNode); + PY_ANALYZER_EXPR_MUTATOR_BASE_DISPATCH(StringImmNode); + vtable.Finalize(); + return vtable; + } +}; + +class PyStmtExprMutatorWithAnalyzer : public ffi::ObjectRef { + public: + explicit PyStmtExprMutatorWithAnalyzer(ffi::ObjectPtr data) + : ffi::ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } + + TVM_DLL static PyStmtExprMutatorWithAnalyzer MakePyStmtExprMutatorWithAnalyzer( + ffi::Array> callbacks) { + ffi::ObjectPtr n = + ffi::make_object(); + SetStmtExprFunctorCallbacks(n.get(), callbacks); + return PyStmtExprMutatorWithAnalyzer(n); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PyStmtExprMutatorWithAnalyzer, ffi::ObjectRef, + PyStmtExprMutatorWithAnalyzerNode); +}; + // ================================================ // TVM Register // ================================================ @@ -807,13 +1676,19 @@ class PyStmtExprMutator : public ffi::ObjectRef { TVM_FFI_STATIC_INIT_BLOCK() { PyStmtExprVisitorNode::RegisterReflection(); PyStmtExprMutatorNode::RegisterReflection(); + PyStmtExprVisitorWithAnalyzerNode::RegisterReflection(); + PyStmtExprMutatorWithAnalyzerNode::RegisterReflection(); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tirx.MakePyStmtExprVisitor", PyStmtExprVisitor::MakePyStmtExprVisitor) - .def("tirx.MakePyStmtExprMutator", PyStmtExprMutator::MakePyStmtExprMutator); + .def("tirx.MakePyStmtExprMutator", PyStmtExprMutator::MakePyStmtExprMutator) + .def("tirx.MakePyStmtExprVisitorWithAnalyzer", + PyStmtExprVisitorWithAnalyzer::MakePyStmtExprVisitorWithAnalyzer) + .def("tirx.MakePyStmtExprMutatorWithAnalyzer", + PyStmtExprMutatorWithAnalyzer::MakePyStmtExprMutatorWithAnalyzer); } // StmtExprVisitor @@ -830,6 +1705,29 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](PyStmtExprVisitor visitor, const PrimExpr& expr) { visitor->VisitExpr(expr); }); } +// StmtExprVisitorWithAnalyzer +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tirx.PyStmtExprVisitorWithAnalyzerDefaultVisitExpr", + [](PyStmtExprVisitorWithAnalyzer visitor, const PrimExpr& expr) { + visitor->DefaultVisitExprWithAnalyzer(expr); + }) + .def("tirx.PyStmtExprVisitorWithAnalyzerDefaultVisitStmt", + [](PyStmtExprVisitorWithAnalyzer visitor, const Stmt& stmt) { + visitor->DefaultVisitStmtWithAnalyzer(stmt); + }) + .def( + "tirx.PyStmtExprVisitorWithAnalyzerVisitStmt", + [](PyStmtExprVisitorWithAnalyzer visitor, const Stmt& stmt) { visitor->VisitStmt(stmt); }) + .def("tirx.PyStmtExprVisitorWithAnalyzerVisitExpr", + [](PyStmtExprVisitorWithAnalyzer visitor, const PrimExpr& expr) { + visitor->VisitExpr(expr); + }) + .def("tirx.PyStmtExprVisitorWithAnalyzerGetAnalyzer", + [](PyStmtExprVisitorWithAnalyzer visitor) { return visitor->GetAnalyzer(); }); +} + // StmtExprMutator TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; @@ -848,5 +1746,29 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](PyStmtExprMutator mutator, const Stmt& stmt) { return mutator->VisitStmt(stmt); }); } +// StmtExprMutatorWithAnalyzer +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tirx.PyStmtExprMutatorWithAnalyzerDefaultVisitExpr", + [](PyStmtExprMutatorWithAnalyzer mutator, const PrimExpr& expr) { + return mutator->DefaultVisitExprWithAnalyzer(expr); + }) + .def("tirx.PyStmtExprMutatorWithAnalyzerDefaultVisitStmt", + [](PyStmtExprMutatorWithAnalyzer mutator, const Stmt& stmt) { + return mutator->DefaultVisitStmtWithAnalyzer(stmt); + }) + .def("tirx.PyStmtExprMutatorWithAnalyzerVisitExpr", + [](PyStmtExprMutatorWithAnalyzer mutator, const PrimExpr& expr) { + return mutator->VisitExpr(expr); + }) + .def("tirx.PyStmtExprMutatorWithAnalyzerVisitStmt", + [](PyStmtExprMutatorWithAnalyzer mutator, const Stmt& stmt) { + return mutator->VisitStmt(stmt); + }) + .def("tirx.PyStmtExprMutatorWithAnalyzerGetAnalyzer", + [](PyStmtExprMutatorWithAnalyzer mutator) { return mutator->GetAnalyzer(); }); +} + } // namespace tirx } // namespace tvm diff --git a/tests/python/tirx-transform/test_tir_functor.py b/tests/python/tirx-transform/test_tir_functor.py index 021acd8fb60b..c9d2d0324e91 100644 --- a/tests/python/tirx-transform/test_tir_functor.py +++ b/tests/python/tirx-transform/test_tir_functor.py @@ -21,8 +21,10 @@ from tvm import tirx from tvm.tirx import ( EQ, + GE, LT, Add, + AssertStmt, Cast, Evaluate, FloatImm, @@ -33,7 +35,11 @@ Min, Mul, PyStmtExprMutator, + PyStmtExprMutatorWithAnalyzer, PyStmtExprVisitor, + PyStmtExprVisitorWithAnalyzer, + Select, + SeqStmt, StringImm, Sub, Var, @@ -202,6 +208,87 @@ def visit_add_(self, op: Add): return Add(Mul(a, IntImm("int32", 2)), b) +@tirx.functor.visitor +class AnalyzerAwareVisitor(PyStmtExprVisitorWithAnalyzer): + """Record analyzer facts visible from Python visitor callbacks.""" + + def __init__(self, var): + super().__init__() + self.var = var + self.facts = [] + + def visit_evaluate_(self, op: Evaluate): + if op.value.same_as(self.var): + self.facts.append(self.analyzer.can_prove(GE(self.var, IntImm("int32", 0)))) + self.facts.append(self.analyzer.can_prove(LT(self.var, IntImm("int32", 10)))) + super().visit_evaluate_(op) + + +@tirx.functor.mutator +class AnalyzerAwareMutator(PyStmtExprMutatorWithAnalyzer): + """Use branch constraints from the analyzer to rewrite proven predicates.""" + + def _rewrite_if_proven(self, value): + if value.dtype == "bool" and self.analyzer.can_prove(value): + return IntImm("bool", True) + return value + + def visit_lt_(self, op: LT): + a = self.visit_expr(op.a) + b = self.visit_expr(op.b) + value = op if a.same_as(op.a) and b.same_as(op.b) else LT(a, b, op.span) + return self._rewrite_if_proven(value) + + def visit_ge_(self, op: GE): + a = self.visit_expr(op.a) + b = self.visit_expr(op.b) + value = op if a.same_as(op.a) and b.same_as(op.b) else GE(a, b, op.span) + return self._rewrite_if_proven(value) + + def visit_evaluate_(self, op: Evaluate): + value = self.visit_expr(op.value) + value = self._rewrite_if_proven(value) + if value.same_as(op.value): + return op + return Evaluate(value, op.span) + + +@tirx.functor.visitor +class PredicateVisitor(PyStmtExprVisitorWithAnalyzer): + """Record whether boolean Evaluate nodes are provable in analyzer context.""" + + def __init__(self): + super().__init__() + self.facts = [] + + def visit_evaluate_(self, op: Evaluate): + if op.value.dtype == "bool": + self.facts.append(self.analyzer.can_prove(op.value)) + super().visit_evaluate_(op) + + +@tirx.functor.visitor +class AssertMessageVisitor(PyStmtExprVisitorWithAnalyzer): + """Record string immediates reached through analyzer-aware assert traversal.""" + + def __init__(self): + super().__init__() + self.strings = [] + + def visit_string_imm_(self, op: StringImm): + self.strings.append(op.value) + + +@tirx.functor.mutator +class AssertMessageMutator(PyStmtExprMutatorWithAnalyzer): + """Rewrite assert message strings through analyzer-aware traversal.""" + + def visit_string_imm_(self, op: StringImm): + if op.value == "bad": + return StringImm("rewritten") + return op + + def test_basic_visitor(): """Test the basic AST printer visitor""" expr = Add(Var("x", dtype="int32"), Var("y", dtype="int32")) @@ -330,6 +417,161 @@ def test_complex_mutator(): assert isinstance(modified_expr.a, Mul) # First operand should be multiplied by 2 +def test_analyzer_aware_visitor_loop_context(): + """Test that analyzer-aware visitors expose loop bounds to Python callbacks.""" + i = Var("i", dtype="int32") + stmt = For( + i, + IntImm("int32", 0), + IntImm("int32", 10), + tirx.ForKind.SERIAL, + Evaluate(i), + ) + + visitor = AnalyzerAwareVisitor(i) + visitor.visit_stmt(stmt) + + assert visitor.facts == [True, True] + + +def test_analyzer_aware_mutator_branch_context(): + """Test that analyzer-aware mutators expose branch predicates to Python callbacks.""" + x = Var("x", dtype="int32") + stmt = IfThenElse( + LT(x, IntImm("int32", 4)), + Evaluate(LT(x, IntImm("int32", 8))), + Evaluate(GE(x, IntImm("int32", 4))), + ) + + result = AnalyzerAwareMutator().visit_stmt(stmt) + + assert isinstance(result, IfThenElse) + assert isinstance(result.then_case.value, IntImm) + assert isinstance(result.else_case.value, IntImm) + assert result.then_case.value.value == 1 + assert result.else_case.value.value == 1 + + +def test_analyzer_aware_visitor_assert_context(): + """Test that assert constraints are visible to later statements in the same sequence.""" + x = Var("x", dtype="int32") + stmt = SeqStmt( + [ + AssertStmt(LT(x, IntImm("int32", 4)), StringImm("ValueError")), + Evaluate(LT(x, IntImm("int32", 8))), + ] + ) + + visitor = PredicateVisitor() + visitor.visit_stmt(stmt) + + assert visitor.facts == [True] + + +def test_analyzer_aware_assert_visits_error_kind_and_message(): + """Test that analyzer-aware assert traversal covers error kind and message parts.""" + x = Var("x", dtype="int32") + stmt = AssertStmt(LT(x, IntImm("int32", 4)), StringImm("ValueError"), [StringImm("bad")]) + + visitor = AssertMessageVisitor() + visitor.visit_stmt(stmt) + + assert visitor.strings == ["ValueError", "bad"] + + result = AssertMessageMutator().visit_stmt(stmt) + + assert result.error_kind.value == "ValueError" + assert [part.value for part in result.message_parts] == ["rewritten"] + + +def test_analyzer_aware_visitor_pure_bind_context(): + """Test that pure Bind values are visible to later statements in the same sequence.""" + x = Var("x", dtype="int32") + stmt = SeqStmt( + [ + tirx.Bind(x, IntImm("int32", 4)), + Evaluate(GE(x, IntImm("int32", 4))), + ] + ) + + visitor = PredicateVisitor() + visitor.visit_stmt(stmt) + + assert visitor.facts == [True] + + +def test_analyzer_aware_mutator_skips_opaque_bind_context(): + """Test that opaque Bind values are not inserted into the analyzer context.""" + h = Var("h", dtype="handle") + stmt = SeqStmt( + [ + tirx.Bind(h, tirx.tvm_stack_alloca("tvm_ffi_any", 1)), + Evaluate(IntImm("int32", 0)), + ] + ) + + result = AnalyzerAwareMutator().visit_stmt(stmt) + + assert isinstance(result, tirx.Bind) + assert result.var.same_as(h) + + +def test_analyzer_aware_visitor_branch_assert_does_not_leak(): + """Test that assert constraints inside a branch do not leak to following statements.""" + x = Var("x", dtype="int32") + stmt = SeqStmt( + [ + IfThenElse( + LT(x, IntImm("int32", 4)), + AssertStmt(LT(x, IntImm("int32", 1)), StringImm("ValueError")), + None, + ), + Evaluate(LT(x, IntImm("int32", 1))), + ] + ) + + visitor = PredicateVisitor() + visitor.visit_stmt(stmt) + + assert visitor.facts == [False] + + +def test_analyzer_aware_mutator_select_context(): + """Test that analyzer-aware mutators expose Select branch predicates.""" + x = Var("x", dtype="int32") + stmt = Evaluate( + Select( + LT(x, IntImm("int32", 4)), + LT(x, IntImm("int32", 8)), + GE(x, IntImm("int32", 4)), + ) + ) + + result = AnalyzerAwareMutator().visit_stmt(stmt) + + assert isinstance(result.value, IntImm) + assert result.value.value == 1 + + +def test_analyzer_aware_mutator_if_then_else_call_context(): + """Test that analyzer-aware mutators expose tirx.if_then_else expression predicates.""" + x = Var("x", dtype="int32") + stmt = Evaluate( + tirx.if_then_else( + LT(x, IntImm("int32", 4)), + LT(x, IntImm("int32", 8)), + GE(x, IntImm("int32", 4)), + ) + ) + + result = AnalyzerAwareMutator().visit_stmt(stmt) + + assert isinstance(result.value.args[1], IntImm) + assert isinstance(result.value.args[2], IntImm) + assert result.value.args[1].value == 1 + assert result.value.args[2].value == 1 + + def test_different_expr_types(): """Test visitor with various expression types""" x = Var("x", dtype="int32")