From 21958f1f8930f84114269e2d734becff408c58ca Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 5 Apr 2026 15:14:34 -0700 Subject: [PATCH] refactor(ffi): always build with bundled tvm-ffi and align include paths Switch to always using the bundled tvm-ffi submodule for builds. Update internal include paths and Python imports to match the tvm-ffi submodule restructure: headers moved from `tvm/ffi/ir/text/` to `tvm/ffi/extra/` and Python modules from `tvm_ffi.ir.text` to `tvm_ffi.pyast`. No public API changes. All 99 test suites (19,800 tests) pass. --- .gitignore | 7 + 3rdparty/tvm-ffi | 2 +- CMakeLists.txt | 1 + include/tvm/ir/attrs.h | 6 +- include/tvm/ir/expr.h | 20 +- include/tvm/ir/global_info.h | 9 +- include/tvm/ir/op.h | 5 +- include/tvm/ir/type.h | 18 +- include/tvm/relax/distributed/struct_info.h | 6 +- include/tvm/relax/expr.h | 75 +- include/tvm/relax/struct_info.h | 19 +- include/tvm/relax/type.h | 18 +- include/tvm/target/target.h | 5 +- include/tvm/tirx/buffer.h | 6 +- include/tvm/tirx/expr.h | 203 ++++- include/tvm/tirx/function.h | 5 +- include/tvm/tirx/stmt.h | 52 +- include/tvm/tirx/var.h | 14 +- pyproject.toml | 3 + python/tvm/runtime/script_printer.py | 13 + python/tvm/script/ir_builder/tirx/ir.py | 2 + python/tvm/tirx/op.py | 13 + src/ir/attrs.cc | 2 + src/ir/expr.cc | 26 + src/ir/global_info.cc | 15 + src/ir/module.cc | 156 ++++ src/ir/printer_utils.h | 63 ++ src/ir/type.cc | 19 + src/relax/distributed/struct_info.cc | 108 +++ src/relax/ir/expr.cc | 799 ++++++++++++++++++++ src/relax/ir/script_print_utils.h | 424 +++++++++++ src/relax/ir/struct_info.cc | 196 ++++- src/target/target.cc | 9 +- src/tirx/ir/expr.cc | 354 +++++++++ src/tirx/ir/function.cc | 357 +++++++++ src/tirx/ir/script_print_utils.h | 348 +++++++++ src/tirx/ir/stmt.cc | 419 +++++++++- 37 files changed, 3724 insertions(+), 73 deletions(-) create mode 100644 src/ir/printer_utils.h create mode 100644 src/relax/ir/script_print_utils.h create mode 100644 src/tirx/ir/script_print_utils.h diff --git a/.gitignore b/.gitignore index 93f584104748..004f0f6a4121 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,10 @@ +.claude/ +ir_testsuite.jsonl +ir_testsuite_filtered.jsonl +main.py +main2.py +testsuite/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi index 63224e3f1e46..97f0132a185d 160000 --- a/3rdparty/tvm-ffi +++ b/3rdparty/tvm-ffi @@ -1 +1 @@ -Subproject commit 63224e3f1e464cc62307223787926a48fc8df8c0 +Subproject commit 97f0132a185d9198f32067e5a26c2bc12766fbc8 diff --git a/CMakeLists.txt b/CMakeLists.txt index 0950db7b0ba4..dc4d904991c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -277,6 +277,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/driver/*.cc src/support/*.cc src/script/*.cc + src/script_v2/*.cc src/relax/ir/*.cc src/relax/op/*.cc src/relax/analysis/*.cc diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index cf52ec32ea1f..0ee5bff7c0a1 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -145,7 +146,10 @@ class DictAttrsNode : public BaseAttrsNode { static void RegisterReflection() { namespace rfl = ffi::reflection; - rfl::ObjectDef().def_ro("__dict__", &DictAttrsNode::dict); + namespace tr = tvm::ffi::ir_traits; + rfl::ObjectDef() + .def_ro("__dict__", &DictAttrsNode::dict) + .def_ir_traits("$field:__dict__"); } void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 7c4b8e7cb2bd..90087ee5645b 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -24,6 +24,7 @@ #ifndef TVM_IR_EXPR_H_ #define TVM_IR_EXPR_H_ +#include #include #include #include @@ -459,7 +460,10 @@ class GlobalVarNode : public RelaxExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("name_hint", &GlobalVarNode::name_hint); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("name_hint", &GlobalVarNode::name_hint) + .def_ir_traits("I.GlobalVar", "$field:name_hint"); } bool SEqual(const GlobalVarNode* other, @@ -498,7 +502,10 @@ class IntImmNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("value", &IntImmNode::value); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("value", &IntImmNode::value) + .def_ir_traits("$field:value", "int"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.IntImm", IntImmNode, PrimExprNode); }; @@ -533,7 +540,10 @@ class FloatImmNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("value", &FloatImmNode::value); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("value", &FloatImmNode::value) + .def_ir_traits("$field:value", "float"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.FloatImm", FloatImmNode, PrimExprNode); }; @@ -675,10 +685,12 @@ class RangeNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("min", &RangeNode::min) .def_ro("extent", &RangeNode::extent) - .def_ro("span", &RangeNode::span, refl::AttachFieldFlag::SEqHashIgnore()); + .def_ro("span", &RangeNode::span, refl::AttachFieldFlag::SEqHashIgnore()) + .def_ir_traits("I.Range", "$global:ir._range_args"); } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index 892bba4da694..34924f8c2e4b 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -25,6 +25,7 @@ #ifndef TVM_IR_GLOBAL_INFO_H_ #define TVM_IR_GLOBAL_INFO_H_ +#include #include #include #include @@ -71,10 +72,12 @@ class VDeviceNode : public GlobalInfoNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("target", &VDeviceNode::target) .def_ro("vdevice_id", &VDeviceNode::vdevice_id) - .def_ro("memory_scope", &VDeviceNode::memory_scope); + .def_ro("memory_scope", &VDeviceNode::memory_scope) + .def_ir_traits("I.vdevice", "$global:ir._vdevice_args"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.VDevice", VDeviceNode, GlobalInfoNode); @@ -97,7 +100,9 @@ class DummyGlobalInfoNode : public GlobalInfoNode { public: static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ir_traits("I.dummy_global_info", "$global:ir._dummy_args"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DummyGlobalInfo", DummyGlobalInfoNode, GlobalInfoNode); diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 9171a9e6d2df..ce5e7ec505c9 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -26,6 +26,7 @@ #define TVM_IR_OP_H_ #include +#include #include #include #include @@ -93,6 +94,7 @@ class OpNode : public RelaxExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("name", &OpNode::name) .def_ro("op_type", &OpNode::op_type, refl::AttachFieldFlag::SEqHashIgnore()) @@ -100,7 +102,8 @@ class OpNode : public RelaxExprNode { .def_ro("arguments", &OpNode::arguments, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("attrs_type_key", &OpNode::attrs_type_key, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("num_inputs", &OpNode::num_inputs, refl::AttachFieldFlag::SEqHashIgnore()) - .def_ro("support_level", &OpNode::support_level, refl::AttachFieldFlag::SEqHashIgnore()); + .def_ro("support_level", &OpNode::support_level, refl::AttachFieldFlag::SEqHashIgnore()) + .def_ir_traits("I.Op", "$field:name"); } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUniqueInstance; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 902778c3db02..ac7316140c9c 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -50,6 +50,7 @@ #define TVM_IR_TYPE_H_ #include +#include #include #include #include @@ -118,7 +119,10 @@ class PrimTypeNode : public TypeNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("dtype", &PrimTypeNode::dtype); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("dtype", &PrimTypeNode::dtype) + .def_ir_traits("$field:dtype"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.PrimType", PrimTypeNode, TypeNode); }; @@ -162,9 +166,11 @@ class PointerTypeNode : public TypeNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("element_type", &PointerTypeNode::element_type) - .def_ro("storage_scope", &PointerTypeNode::storage_scope); + .def_ro("storage_scope", &PointerTypeNode::storage_scope) + .def_ir_traits("T.handle", "$global:ir._handle_args"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.PointerType", PointerTypeNode, TypeNode); }; @@ -198,9 +204,11 @@ class TupleTypeNode : public TypeNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("fields", &TupleTypeNode::fields) - .def_ro("span", &TupleTypeNode::span); + .def_ro("span", &TupleTypeNode::span) + .def_ir_traits("$field:fields"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.TupleType", TupleTypeNode, TypeNode); }; @@ -258,10 +266,12 @@ class FuncTypeNode : public TypeNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("arg_types", &FuncTypeNode::arg_types) .def_ro("ret_type", &FuncTypeNode::ret_type) - .def_ro("span", &FuncTypeNode::span); + .def_ro("span", &FuncTypeNode::span) + .def_ir_traits("$field:arg_types", "$field:ret_type"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.FuncType", FuncTypeNode, TypeNode); }; diff --git a/include/tvm/relax/distributed/struct_info.h b/include/tvm/relax/distributed/struct_info.h index 9ca3b1513828..69ab0cd15ec7 100644 --- a/include/tvm/relax/distributed/struct_info.h +++ b/include/tvm/relax/distributed/struct_info.h @@ -25,6 +25,7 @@ #ifndef TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_ #define TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_ +#include #include #include namespace tvm { @@ -91,7 +92,10 @@ class PlacementNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("dim_specs", &PlacementNode::dim_specs); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("dim_specs", &PlacementNode::dim_specs) + .def_ir_traits("$global:relax._placement_str"); } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 0407e3c604c5..d810be9ac0c8 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -214,7 +215,10 @@ class TupleNode : public ExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("fields", &TupleNode::fields); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("fields", &TupleNode::fields) + .def_ir_traits("Tuple", "$field:fields"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Tuple", TupleNode, ExprNode); }; @@ -269,9 +273,11 @@ class TupleGetItemNode : public ExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("tuple_value", &TupleGetItemNode::tuple) - .def_ro("index", &TupleGetItemNode::index); + .def_ro("index", &TupleGetItemNode::index) + .def_ir_traits("$field:tuple_value", "$field:index"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.TupleGetItem", TupleGetItemNode, ExprNode); }; @@ -328,7 +334,10 @@ class ShapeExprNode : public LeafExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("values", &ShapeExprNode::values); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("values", &ShapeExprNode::values) + .def_ir_traits("R.shape", "$global:relax._shape_args"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.ShapeExpr", ShapeExprNode, LeafExprNode); }; @@ -352,7 +361,10 @@ class VarNode : public LeafExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("vid", &VarNode::vid); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("vid", &VarNode::vid) + .def_ir_traits("$field:vid"); // customize structural equal and hash to include struct_info_ refl::TypeAttrDef() .def("__s_equal__", &VarNode::SEqual) @@ -397,7 +409,9 @@ class DataflowVarNode : public VarNode { public: static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ir_traits("$field:vid"); } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; @@ -435,7 +449,10 @@ class ConstantNode : public LeafExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("data", &ConstantNode::data); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("data", &ConstantNode::data) + .def_ir_traits("$field:data", "constant"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Constant", ConstantNode, LeafExprNode); }; @@ -469,7 +486,10 @@ class PrimValueNode : public LeafExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("value", &PrimValueNode::value); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("value", &PrimValueNode::value) + .def_ir_traits("R.prim_value", "$field:value"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.PrimValue", PrimValueNode, LeafExprNode); }; @@ -509,7 +529,10 @@ class StringImmNode : public LeafExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("value", &StringImmNode::value); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("value", &StringImmNode::value) + .def_ir_traits("R.str", "$field:value"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.StringImm", StringImmNode, LeafExprNode); }; @@ -541,7 +564,10 @@ class DataTypeImmNode : public LeafExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("value", &DataTypeImmNode::value); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("value", &DataTypeImmNode::value) + .def_ir_traits("R.dtype", "$global:relax._dtype_str"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataTypeImm", DataTypeImmNode, LeafExprNode); }; @@ -611,9 +637,11 @@ class MatchCastNode : public BindingNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("value", &MatchCastNode::value) - .def_ro("struct_info", &MatchCastNode::struct_info, refl::AttachFieldFlag::SEqHashDef()); + .def_ro("struct_info", &MatchCastNode::struct_info, refl::AttachFieldFlag::SEqHashDef()) + .def_ir_traits("$field:var", "$field:value"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.MatchCast", MatchCastNode, BindingNode); }; @@ -637,7 +665,8 @@ class VarBindingNode : public BindingNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("value", &VarBindingNode::value); + refl::ObjectDef() + .def_ro("value", &VarBindingNode::value); // customize the SEqual and SHash methods for better error messages refl::TypeAttrDef() .def("__s_equal__", &VarBindingNode::SEqual) @@ -664,10 +693,14 @@ class BindingBlockNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("bindings", &BindingBlockNode::bindings) .def_ro("span", &BindingBlockNode::span, refl::AttachFieldFlag::SEqHashIgnore(), - refl::DefaultValue(Span())); + refl::DefaultValue(Span())) + .def_ir_traits(tr::RegionTraits("$field:bindings"), + nullptr, nullptr, nullptr, nullptr, nullptr, + ffi::Optional(true)); } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; @@ -686,7 +719,11 @@ class DataflowBlockNode : public BindingBlockNode { public: static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ir_traits(tr::RegionTraits("$field:bindings"), + nullptr, nullptr, "R.dataflow", nullptr, + "$global:relax._dataflow_outputs"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataflowBlock", DataflowBlockNode, BindingBlockNode); @@ -710,9 +747,14 @@ class SeqExprNode : public ExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("blocks", &SeqExprNode::blocks) - .def_ro("body", &SeqExprNode::body); + .def_ro("body", &SeqExprNode::body) + .def_ir_traits(tr::RegionTraits("$field:blocks", nullptr, nullptr, + "$field:body"), + nullptr, nullptr, nullptr, nullptr, nullptr, + ffi::Optional(true)); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.SeqExpr", SeqExprNode, ExprNode); }; @@ -907,7 +949,10 @@ class ExternFuncNode : public BaseFuncNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("global_symbol", &ExternFuncNode::global_symbol); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("global_symbol", &ExternFuncNode::global_symbol) + .def_ir_traits("R.ExternFunc", "$field:global_symbol"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.ExternFunc", ExternFuncNode, BaseFuncNode); }; diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index 0d9658d8cffc..9959cc5ed5d7 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -19,6 +19,7 @@ #ifndef TVM_RELAX_STRUCT_INFO_H_ #define TVM_RELAX_STRUCT_INFO_H_ +#include #include #include #include @@ -40,7 +41,9 @@ class ObjectStructInfoNode : public StructInfoNode { public: static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ir_traits("R.Object", "$global:relax._empty_array"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ObjectStructInfo", ObjectStructInfoNode, StructInfoNode); }; @@ -175,11 +178,13 @@ class TensorStructInfoNode : public StructInfoNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("shape", &TensorStructInfoNode::shape) .def_ro("dtype", &TensorStructInfoNode::dtype) .def_ro("vdevice", &TensorStructInfoNode::vdevice) - .def_ro("ndim", &TensorStructInfoNode::ndim); + .def_ro("ndim", &TensorStructInfoNode::ndim) + .def_ir_traits("$field:shape", "$field:dtype", "$field:vdevice"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TensorStructInfo", TensorStructInfoNode, StructInfoNode); }; @@ -225,7 +230,10 @@ class TupleStructInfoNode : public StructInfoNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("fields", &TupleStructInfoNode::fields); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("fields", &TupleStructInfoNode::fields) + .def_ir_traits("R.Tuple", "$field:fields"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TupleStructInfo", TupleStructInfoNode, StructInfoNode); }; @@ -293,11 +301,14 @@ class FuncStructInfoNode : public StructInfoNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("params", &FuncStructInfoNode::params, refl::AttachFieldFlag::SEqHashDef()) .def_ro("ret", &FuncStructInfoNode::ret) .def_ro("derive_func", &FuncStructInfoNode::derive_func) - .def_ro("purity", &FuncStructInfoNode::purity); + .def_ro("purity", &FuncStructInfoNode::purity) + .def_ir_traits("R.Callable", "$global:relax._func_si_args", + nullptr, "$global:relax._func_si_kwargs"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.FuncStructInfo", FuncStructInfoNode, StructInfoNode); }; diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index b70a2756b71f..5715fffdc812 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -25,6 +25,7 @@ #define TVM_RELAX_TYPE_H_ #include +#include #include #include #include @@ -46,7 +47,10 @@ class ShapeTypeNode : public TypeNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("ndim", &ShapeTypeNode::ndim); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("ndim", &ShapeTypeNode::ndim) + .def_ir_traits(nullptr, "$field:ndim"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ShapeType", ShapeTypeNode, TypeNode); }; @@ -75,9 +79,11 @@ class TensorTypeNode : public TypeNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("ndim", &TensorTypeNode::ndim) - .def_ro("dtype", &TensorTypeNode::dtype); + .def_ro("dtype", &TensorTypeNode::dtype) + .def_ir_traits(nullptr, "$field:dtype"); } inline bool IsUnknownNdim() const { return ndim == kUnknownNDim; } @@ -115,7 +121,9 @@ class ObjectTypeNode : public TypeNode { public: static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ir_traits(); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ObjectType", ObjectTypeNode, TypeNode); }; @@ -131,7 +139,9 @@ class PackedFuncTypeNode : public TypeNode { public: static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ir_traits(); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.PackedFuncType", PackedFuncTypeNode, TypeNode); }; diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index b71a4952b530..7451fc3c5f02 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -24,6 +24,7 @@ #ifndef TVM_TARGET_TARGET_H_ #define TVM_TARGET_TARGET_H_ +#include #include #include #include @@ -80,12 +81,14 @@ class TargetNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("kind", &TargetNode::kind) .def_ro("tag", &TargetNode::tag) .def_ro("keys", &TargetNode::keys) .def_ro("attrs", &TargetNode::attrs) - .def_ro("host", &TargetNode::host); + .def_ro("host", &TargetNode::host) + .def_ir_traits("T.target", "$global:target._config"); } /*! diff --git a/include/tvm/tirx/buffer.h b/include/tvm/tirx/buffer.h index 8f5c916a5c11..a6f46520cc85 100644 --- a/include/tvm/tirx/buffer.h +++ b/include/tvm/tirx/buffer.h @@ -25,6 +25,7 @@ #define TVM_TIR_BUFFER_H_ #include +#include #include #include #include @@ -114,6 +115,7 @@ class BufferNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("data", &BufferNode::data, refl::AttachFieldFlag::SEqHashDef()) .def_ro("dtype", &BufferNode::dtype) @@ -126,7 +128,9 @@ class BufferNode : public Object { .def_ro("data_alignment", &BufferNode::data_alignment) .def_ro("offset_factor", &BufferNode::offset_factor) .def_ro("buffer_type", &BufferNode::buffer_type) - .def_ro("span", &BufferNode::span, refl::AttachFieldFlag::SEqHashIgnore()); + .def_ro("span", &BufferNode::span, refl::AttachFieldFlag::SEqHashIgnore()) + .def_ir_traits("$field:shape", "$field:dtype", + "$field:strides", "$field:elem_offset", "$field:name"); } /*! \return preferred index type for this buffer node */ diff --git a/include/tvm/tirx/expr.h b/include/tvm/tirx/expr.h index ebd318d82288..12deb1af25d6 100644 --- a/include/tvm/tirx/expr.h +++ b/include/tvm/tirx/expr.h @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -56,7 +57,10 @@ class StringImmNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("value", &StringImmNode::value); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("value", &StringImmNode::value) + .def_ir_traits("$field:value", "string"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.StringImm", StringImmNode, PrimExprNode); }; @@ -83,7 +87,10 @@ class CastNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("value", &CastNode::value); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("value", &CastNode::value) + .def_ir_traits("T.Cast", "$global:tirx._cast_args"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Cast", CastNode, PrimExprNode); }; @@ -125,6 +132,15 @@ class BinaryOpNode : public PrimExprNode { class AddNode : public BinaryOpNode { public: static constexpr const char* _type_key = "tirx.Add"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &AddNode::a) + .def_ro("b", &AddNode::b) + .def_ir_traits("$field:a", "$field:b", "+", + "$global:tirx._add_sugar", "Add"); + } }; /*! @@ -142,6 +158,15 @@ class Add : public PrimExpr { class SubNode : public BinaryOpNode { public: static constexpr const char* _type_key = "tirx.Sub"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &SubNode::a) + .def_ro("b", &SubNode::b) + .def_ir_traits("$field:a", "$field:b", "-", + "$global:tirx._sub_sugar", "Sub"); + } }; /*! @@ -160,6 +185,15 @@ class Sub : public PrimExpr { class MulNode : public BinaryOpNode { public: static constexpr const char* _type_key = "tirx.Mul"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &MulNode::a) + .def_ro("b", &MulNode::b) + .def_ir_traits("$field:a", "$field:b", "*", + "$global:tirx._mul_sugar", "Mul"); + } }; /*! @@ -180,6 +214,15 @@ class Mul : public PrimExpr { class DivNode : public BinaryOpNode { public: static constexpr const char* _type_key = "tirx.Div"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &DivNode::a) + .def_ro("b", &DivNode::b) + .def_ir_traits("$field:a", "$field:b", "/", + "$global:tirx._div_sugar", "Div"); + } }; /*! @@ -200,6 +243,14 @@ class Div : public PrimExpr { class ModNode : public BinaryOpNode { public: static constexpr const char* _type_key = "tirx.Mod"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &ModNode::a) + .def_ro("b", &ModNode::b) + .def_ir_traits("$field:a", "$field:b", "truncmod"); + } }; /*! @@ -217,6 +268,15 @@ class Mod : public PrimExpr { class FloorDivNode : public BinaryOpNode { public: static constexpr const char* _type_key = "tirx.FloorDiv"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &FloorDivNode::a) + .def_ro("b", &FloorDivNode::b) + .def_ir_traits("$field:a", "$field:b", "//", + "$global:tirx._floordiv_sugar", "FloorDiv"); + } }; /*! @@ -234,6 +294,15 @@ class FloorDiv : public PrimExpr { class FloorModNode : public BinaryOpNode { public: static constexpr const char* _type_key = "tirx.FloorMod"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &FloorModNode::a) + .def_ro("b", &FloorModNode::b) + .def_ir_traits("$field:a", "$field:b", "%", + "$global:tirx._floormod_sugar", "FloorMod"); + } }; /*! @@ -251,6 +320,14 @@ class FloorMod : public PrimExpr { class MinNode : public BinaryOpNode { public: static constexpr const char* _type_key = "tirx.Min"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &MinNode::a) + .def_ro("b", &MinNode::b) + .def_ir_traits("$field:a", "$field:b", "min"); + } }; /*! @@ -268,6 +345,14 @@ class Min : public PrimExpr { class MaxNode : public BinaryOpNode { public: static constexpr const char* _type_key = "tirx.Max"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &MaxNode::a) + .def_ro("b", &MaxNode::b) + .def_ir_traits("$field:a", "$field:b", "max"); + } }; /*! @@ -307,6 +392,15 @@ class CmpOpNode : public PrimExprNode { class EQNode : public CmpOpNode { public: static constexpr const char* _type_key = "tirx.EQ"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &EQNode::a) + .def_ro("b", &EQNode::b) + .def_ir_traits("$field:a", "$field:b", "==", + "$global:tirx._eq_sugar", "EQ"); + } }; /*! @@ -324,6 +418,15 @@ class EQ : public PrimExpr { class NENode : public CmpOpNode { public: static constexpr const char* _type_key = "tirx.NE"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &NENode::a) + .def_ro("b", &NENode::b) + .def_ir_traits("$field:a", "$field:b", "!=", + "$global:tirx._ne_sugar", "NE"); + } }; /*! @@ -341,6 +444,15 @@ class NE : public PrimExpr { class LTNode : public CmpOpNode { public: static constexpr const char* _type_key = "tirx.LT"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", <Node::a) + .def_ro("b", <Node::b) + .def_ir_traits("$field:a", "$field:b", "<", + "$global:tirx._lt_sugar", "LT"); + } }; /*! @@ -358,6 +470,15 @@ class LT : public PrimExpr { struct LENode : public CmpOpNode { public: static constexpr const char* _type_key = "tirx.LE"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &LENode::a) + .def_ro("b", &LENode::b) + .def_ir_traits("$field:a", "$field:b", "<=", + "$global:tirx._le_sugar", "LE"); + } }; /*! @@ -375,6 +496,15 @@ class LE : public PrimExpr { class GTNode : public CmpOpNode { public: static constexpr const char* _type_key = "tirx.GT"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", >Node::a) + .def_ro("b", >Node::b) + .def_ir_traits("$field:a", "$field:b", ">", + "$global:tirx._gt_sugar", "GT"); + } }; /*! @@ -392,6 +522,15 @@ class GT : public PrimExpr { class GENode : public CmpOpNode { public: static constexpr const char* _type_key = "tirx.GE"; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &GENode::a) + .def_ro("b", &GENode::b) + .def_ir_traits("$field:a", "$field:b", ">=", + "$global:tirx._ge_sugar", "GE"); + } }; /*! @@ -415,7 +554,12 @@ class AndNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("a", &AndNode::a).def_ro("b", &AndNode::b); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &AndNode::a) + .def_ro("b", &AndNode::b) + .def_ir_traits("$field:a", "$field:b", "and", + "$global:tirx._and_sugar", "And"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.And", AndNode, PrimExprNode); }; @@ -441,7 +585,12 @@ class OrNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("a", &OrNode::a).def_ro("b", &OrNode::b); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &OrNode::a) + .def_ro("b", &OrNode::b) + .def_ir_traits("$field:a", "$field:b", "or", + "$global:tirx._or_sugar", "Or"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Or", OrNode, PrimExprNode); }; @@ -465,7 +614,10 @@ class NotNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("a", &NotNode::a); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("a", &NotNode::a) + .def_ir_traits("$field:a", "not"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Not", NotNode, PrimExprNode); }; @@ -499,10 +651,12 @@ class SelectNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("condition", &SelectNode::condition) .def_ro("true_value", &SelectNode::true_value) - .def_ro("false_value", &SelectNode::false_value); + .def_ro("false_value", &SelectNode::false_value) + .def_ir_traits("T.Select", "$global:tirx._select_args"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Select", SelectNode, PrimExprNode); }; @@ -540,10 +694,13 @@ class BufferLoadNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("buffer", &BufferLoadNode::buffer) .def_ro("indices", &BufferLoadNode::indices) - .def_ro("predicate", &BufferLoadNode::predicate); + .def_ro("predicate", &BufferLoadNode::predicate) + .def_ir_traits("$field:buffer", "$global:tirx._load_indices", + "$field:predicate"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferLoad", BufferLoadNode, PrimExprNode); @@ -594,9 +751,11 @@ class ProducerLoadNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("producer", &ProducerLoadNode::producer) - .def_ro("indices", &ProducerLoadNode::indices); + .def_ro("indices", &ProducerLoadNode::indices) + .def_ir_traits("$field:producer", "$field:indices"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ProducerLoad", ProducerLoadNode, PrimExprNode); }; @@ -634,10 +793,12 @@ class RampNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("base", &RampNode::base) .def_ro("stride", &RampNode::stride) - .def_ro("lanes", &RampNode::lanes); + .def_ro("lanes", &RampNode::lanes) + .def_ir_traits("T.Ramp", "$global:tirx._ramp_args"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Ramp", RampNode, PrimExprNode); }; @@ -663,9 +824,11 @@ class BroadcastNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("value", &BroadcastNode::value) - .def_ro("lanes", &BroadcastNode::lanes); + .def_ro("lanes", &BroadcastNode::lanes) + .def_ir_traits("T.Broadcast", "$global:tirx._broadcast_args"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Broadcast", BroadcastNode, PrimExprNode); }; @@ -732,7 +895,12 @@ class CallNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("op", &CallNode::op).def_ro("args", &CallNode::args); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("op", &CallNode::op) + .def_ro("args", &CallNode::args) + .def_ir_traits("$global:tirx._call_callee", "$global:tirx._call_args", + nullptr, nullptr, "$global:tirx._tir_call_callee"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Call", CallNode, PrimExprNode); }; @@ -762,9 +930,11 @@ class ShuffleNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("vectors", &ShuffleNode::vectors) - .def_ro("indices", &ShuffleNode::indices); + .def_ro("indices", &ShuffleNode::indices) + .def_ir_traits("T.Shuffle", "$global:tirx._shuffle_args"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Shuffle", ShuffleNode, PrimExprNode); }; @@ -812,12 +982,14 @@ class CommReducerNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("lhs", &CommReducerNode::lhs, refl::AttachFieldFlag::SEqHashDef()) .def_ro("rhs", &CommReducerNode::rhs, refl::AttachFieldFlag::SEqHashDef()) .def_ro("result", &CommReducerNode::result) .def_ro("identity_element", &CommReducerNode::identity_element) - .def_ro("span", &CommReducerNode::span, refl::AttachFieldFlag::SEqHashIgnore()); + .def_ro("span", &CommReducerNode::span, refl::AttachFieldFlag::SEqHashIgnore()) + .def_ir_traits("comm_reducer", "$field:lhs"); } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; @@ -857,13 +1029,16 @@ class ReduceNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("combiner", &ReduceNode::combiner) .def_ro("source", &ReduceNode::source) .def_ro("init", &ReduceNode::init) .def_ro("axis", &ReduceNode::axis) .def_ro("condition", &ReduceNode::condition) - .def_ro("value_index", &ReduceNode::value_index); + .def_ro("value_index", &ReduceNode::value_index) + .def_ir_traits("T.reduce", "$global:tirx._reduce_positional", + nullptr, "$global:tirx._reduce_kwargs"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Reduce", ReduceNode, PrimExprNode); }; diff --git a/include/tvm/tirx/function.h b/include/tvm/tirx/function.h index 0c98deb8b309..e1bb14c69497 100644 --- a/include/tvm/tirx/function.h +++ b/include/tvm/tirx/function.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -167,9 +168,11 @@ class TensorIntrinNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("desc", &TensorIntrinNode::desc) - .def_ro("impl", &TensorIntrinNode::impl); + .def_ro("impl", &TensorIntrinNode::impl) + .def_ir_traits("TensorIntrin", "$field:desc"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.TensorIntrin", TensorIntrinNode, Object); }; diff --git a/include/tvm/tirx/stmt.h b/include/tvm/tirx/stmt.h index c191c4e6bf1f..ce587173edc3 100644 --- a/include/tvm/tirx/stmt.h +++ b/include/tvm/tirx/stmt.h @@ -24,6 +24,7 @@ #ifndef TVM_TIR_STMT_H_ #define TVM_TIR_STMT_H_ +#include #include #include #include @@ -83,9 +84,11 @@ class BindNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("var", &BindNode::var, refl::AttachFieldFlag::SEqHashDef()) - .def_ro("value", &BindNode::value); + .def_ro("value", &BindNode::value) + .def_ir_traits("$field:var", "$field:value"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Bind", BindNode, StmtNode); }; @@ -167,10 +170,12 @@ class AssertStmtNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("condition", &AssertStmtNode::condition) .def_ro("error_kind", &AssertStmtNode::error_kind) - .def_ro("message_parts", &AssertStmtNode::message_parts); + .def_ro("message_parts", &AssertStmtNode::message_parts) + .def_ir_traits("$field:condition", "$global:tirx._structured_msg"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.AssertStmt", AssertStmtNode, StmtNode); }; @@ -211,11 +216,14 @@ class BufferStoreNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("buffer", &BufferStoreNode::buffer) .def_ro("value", &BufferStoreNode::value) .def_ro("indices", &BufferStoreNode::indices) - .def_ro("predicate", &BufferStoreNode::predicate); + .def_ro("predicate", &BufferStoreNode::predicate) + .def_ir_traits("$field:buffer", "$field:value", + "$global:tirx._store_indices", "$field:predicate"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferStore", BufferStoreNode, StmtNode); }; @@ -242,7 +250,8 @@ class DeclBufferNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("buffer", &DeclBufferNode::buffer); + refl::ObjectDef() + .def_ro("buffer", &DeclBufferNode::buffer); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.DeclBuffer", DeclBufferNode, StmtNode); }; @@ -322,7 +331,10 @@ class SeqStmtNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("seq", &SeqStmtNode::seq); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("seq", &SeqStmtNode::seq) + .def_ir_traits(tr::RegionTraits("$field:seq")); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SeqStmt", SeqStmtNode, StmtNode); }; @@ -340,7 +352,13 @@ class EvaluateNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("value", &EvaluateNode::value); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ro("value", &EvaluateNode::value) + .def_ir_traits(nullptr, "$global:tirx._evaluate_expr", + nullptr, nullptr, + "$global:tirx._evaluate_kind", + "$global:tirx._evaluate_is_return"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Evaluate", EvaluateNode, StmtNode); }; @@ -524,10 +542,14 @@ class IfThenElseNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("condition", &IfThenElseNode::condition) .def_ro("then_case", &IfThenElseNode::then_case) - .def_ro("else_case", &IfThenElseNode::else_case); + .def_ro("else_case", &IfThenElseNode::else_case) + .def_ir_traits("$field:condition", + tr::RegionTraits("$field:then_case"), + ffi::Optional(tr::RegionTraits("$field:else_case"))); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.IfThenElse", IfThenElseNode, StmtNode); }; @@ -667,9 +689,11 @@ class WhileNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("condition", &WhileNode::condition) - .def_ro("body", &WhileNode::body); + .def_ro("body", &WhileNode::body) + .def_ir_traits("$field:condition", tr::RegionTraits("$field:body")); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.While", WhileNode, StmtNode); }; @@ -698,9 +722,11 @@ class BufferRegionNode : public PrimExprConvertibleNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("buffer", &BufferRegionNode::buffer) - .def_ro("region", &BufferRegionNode::region); + .def_ro("region", &BufferRegionNode::region) + .def_ir_traits("$field:buffer", "$global:tirx._buf_region_indices"); } TVM_DLL PrimExpr ToPrimExpr() const final; @@ -825,6 +851,7 @@ class SBlockNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("iter_vars", &SBlockNode::iter_vars, refl::AttachFieldFlag::SEqHashDef()) .def_ro("reads", &SBlockNode::reads) @@ -834,7 +861,8 @@ class SBlockNode : public StmtNode { .def_ro("match_buffers", &SBlockNode::match_buffers) .def_ro("annotations", &SBlockNode::annotations) .def_ro("init", &SBlockNode::init) - .def_ro("body", &SBlockNode::body); + .def_ro("body", &SBlockNode::body) + .def_ir_traits(tr::RegionTraits("$field:body", "$field:iter_vars")); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SBlock", SBlockNode, StmtNode); }; @@ -875,10 +903,12 @@ class SBlockRealizeNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("iter_values", &SBlockRealizeNode::iter_values) .def_ro("predicate", &SBlockRealizeNode::predicate) - .def_ro("block", &SBlockRealizeNode::block); + .def_ro("block", &SBlockRealizeNode::block) + .def_ir_traits(tr::RegionTraits("$field:block")); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SBlockRealize", SBlockRealizeNode, StmtNode); }; diff --git a/include/tvm/tirx/var.h b/include/tvm/tirx/var.h index 7da5cb31c152..07e28155827e 100644 --- a/include/tvm/tirx/var.h +++ b/include/tvm/tirx/var.h @@ -62,9 +62,12 @@ class VarNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("name", &VarNode::name_hint, refl::AttachFieldFlag::SEqHashIgnore()) - .def_ro("type_annotation", &VarNode::type_annotation); + .def_ro("type_annotation", &VarNode::type_annotation) + .def_ir_traits("$field:name", "$field:type_annotation", + "$global:tirx._var_type_or_null"); } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; @@ -133,7 +136,10 @@ class SizeVarNode : public VarNode { public: static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); + namespace tr = tvm::ffi::ir_traits; + refl::ObjectDef() + .def_ir_traits("$field:name", "$field:type_annotation", + "$global:tirx._sizevar_type_or_null"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SizeVar", SizeVarNode, VarNode); }; @@ -276,11 +282,13 @@ class IterVarNode : public PrimExprConvertibleNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; + namespace tr = tvm::ffi::ir_traits; refl::ObjectDef() .def_ro("dom", &IterVarNode::dom) .def_ro("var", &IterVarNode::var, refl::AttachFieldFlag::SEqHashDef()) .def_ro("iter_type", &IterVarNode::iter_type) - .def_ro("thread_tag", &IterVarNode::thread_tag); + .def_ro("thread_tag", &IterVarNode::thread_tag) + .def_ir_traits("T.iter_var", "$global:tirx._iter_var_args"); } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; diff --git a/pyproject.toml b/pyproject.toml index 062f357bfa8b..c9c1179716ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -269,5 +269,8 @@ ignore_errors = true module = ["python.tvm.runtime.*"] ignore_errors = true +[tool.uv.sources] +apache-tvm-ffi = { path = "3rdparty/tvm-ffi", editable = true } + [dependency-groups] lint = ["pre-commit"] diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 6ba9abb032af..3d43736198eb 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -224,6 +224,19 @@ def script( ), ) + def script_v2(self, *, indent_spaces: int = 4) -> str: + """Print using the traits-based V2 printer. + + Returns + ------- + script : str + The V2 TVM Script of the given TVM IR + """ + from tvm_ffi.pyast import to_python, PrinterConfig + + cfg = PrinterConfig(indent_spaces=indent_spaces) + return to_python(self, cfg) + def _relax_script( self, *, diff --git a/python/tvm/script/ir_builder/tirx/ir.py b/python/tvm/script/ir_builder/tirx/ir.py index ce88c563156f..4113d4c07e45 100644 --- a/python/tvm/script/ir_builder/tirx/ir.py +++ b/python/tvm/script/ir_builder/tirx/ir.py @@ -1928,6 +1928,7 @@ def wrapped(*args, **kwargs) -> T: truncmod = _op_wrapper(_tir_op.truncmod) tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr) tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error) +tvm_static_handle = _op_wrapper(_tir_op.tvm_static_handle) tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca) tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape) tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array) @@ -2208,6 +2209,7 @@ def wrapped(*args, **kwargs): "truncmod", "tvm_access_ptr", "tvm_throw_last_error", + "tvm_static_handle", "tvm_stack_alloca", "tvm_stack_make_shape", "tvm_stack_make_array", diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py index 6b4a636f3061..b8d7244dfa4c 100644 --- a/python/tvm/tirx/op.py +++ b/python/tvm/tirx/op.py @@ -801,6 +801,19 @@ def tvm_throw_last_error(): return call_intrin("handle", "tirx.tvm_throw_last_error") +def tvm_static_handle(): + """Create a function-local static handle initialized to nullptr. + + Can be used to cache function-local static resources. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tirx.tvm_static_handle") + + def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): """TVM intrinsic for tensor core load operators diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 748f4bf5c93f..1badb128cc3b 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -21,6 +21,7 @@ * \file attrs.cc */ #include +#include #include #include @@ -29,6 +30,7 @@ namespace tvm { TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = ::tvm::ffi::reflection; AttrFieldInfoNode::RegisterReflection(); DictAttrsNode::RegisterReflection(); } diff --git a/src/ir/expr.cc b/src/ir/expr.cc index f9d6e0fc6080..d6ca966a57d0 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -23,6 +23,8 @@ */ #include #include +#include +#include #include #include #include @@ -34,6 +36,7 @@ namespace tvm { TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = ::tvm::ffi::reflection; BaseExprNode::RegisterReflection(); PrimExprNode::RegisterReflection(); RelaxExprNode::RegisterReflection(); @@ -42,6 +45,27 @@ TVM_FFI_STATIC_INIT_BLOCK() { IntImmNode::RegisterReflection(); FloatImmNode::RegisterReflection(); RangeNode::RegisterReflection(); + refl::GlobalDef().def("ir._range_args", [](ffi::AnyView /*ctx*/, Range node) -> ffi::Array { + return {node->min, node->min + node->extent}; + }); +} + +// GlobalVar: check printer var table first for module-bound references. +// Falls through to I.GlobalVar("name") for unbound. +TVM_FFI_STATIC_INIT_BLOCK() { + namespace text = ::tvm::ffi::pyast; + namespace refl = ::tvm::ffi::reflection; + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](GlobalVar gv, text::IRPrinter printer, refl::AccessPath path) -> text::NodeAST { + if (auto doc = printer->VarGet(gv)) { + return doc.value(); + } + text::ExprAST callee = text::ExprAttr(text::IdAST("I"), "GlobalVar"); + ffi::List args; + args.push_back(text::LiteralAST::Str(gv->name_hint)); + return text::CallAST(callee, std::move(args), {}, {}); + }); } PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {} @@ -231,4 +255,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } +// range_args method is now registered via ObjectDef().def() above. + } // namespace tvm diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index 151387d3c25a..da88173c0d9c 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -22,13 +22,27 @@ * \brief Module global info. */ +#include +#include #include #include +#include + namespace tvm { TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = ::tvm::ffi::reflection; VDeviceNode::RegisterReflection(); DummyGlobalInfoNode::RegisterReflection(); + refl::GlobalDef() + .def("ir._vdevice_args", [](ffi::AnyView /*ctx*/, VDevice node) -> ffi::List { + ffi::List result; + result.push_back(node->target->ToConfig()); + result.push_back(static_cast(node->vdevice_id)); + result.push_back(node->memory_scope); + return result; + }) + .def("ir._dummy_args", [](ffi::AnyView /*ctx*/, DummyGlobalInfo) -> ffi::Array { return {}; }); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -53,4 +67,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { return VDevice(tgt, dev_id, mem_scope); }); } + } // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index 935d9e0ccdb4..3b367b4151d8 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -34,6 +35,8 @@ #include #include +#include "printer_utils.h" + namespace tvm { TVM_FFI_STATIC_INIT_BLOCK() { IRModuleNode::RegisterReflection(); } @@ -315,4 +318,157 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](IRModule mod, ffi::String key) -> ObjectRef { return mod->GetAttr(key); }); } +// --------------------------------------------------------------------------- +// __ffi_text_print__ override +// --------------------------------------------------------------------------- + +namespace { + +struct SortableFunction { + int priority; + GlobalVar gv; + BaseFunc func; + + explicit SortableFunction(const std::pair& obj) + : priority(0), gv(obj.first), func(obj.second) { + if (gv->name_hint == "main") { + priority = 1000; + } else if (obj.second->GetTypeKey() == "tirx.PrimFunc") { + priority = 1; + } else if (obj.second->GetTypeKey() == "relax.expr.ExternFunc") { + priority = 2; + } else if (obj.second->GetTypeKey() == "relax.expr.Function") { + priority = 3; + } else { + priority = 0; + } + } + + bool operator<(const SortableFunction& other) const { + if (this->priority != other.priority) { + return this->priority < other.priority; + } + return this->gv->name_hint < other.gv->name_hint; + } +}; + +} // namespace + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](IRModule mod, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + using namespace printer; + // Sort functions by priority + std::vector functions; + for (const auto& kv : mod->functions) { + functions.push_back(SortableFunction(kv)); + } + std::sort(functions.begin(), functions.end()); + + text::IdAST module_doc = text::IdAST("Module"); + ffi::List decorators; + decorators.push_back(IR("ir_module")); + + text::DefaultFrame frame; + printer->FramePush(frame); + + // Define GlobalVars + for (const auto& entry : functions) { + const GlobalVar& gv = entry.gv; + ffi::String name = gv->name_hint; + ffi::Function creator = ffi::Function::FromTyped( + [name]() -> text::ExprAST { return text::ExprAttr(text::IdAST("Module"), std::string(name)); }); + printer->VarDefNoName(creator, gv, ffi::Optional(frame)); + } + + // Print attrs prologue + if (mod->attrs.defined() && !mod->attrs->dict.empty()) { + frame->stmts.push_back( + text::ExprStmtAST(text::ExprCall(IR("module_attrs"), + {Print(printer, mod->attrs, path->Attr("attrs"))}))); + } + + // Print global_infos prologue + if (mod->global_infos.defined() && !mod->global_infos.empty()) { + frame->stmts.push_back( + text::ExprStmtAST(text::ExprCall(IR("module_global_infos"), + {Print(printer, mod->global_infos, + path->Attr("global_infos"))}))); + if (auto opt_vdevices = mod->global_infos.Get("vdevice")) { + ffi::Array vdevices = opt_vdevices.value(); + for (int i = 0; i < static_cast(vdevices.size()); ++i) { + if (const auto* vd = vdevices[i].as()) { + VDevice vdev = ffi::GetRef(vd); + if (!vdev->target.defined()) continue; + std::string dev_kind = vdev->target->kind->name; + int kind_index = 0; + for (int j = 0; j < i; ++j) { + if (const auto* prev = vdevices[j].as()) { + if (prev->target.defined() && prev->target->kind->name == dev_kind) { + kind_index++; + } + } + } + std::string vdev_str = dev_kind + ":" + std::to_string(kind_index) + ":" + + std::string(vdev->memory_scope); + ffi::Function creator = ffi::Function::FromTyped( + [vdev_str]() -> text::ExprAST { return text::LiteralAST::Str(vdev_str); }); + printer->VarDefNoName(creator, vdev, ffi::Optional(frame)); + } + } + } + } + + // Print each function + ffi::List body; + for (const auto& s : frame->stmts) body.push_back(s); + + for (const auto& entry : functions) { + const GlobalVar& gv = entry.gv; + const BaseFunc& base_func = entry.func; + text::AccessPath func_path = path->Attr("functions")->MapItem(gv); + printer->VarDef(gv->name_hint, base_func, frame); + text::NodeAST doc = printer->operator()(ffi::Any(base_func), func_path).cast(); + if (auto* block = doc.as()) { + body.push_back(block->stmts.back()); + } else if (doc->IsInstance()) { + body.push_back(Downcast(doc)); + } else if (doc->IsInstance()) { + text::ExprAST lhs = text::IdAST(gv->name_hint); + body.push_back(text::AssignAST(lhs, Downcast(doc), + ffi::Optional())); + } + } + + printer->FramePop(); + + text::ClassAST class_ast(module_doc, {}, decorators, body); + + // Add import header comments + bool has_relax = false; + for (const auto& entry : functions) { + std::string type_key = entry.func->GetTypeKey(); + if (type_key == "relax.expr.Function" || type_key == "relax.expr.ExternFunc") { + has_relax = true; + break; + } + } + ffi::List result; + result.push_back(text::CommentAST( + ffi::Optional(ffi::String("from tvm.script import ir as I")))); + result.push_back(text::CommentAST( + ffi::Optional(ffi::String("from tvm.script import tirx as T")))); + if (has_relax) { + result.push_back(text::CommentAST( + ffi::Optional(ffi::String("from tvm.script import relax as R")))); + } + result.push_back(text::CommentAST(ffi::Optional())); + result.push_back(class_ast); + return text::StmtBlockAST(result); + }); +} + } // namespace tvm diff --git a/src/ir/printer_utils.h b/src/ir/printer_utils.h new file mode 100644 index 000000000000..bd778cea7d1b --- /dev/null +++ b/src/ir/printer_utils.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_IR_PRINTER_UTILS_H_ +#define TVM_IR_PRINTER_UTILS_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace printer { + +namespace text = ::tvm::ffi::pyast; +namespace tr = ::tvm::ffi::ir_traits; +namespace refl = ::tvm::ffi::reflection; + +/*! \brief Convert a DataType to its string representation. */ +inline std::string DType2Str(const runtime::DataType& dtype) { + return dtype.is_void() ? "void" : runtime::DLDataTypeToString(dtype); +} + +/*! \brief Build `prefix.attr` as an IdAST with dot notation. */ +inline text::ExprAST PrefixedId(const std::string& prefix, const std::string& attr) { + return text::ExprAttr(text::IdAST(prefix), attr); +} + +/*! \brief Build `T.attr` */ +inline text::ExprAST TIR(const std::string& attr) { return PrefixedId("T", attr); } + +/*! \brief Build `R.attr` */ +inline text::ExprAST Relax(const std::string& attr) { return PrefixedId("R", attr); } + +/*! \brief Build `I.attr` */ +inline text::ExprAST IR(const std::string& attr) { return PrefixedId("I", attr); } + +/*! \brief Print an object through the printer dispatch. */ +inline text::ExprAST Print(const text::IRPrinter& printer, ffi::Any obj, text::AccessPath path) { + return printer->operator()(std::move(obj), std::move(path)).cast(); +} + + +} // namespace printer +} // namespace tvm + +#endif // TVM_IR_PRINTER_UTILS_H_ diff --git a/src/ir/type.cc b/src/ir/type.cc index b28e20a78f89..2f669ec8932e 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -21,15 +21,34 @@ * \file src/ir/type.cc * \brief Common type system AST nodes throughout the IR. */ +#include #include +#include #include #include +#include namespace tvm { TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = ::tvm::ffi::reflection; TypeNode::RegisterReflection(); PrimTypeNode::RegisterReflection(); PointerTypeNode::RegisterReflection(); + refl::GlobalDef().def("ir._handle_args", [](ffi::AnyView /*ctx*/, PointerType node) -> ffi::Array { + ffi::Array args; + if (const auto* prim = node->element_type.as()) { + runtime::DataType dt = prim->dtype; + ffi::String dtype_str = dt.is_void() ? ffi::String("void") + : ffi::DLDataTypeToString(static_cast(dt)); + args.push_back(tirx::StringImm(dtype_str)); + } else { + args.push_back(node->element_type); + } + if (!node->storage_scope.empty()) { + args.push_back(tirx::StringImm(node->storage_scope)); + } + return args; + }); TupleTypeNode::RegisterReflection(); FuncTypeNode::RegisterReflection(); TensorMapTypeNode::RegisterReflection(); diff --git a/src/relax/distributed/struct_info.cc b/src/relax/distributed/struct_info.cc index 731564825015..c430d5494f28 100644 --- a/src/relax/distributed/struct_info.cc +++ b/src/relax/distributed/struct_info.cc @@ -22,16 +22,24 @@ * \brief Relax dtensor struct info. */ +#include #include #include + +#include "../ir/script_print_utils.h" + namespace tvm { namespace relax { namespace distributed { TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = ::tvm::ffi::reflection; DTensorStructInfoNode::RegisterReflection(); PlacementNode::RegisterReflection(); PlacementSpecNode::RegisterReflection(); + refl::GlobalDef().def("relax._placement_str", [](ffi::AnyView /*ctx*/, Placement node) -> ffi::String { + return node->ToString(); + }); } PlacementSpec PlacementSpec::Sharding(int axis) { @@ -144,6 +152,106 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } +// ---- __ffi_text_print__ overrides ---- + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // DeviceMesh: R.device_mesh((dim0, dim1, ...), I.Range(start, end)) + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](DeviceMesh node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + // Build shape tuple from ffi::Shape + ffi::List shape_elts; + for (size_t i = 0; i < node->shape.size(); ++i) { + shape_elts.push_back(text::LiteralAST::Int(node->shape[i])); + } + text::ExprAST shape_doc = text::TupleAST({}, std::move(shape_elts)); + // Build second arg: I.Range(start, end) or device_ids list + ffi::List args; + args.push_back(shape_doc); + if (node->device_range.defined()) { + Range r = node->device_range.value(); + text::ExprAST range_begin = Print(printer, r->min, path->Attr("device_range")->Attr("min")); + // Range stores (min, extent); V1 prints I.Range(min, min + extent) + PrimExpr end_expr = r->min + r->extent; + text::ExprAST range_end = Print(printer, end_expr, path->Attr("device_range")->Attr("extent")); + args.push_back(text::ExprCall(IR("Range"), {range_begin, range_end})); + } else { + args.push_back(Print(printer, node->device_ids, path->Attr("device_ids"))); + } + return text::ExprCall(Relax("device_mesh"), std::move(args)); + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // DTensorStructInfo: R.DTensor(shape, dtype, "mesh[i]", "placement_str") + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](DTensorStructInfo n, text::IRPrinter printer, + text::AccessPath path) -> text::NodeAST { + ffi::List args; + ffi::List kwargs_keys; + ffi::List kwargs_values; + bool require_kwargs = false; + // Shape from tensor_sinfo + if (n->tensor_sinfo->shape.defined()) { + if (const auto* shape = n->tensor_sinfo->shape.value().as()) { + auto shape_expr = ffi::GetRef(shape); + text::AccessPath shape_p = path->Attr("tensor_sinfo")->Attr("shape")->Attr("values"); + ffi::List shape_docs; + for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { + shape_docs.push_back(PrintShapeValue(shape_expr->values[i], + shape_p->ArrayItem(i), printer, + /*stringify_vars=*/false)); + } + args.push_back(text::TupleAST({}, std::move(shape_docs))); + } else { + args.push_back(Print(printer, n->tensor_sinfo->shape.value(), + path->Attr("tensor_sinfo")->Attr("shape"))); + } + } else { + require_kwargs = true; + } + // dtype + if (!n->tensor_sinfo->IsUnknownDtype()) { + if (!require_kwargs) { + args.push_back(text::LiteralAST::Str(DType2Str(n->tensor_sinfo->dtype))); + } else { + kwargs_keys.push_back(ffi::String("dtype")); + kwargs_values.push_back(text::LiteralAST::Str(DType2Str(n->tensor_sinfo->dtype))); + } + } else { + require_kwargs = true; + } + // device_mesh: print as string reference "mesh[i]" or inline + if (!require_kwargs) { + args.push_back(Print(printer, n->device_mesh, path->Attr("device_mesh"))); + } else { + kwargs_keys.push_back(ffi::String("device_mesh")); + kwargs_values.push_back(Print(printer, n->device_mesh, path->Attr("device_mesh"))); + } + // placement + if (!require_kwargs) { + args.push_back(Print(printer, n->placement, path->Attr("placement"))); + } else { + kwargs_keys.push_back(ffi::String("placement")); + kwargs_values.push_back(Print(printer, n->placement, path->Attr("placement"))); + } + // ndim when shape is not defined + if (!n->tensor_sinfo->shape.defined() && !n->tensor_sinfo->IsUnknownNdim()) { + kwargs_keys.push_back(ffi::String("ndim")); + kwargs_values.push_back(text::LiteralAST::Int(n->tensor_sinfo->ndim)); + } + return text::ExprCallKw(Relax("DTensor"), std::move(args), + std::move(kwargs_keys), std::move(kwargs_values)); + }); +} + } // namespace distributed } // namespace relax } // namespace tvm diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 2fbd573a5f7f..8b4bacf9490f 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -24,6 +25,8 @@ #include +#include "script_print_utils.h" + namespace tvm { namespace relax { @@ -52,6 +55,43 @@ TVM_FFI_STATIC_INIT_BLOCK() { ExternFuncNode::RegisterReflection(); } +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + refl::GlobalDef() + .def("relax._shape_args", [](ffi::AnyView /*ctx*/, ShapeExpr node) -> ffi::Array { + ffi::Array wrapper; + wrapper.push_back(node->values); + return wrapper; + }) + .def("relax._dtype_str", [](ffi::AnyView /*ctx*/, DataTypeImm node) -> ffi::String { + DataType dt = node->value; + return dt.is_void() ? ffi::String("void") + : ffi::DLDataTypeToString(static_cast(dt)); + }) + .def("relax._dataflow_outputs", + [](ffi::AnyView /*ctx*/, DataflowBlock block) -> ::tvm::ffi::Function { + return ::tvm::ffi::Function::FromTyped( + [](::tvm::ffi::ObjectRef obj, text::IRPrinter printer, text::DefaultFrame frame) { + DataflowBlock blk = ::tvm::Downcast(obj); + ffi::List outputs; + for (const auto& b : blk->bindings) { + if (!b->var->IsInstance()) { + ffi::Optional var_expr = printer->VarGet(b->var); + if (var_expr.has_value()) { + outputs.push_back(var_expr.value()); + } + } + } + if (!outputs.empty()) { + text::ExprAST callee = text::ExprAttr(text::IdAST("R"), "output"); + frame->stmts.push_back( + text::ExprStmtAST(text::ExprCall(callee, std::move(outputs)))); + } + }); + }); +} + Id::Id(ffi::String name_hint) { ObjectPtr n = ffi::make_object(); n->name_hint = std::move(name_hint); @@ -754,5 +794,764 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } +// ---- __ffi_text_print__ overrides ---- + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // DataflowBlock: with R.dataflow(): bindings + R.output(...) + // Must use __ffi_text_print__ (not trait) to preserve output vars in parent scope. + // The With trait's FramePop removes all vars, causing the SeqExpr body return to fail. + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](DataflowBlock node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + text::DefaultFrame frame; + printer->FramePush(frame); + ffi::List body; + for (int i = 0; i < static_cast(node->bindings.size()); ++i) { + text::NodeAST s = printer->operator()(ffi::Any(node->bindings[i]), + path->Attr("bindings")->ArrayItem(i)) + .cast(); + if (auto* block = s.as()) { + for (const auto& st : block->stmts) body.push_back(st); + } else if (s->IsInstance()) { + body.push_back(Downcast(s)); + } else if (s->IsInstance()) { + body.push_back(text::ExprStmtAST(Downcast(s))); + } + } + // R.output() for non-DataflowVar bindings + ffi::List outputs; + for (const auto& b : node->bindings) { + if (!b->var->IsInstance()) { + if (auto var_doc = printer->VarGet(b->var)) { + outputs.push_back(var_doc.value()); + } + } + } + if (!outputs.empty()) { + body.push_back(text::ExprStmtAST(text::ExprCall(Relax("output"), std::move(outputs)))); + } + // Pop frame BUT re-register output vars in the PARENT frame so they're + // accessible to the enclosing SeqExpr's return statement. + ffi::List output_vars; + for (const auto& b : node->bindings) { + if (!b->var->IsInstance()) { + output_vars.push_back(b->var); + } + } + // Save the var info before popping + ffi::Map saved_info; + for (const auto& ov : output_vars) { + auto it = printer->obj2info.find(ov); + if (it != printer->obj2info.end()) { + saved_info.Set(ov, (*it).second); + } + } + printer->FramePop(); + // Re-register output vars in parent frame + if (!printer->frames.empty()) { + ffi::ObjectRef parent_frame = printer->frames.back().cast(); + for (const auto& kv : saved_info) { + printer->obj2info.Set(kv.first, kv.second); + // Add to parent frame's var list + auto frame_vars = printer->frame_vars[parent_frame].cast>(); + frame_vars.push_back(kv.first); + } + } + + text::ExprAST ctx = text::ExprCall(Relax("dataflow"), {}); + return text::WithAST(ffi::Optional(), ctx, body); + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // relax.Function: @R.function with struct_info annotations and prologue + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](Function func, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + // Step 1. Determine function name + ffi::String func_name = FindFuncName(func, printer, path); + text::IdAST name = text::IdAST(func_name); + bool at_top_level = AtTopLevelInModule(path); + bool has_global_symbol = func->attrs.defined() && + func->attrs->dict.count("global_symbol"); + + // Step 2. Build decorator: @R.function with optional kwargs + ffi::List dec_keys; + ffi::List dec_values; + // pure=False when function is impure + if (!func->is_pure) { + dec_keys.push_back(ffi::String("pure")); + dec_values.push_back(text::LiteralAST::Bool(false)); + } + // private=True when at top level (or standalone) without global_symbol + // V1 marks standalone functions and top-level module functions as private + // when they lack a global_symbol attribute. + bool is_standalone = !path->step.defined() || path->depth <= 1; + if ((at_top_level || is_standalone) && !has_global_symbol) { + dec_keys.push_back(ffi::String("private")); + dec_values.push_back(text::LiteralAST::Bool(true)); + } + ffi::List decorators; + if (dec_keys.size() > 0) { + decorators.push_back( + text::ExprCallKw(Relax("function"), {}, std::move(dec_keys), std::move(dec_values))); + } else { + decorators.push_back(Relax("function")); + } + + // Step 3. Push frame, define symbolic TIR vars, then print params. + text::DefaultFrame frame; + printer->FramePush(frame); + int n = func->params.size(); + + // Step 3a. Collect and define symbolic TIR vars from ALL struct_info + // in the function: params, match_cast bindings, and var bindings. + // This ensures all TIR vars are defined at function level (matching V1). + // Without this, TIR vars first seen inside a DataflowBlock's match_cast + // would be scoped inside the dataflow block, causing undefined-var errors + // when they're referenced outside (e.g., in the return statement). + { + std::vector tir_vars; + std::unordered_set seen; + // Collect from params + for (int i = 0; i < n; ++i) { + Var var = func->params[i]; + if (var->struct_info_.defined()) { + StructInfo si = Downcast(var->struct_info_.value()); + CollectTIRVarsFromStructInfo(si, &tir_vars, &seen); + } + } + // Collect from body bindings (match_cast struct_info and var struct_info). + // Include FuncStructInfo vars: inner functions may reference TIR vars + // from the outer scope (e.g., symbolic shape dims shared between outer + // and inner function params). The `seen` set prevents duplicates. + SeqExpr body_seq = Downcast(func->body); + for (const auto& block : body_seq->blocks) { + for (const auto& binding : block->bindings) { + if (const auto* mc = binding.as()) { + CollectTIRVarsFromStructInfo(mc->struct_info, &tir_vars, &seen); + } + if (binding->var->struct_info_.defined()) { + StructInfo si = Downcast(binding->var->struct_info_.value()); + CollectTIRVarsFromStructInfo(si, &tir_vars, &seen); + } + } + } + // Also collect from the return expression's struct_info if applicable + if (body_seq->body->struct_info_.defined()) { + StructInfo si = Downcast(body_seq->body->struct_info_.value()); + CollectTIRVarsFromStructInfo(si, &tir_vars, &seen); + } + for (const auto& tir_var : tir_vars) { + // Skip vars with empty name_hint (synthetic vars from FuncStructInfo etc.) + if (tir_var->name_hint.empty()) continue; + if (!printer->VarGet(tir_var).has_value()) { + printer->VarDef(tir_var->name_hint, tir_var, frame); + text::ExprAST var_id = printer->VarGet(tir_var).value(); + std::string dtype_str = DType2Str(tir_var->dtype); + // Match V1's PrintVarCreation: add is_size_var=True kwarg for SizeVar + if (tir_var->IsInstance()) { + frame->stmts.push_back(text::AssignAST( + var_id, + text::ExprCallKw(TIR(dtype_str), {}, + {ffi::String("is_size_var")}, {text::LiteralAST::Bool(true)}), + ffi::Optional())); + } else { + frame->stmts.push_back( + text::AssignAST(var_id, text::ExprCall(TIR(dtype_str), {}), ffi::Optional())); + } + } + } + } + + // Step 3b. Print params (tirx::Vars in struct_info shapes are now defined) + // Enable stringify_vars for param annotations ONLY for top-level/standalone + // functions. Inner (nested) functions should use resolved var references + // since TIR vars are already in scope from the outer function (matching V1). + bool should_stringify = at_top_level || is_standalone; + g_printing_func_annotation = should_stringify; + ffi::List params; + for (int i = 0; i < n; ++i) { + Var var = func->params[i]; + text::AccessPath var_p = path->Attr("params")->ArrayItem(i); + printer->VarDef(var->vid->name_hint, var, frame); + text::ExprAST var_id = printer->VarGet(var).value(); + ffi::Optional annotation; + if (var->struct_info_.defined()) { + annotation = Print(printer, var->struct_info_.value(), + var_p->Attr("struct_info_")); + } + params.push_back(text::AssignAST(var_id, ffi::Optional(), annotation)); + } + g_printing_func_annotation = false; + + // Step 4. Print attributes (filter global_symbol when it matches func name) + // V1 filters global_symbol for top-level functions (both in-module and standalone) + // when the symbol matches the function name being used. + bool should_filter_global_symbol = has_global_symbol && + func->attrs->dict.at("global_symbol").cast() == func_name && + (at_top_level || is_standalone); + if (func->attrs.defined() && !func->attrs->dict.empty()) { + if (should_filter_global_symbol) { + // global_symbol matches func name: filter it out + ffi::Map filtered; + for (const auto& kv : func->attrs->dict) { + if (kv.first != "global_symbol") { + filtered.Set(kv.first, kv.second); + } + } + if (!filtered.empty()) { + frame->stmts.push_back( + text::ExprStmtAST(text::ExprCall(Relax("func_attr"), + {Print(printer, DictAttrs(filtered), + path->Attr("attrs"))}))); + } + } else { + frame->stmts.push_back( + text::ExprStmtAST(text::ExprCall(Relax("func_attr"), + {Print(printer, func->attrs, path->Attr("attrs"))}))); + } + } + + // Step 5. Print body: inline SeqExpr handling. + // We must NOT delegate to the SeqExpr Seq trait because the trait's + // PrintBody/PrintSeq pushes its own frame for the blocks. When a + // DataflowBlock's With trait FramePop removes vars defined inside + // the block, the subsequent `ret` resolution can no longer find the + // body Var, producing "Undefined variable: v". + // Instead, we print each block inline and resolve the return var + // while DataflowBlock vars are still in scope. + ffi::List body; + SeqExpr seq = Downcast(func->body); + for (int i = 0; i < static_cast(seq->blocks.size()); ++i) { + text::NodeAST block_ast = printer->operator()( + ffi::Any(seq->blocks[i]), + path->Attr("body")->Attr("blocks")->ArrayItem(i)).cast(); + if (auto* sb = block_ast.as()) { + for (const auto& s : sb->stmts) body.push_back(s); + } else if (block_ast->IsInstance()) { + body.push_back(Downcast(block_ast)); + } + } + // Resolve the return var BEFORE FramePop, while all block vars + // (including DataflowBlock vars) are still accessible. + text::ExprAST ret_expr = Print(printer, seq->body, path->Attr("body")->Attr("body")); + body.push_back(text::ReturnAST(ret_expr)); + + ffi::List all_body; + for (const auto& s : frame->stmts) all_body.push_back(s); + for (const auto& s : body) all_body.push_back(s); + + // Step 6. Print return type from FuncStructInfo->ret (matching V1) + // Must be done BEFORE FramePop so that TIR Vars defined in the + // function scope (e.g. N, N_1) are still accessible for name resolution. + // Only stringify vars for top-level/standalone functions (matching V1). + ffi::Optional ret_type; + if (const auto* func_sinfo = func->struct_info_.as()) { + g_printing_func_annotation = should_stringify; + ret_type = Print(printer, func_sinfo->ret, path->Attr("struct_info_")->Attr("ret")); + g_printing_func_annotation = false; + } + printer->FramePop(); + text::FunctionAST func_ast(name, params, decorators, ret_type, all_body); + + // When printing standalone (not in module context), add header comments + // matching V1's HeaderWrapper behavior for relax functions. + if (is_standalone && !at_top_level) { + ffi::List result; + result.push_back(text::CommentAST( + ffi::Optional(ffi::String("from tvm.script import tirx as T")))); + result.push_back(text::CommentAST( + ffi::Optional(ffi::String("from tvm.script import relax as R")))); + result.push_back(text::CommentAST(ffi::Optional())); + result.push_back(func_ast); + return text::StmtBlockAST(result); + } + return func_ast; + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // Tuple: (a, b, ...) or R.tuple() for empty + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](Tuple node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + if (node->fields.empty()) { + return text::ExprCall(Relax("tuple"), {}); + } + ffi::List elts; + for (int i = 0; i < static_cast(node->fields.size()); ++i) { + elts.push_back(Print(printer, node->fields[i], path->Attr("fields")->ArrayItem(i))); + } + return text::TupleAST({}, std::move(elts)); + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // Constant: R.const(value, dtype) for scalars, metadata[...] placeholder for tensors + // For DTensorStructInfo: R.dist.const(value, struct_info_ann) + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](Constant node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + if (ffi::Optional s = SpecialScalar(node->data, path->Attr("data"))) { + // Check if DTensorStructInfo -> use R.dist.const + if (node->struct_info_.defined() && + node->struct_info_.value() + .as()) { + text::ExprAST ann = Print(printer, node->struct_info_.value(), + path->Attr("struct_info_")); + return text::ExprCall(Relax("dist.const"), {s.value(), ann}); + } + return text::ExprCall(Relax("const"), + {s.value(), text::LiteralAST::Str(DType2Str(DataType(node->data->dtype)))}); + } + // Non-scalar: emit R.const(0, dtype) as a lossy placeholder. + // V2 does not have a metadata registry like V1, so we cannot + // faithfully round-trip non-scalar constants yet. + return text::ExprCall(Relax("const"), + {text::LiteralAST::Int(0), + text::LiteralAST::Str(DType2Str(DataType(node->data->dtype)))}); + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // relax.If: __ffi_text_print__ override + // Matches V1's PrintIfExpr: prints If with branches that use ExprStmt (not return). + // When used standalone (not inside VarBinding), the branches just have ExprStmts. + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](If node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + text::ExprAST cond = Print(printer, node->cond, path->Attr("cond")); + ffi::List then_branch = PrintSeqExprBody( + node->true_branch, path->Attr("true_branch"), printer); + ffi::List else_branch = PrintSeqExprBody( + node->false_branch, path->Attr("false_branch"), printer); + return text::IfAST(cond, then_branch, else_branch); + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // relax.VarBinding: __ffi_text_print__ override + // For If values in VarBinding, the Assign trait drops the assignment (returns IfAST + // directly). To fix this, we register __ffi_text_print__ on VarBinding to intercept + // If values and produce IfAST with assignments in branches. For non-If values, + // we replicate the Assign trait behavior exactly. + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](VarBinding node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + if (const auto* if_node = node->value.as()) { + // --- If case: produce IfAST with assignment in each branch --- + text::IdAST lhs = printer->VarDef(node->var->vid->name_hint, node->var, + ffi::Optional{}); + ffi::Optional ann; + if (node->var->struct_info_.defined()) { + ann = Print(printer, node->var->struct_info_.value(), + path->Attr("var")->Attr("struct_info_")); + } + text::ExprAST cond = Print(printer, if_node->cond, path->Attr("value")->Attr("cond")); + auto make_branch = [&](const SeqExpr& seq, const text::AccessPath& seq_path) + -> ffi::List { + ffi::List stmts; + for (int i = 0; i < static_cast(seq->blocks.size()); ++i) { + text::NodeAST block_ast = printer->operator()( + ffi::Any(seq->blocks[i]), + seq_path->Attr("blocks")->ArrayItem(i)).cast(); + if (auto* sb = block_ast.as()) { + for (const auto& s : sb->stmts) stmts.push_back(s); + } else if (block_ast->IsInstance()) { + stmts.push_back(Downcast(block_ast)); + } + } + text::ExprAST ret_expr = Print(printer, seq->body, seq_path->Attr("body")); + stmts.push_back(text::AssignAST(lhs, ret_expr, ann)); + return stmts; + }; + ffi::List then_branch = make_branch( + if_node->true_branch, path->Attr("value")->Attr("true_branch")); + ffi::List else_branch = make_branch( + if_node->false_branch, path->Attr("value")->Attr("false_branch")); + return text::IfAST(cond, then_branch, else_branch); + } + // --- Non-If case: replicate Assign trait behavior --- + // Define LHS variable + text::IdAST lhs = printer->VarDef(node->var->vid->name_hint, node->var, + ffi::Optional{}); + // Type annotation from var's struct_info_ (matching V1 behavior). + // V1 always emits struct_info annotations on intermediate relax vars. + ffi::Optional ann; + if (node->var->struct_info_.defined()) { + ann = Print(printer, node->var->struct_info_.value(), + path->Attr("var")->Attr("struct_info_")); + } + // Print RHS + ffi::Any rhs_result = printer->operator()(ffi::Any(node->value), path->Attr("value")); + text::NodeAST rhs_node = rhs_result.cast(); + // Handle Function RHS + if (auto* func = rhs_node.as()) { + return text::FunctionAST(lhs, func->args, func->decorators, func->return_type, func->body); + } + // Normal expression RHS: produce AssignAST with struct_info annotation + if (rhs_node->IsInstance()) { + return text::AssignAST(lhs, Downcast(rhs_node), ann); + } + // Statement-level RHS (StmtBlock, etc.): return directly + return rhs_node; + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // relax.MatchCast: __ffi_text_print__ override + // Prints: m = T.int64() ; n = T.int64() ; _: type_ann = R.match_cast(value, struct_info) + // This defines symbolic shape variables used later (e.g. n, m from R.Tensor((n,m), ...)). + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](MatchCast node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + ffi::List stmts; + // Step 1. Collect TIR vars from the match struct_info and define them + { + std::vector tir_vars; + std::unordered_set seen; + CollectTIRVarsFromStructInfo(node->struct_info, &tir_vars, &seen); + for (const auto& tir_var : tir_vars) { + if (!printer->VarGet(tir_var).has_value()) { + printer->VarDef(tir_var->name_hint, tir_var, ffi::Optional{}); + text::ExprAST var_id = printer->VarGet(tir_var).value(); + std::string dtype_str = DType2Str(tir_var->dtype); + stmts.push_back( + text::AssignAST(var_id, text::ExprCall(TIR(dtype_str), {}), ffi::Optional())); + } + } + } + // Step 2. Build RHS: R.match_cast(value, struct_info) + text::ExprAST val = Print(printer, node->value, path->Attr("value")); + text::ExprAST si = Print(printer, node->struct_info, path->Attr("struct_info")); + text::ExprAST rhs = text::ExprCall(Relax("match_cast"), {val, si}); + // Step 3. Define LHS variable + text::IdAST lhs = printer->VarDef(node->var->vid->name_hint, node->var, + ffi::Optional{}); + // Type annotation from var's struct info + ffi::Optional ann; + if (node->var->struct_info_.defined()) { + ann = Print(printer, node->var->struct_info_.value(), + path->Attr("var")->Attr("struct_info_")); + } + stmts.push_back(text::AssignAST(lhs, rhs, ann)); + if (stmts.size() == 1) { + return stmts[0]; + } + return text::StmtBlockAST(std::move(stmts)); + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // relax.Call: __ffi_text_print__ override + // Maps Op-based calls to R.name(args...) syntax matching V1. + // Handles ExternFunc, Op, Var, and GlobalVar ops. + // Special cases: assert_op, hint_on_device, to_vdevice, call_tir family. + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](Call call, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + text::ExprAST prefix(ffi::UnsafeInit{}); + ffi::List args; + ffi::List kw_keys; + ffi::List kw_vals; + + // Determine the op name for special-case checks + std::string op_name; + if (const auto* op_node = call->op.as()) { + op_name = op_node->name; + } + + // Step 1. Determine callee prefix + if (const auto* ef = call->op.as()) { + prefix = Relax("call_packed"); + args.push_back(text::LiteralAST::Str(ef->global_symbol)); + } else if (const auto* op_node = call->op.as()) { + std::string name = op_node->name; + if (name.rfind("relax.", 0) == 0) { + prefix = Relax(name.substr(6)); + } else { + prefix = text::IdAST(name); + } + } else if (call->op->IsInstance()) { + prefix = Print(printer, call->op, path->Attr("op")); + } else if (call->op->IsInstance()) { + // Check if the GlobalVar has been registered (e.g. from module prologue) + if (auto bound = printer->VarGet(call->op)) { + prefix = bound.value(); + } else { + // Fallback: search by name_hint in obj2info (same as tirx Call printer) + GlobalVar op_gv = Downcast(call->op); + bool found = false; + for (const auto& kv : printer->obj2info) { + if (const auto* gv_node = kv.first.as()) { + if (gv_node->name_hint == op_gv->name_hint) { + prefix = kv.second->creator().cast(); + found = true; + break; + } + } + } + if (!found) { + prefix = Print(printer, call->op, path->Attr("op")); + } + } + } else { + prefix = Print(printer, call->op, path->Attr("op")); + } + + // ---- Special case: assert_op ---- + // V1 prints: R.assert_op(cond, *format_args, format=format_str) + // args[0]=cond, args[1]=format_str, args[2:]=format_args + if (op_name == "relax.assert_op" && call->args.size() >= 2) { + args.push_back(Print(printer, call->args[0], path->Attr("args")->ArrayItem(0))); + text::ExprAST format_str = Print(printer, call->args[1], path->Attr("args")->ArrayItem(1)); + for (int i = 2, n = call->args.size(); i < n; ++i) { + args.push_back(Print(printer, call->args[i], path->Attr("args")->ArrayItem(i))); + } + kw_keys.push_back(ffi::String("format")); + kw_vals.push_back(format_str); + return text::CallAST(prefix, std::move(args), std::move(kw_keys), std::move(kw_vals)); + } + + // ---- Special case: print ---- + // V1 prints: R.print(*format_args, format=format_str) + // args[0]=format_str, args[1:]=format_args + // The format string must be a keyword arg to avoid round-trip failure + // (otherwise it's interpreted as a positional arg and a new default + // format string is added). + if (op_name == "relax.print" && call->args.size() >= 1) { + text::ExprAST format_str = Print(printer, call->args[0], path->Attr("args")->ArrayItem(0)); + for (int i = 1, n = call->args.size(); i < n; ++i) { + args.push_back(Print(printer, call->args[i], path->Attr("args")->ArrayItem(i))); + } + kw_keys.push_back(ffi::String("format")); + kw_vals.push_back(format_str); + return text::CallAST(prefix, std::move(args), std::move(kw_keys), std::move(kw_vals)); + } + + // ---- Special case: hint_on_device ---- + // V1 prints: R.hint_on_device(expr, R.device(device_type=N, index=M), memory_scope) + if (op_name == "relax.hint_on_device") { + args.push_back(Print(printer, call->args[0], path->Attr("args")->ArrayItem(0))); + if (call->attrs.defined()) { + if (const auto* attrs = call->attrs.as()) { + ffi::List dev_keys; + ffi::List dev_vals; + dev_keys.push_back(ffi::String("device_type")); + dev_vals.push_back(text::LiteralAST::Int(attrs->device_type)); + dev_keys.push_back(ffi::String("index")); + dev_vals.push_back(text::LiteralAST::Int(attrs->index)); + args.push_back(text::ExprCallKw(Relax("device"), {}, std::move(dev_keys), + std::move(dev_vals))); + args.push_back(text::LiteralAST::Str(std::string(attrs->memory_scope))); + } + } + return text::CallAST(prefix, std::move(args), {}, {}); + } + + // ---- Special case: to_vdevice ---- + // V1 prints: R.to_vdevice(expr, dst_vdevice="kind:index:scope") + if (op_name == "relax.to_vdevice") { + args.push_back(Print(printer, call->args[0], path->Attr("args")->ArrayItem(0))); + if (call->attrs.defined()) { + if (const auto* attrs = call->attrs.as()) { + VDevice vdev = attrs->dst_vdevice; + kw_keys.push_back(ffi::String("dst_vdevice")); + // Use the pre-registered VDevice string from module.cc (kind:index:scope) + if (auto opt_str = printer->VarGet(vdev)) { + kw_vals.push_back(opt_str.value()); + } else { + // Fallback: compute from target info + std::string dev_kind = vdev->target.defined() + ? std::string(vdev->target->kind->name) + : "unknown"; + kw_vals.push_back(text::LiteralAST::Str( + dev_kind + ":" + std::to_string(vdev->vdevice_id) + ":" + + std::string(vdev->memory_scope))); + } + } + } + return text::CallAST(prefix, std::move(args), std::move(kw_keys), std::move(kw_vals)); + } + + // ---- Special case: call_tir family ---- + // V1 prints: R.call_tir(callee, input_tuple, out_sinfo=..., [tir_vars=...]) + // Only args[0] (callee) and args[1] (input tuple) are positional. + // args[2] (tir_vars) is a kwarg. out_sinfo comes from sinfo_args. + { + bool is_call_tir_family = (op_name == "relax.call_tir" || + op_name == "relax.call_tir_inplace" || + op_name == "relax.call_dps_packed" || + op_name == "relax.call_tir_with_grad" || + op_name == "relax.dist.call_tir_local_view"); + if (is_call_tir_family) { + // Positional: callee (args[0]) and input tuple (args[1]) + args.push_back(Print(printer, call->args[0], path->Attr("args")->ArrayItem(0))); + args.push_back(Print(printer, call->args[1], path->Attr("args")->ArrayItem(1))); + // out_sinfo from sinfo_args[0] + // Also detect if any out_sinfo is DTensorStructInfo to choose dist.call_tir + bool is_dtensor = false; + if (call->sinfo_args.size() > 0) { + StructInfo o_sinfo = Downcast(call->sinfo_args[0]); + text::AccessPath o_sinfo_p = path->Attr("sinfo_args")->ArrayItem(0); + kw_keys.push_back(ffi::String("out_sinfo")); + if (const auto* o = o_sinfo.as()) { + ffi::List fields; + text::AccessPath fields_p = o_sinfo_p->Attr("fields"); + for (int i = 0, l = o->fields.size(); i < l; ++i) { + if (o->fields[i].as()) { + is_dtensor = true; + } + fields.push_back(Print(printer, o->fields[i], fields_p->ArrayItem(i))); + } + kw_vals.push_back(text::ListAST({}, std::move(fields))); + } else { + if (o_sinfo.as()) { + is_dtensor = true; + } + kw_vals.push_back(Print(printer, o_sinfo, o_sinfo_p)); + } + } + // call_tir_inplace: inplace_indices kwarg + if (op_name == "relax.call_tir_inplace") { + if (const auto* attrs = call->attrs.as()) { + kw_keys.push_back(ffi::String("inplace_indices")); + ffi::List index_fields; + for (const auto& idx : attrs->inplace_indices) { + index_fields.push_back(text::LiteralAST::Int(idx.IntValue())); + } + kw_vals.push_back(text::ListAST({}, std::move(index_fields))); + } + } + // call_tir_with_grad: te_grad_name, te_grad_kwargs + if (op_name == "relax.call_tir_with_grad") { + if (const auto* attrs = call->attrs.as()) { + kw_keys.push_back(ffi::String("te_grad_name")); + kw_vals.push_back(text::LiteralAST::Str(std::string(attrs->te_grad_name))); + if (!attrs->te_grad_kwargs.empty()) { + kw_keys.push_back(ffi::String("te_grad_kwargs")); + kw_vals.push_back(Print(printer, attrs->te_grad_kwargs, + path->Attr("attrs")->Attr("te_grad_kwargs"))); + } + } + } + // tir_vars: args[2] as kwarg (not for call_dps_packed) + if (call->args.size() >= 3 && op_name != "relax.call_dps_packed") { + kw_keys.push_back(ffi::String("tir_vars")); + kw_vals.push_back(Print(printer, call->args[2], path->Attr("args")->ArrayItem(2))); + } + // Choose the right call variant: + // - dist.call_tir_local_view stays as is + // - call_tir_with_grad stays as is + // - call_dps_packed stays as is + // - call_tir_inplace stays as is + // - call_tir: if DTensor sinfo detected, use dist.call_tir instead + if (op_name == "relax.dist.call_tir_local_view") { + prefix = Relax("dist.call_tir_local_view"); + } else if (op_name == "relax.call_tir_with_grad") { + prefix = Relax("call_tir_with_grad"); + } else if (op_name == "relax.call_dps_packed") { + prefix = Relax("call_dps_packed"); + } else if (op_name == "relax.call_tir_inplace") { + prefix = Relax("call_tir_inplace"); + } else if (is_dtensor) { + prefix = Relax("dist.call_tir"); + } else { + prefix = Relax("call_tir"); + } + return text::CallAST(prefix, std::move(args), std::move(kw_keys), std::move(kw_vals)); + } + } + + // Step 2. Print args (non-special-case ops) + for (int i = 0, n = call->args.size(); i < n; ++i) { + args.push_back(Print(printer, call->args[i], path->Attr("args")->ArrayItem(i))); + } + + // Step 3. Print attrs as kwargs + if (call->attrs.defined()) { + if (call->op->IsInstance()) { + kw_keys.push_back(ffi::String("attrs_type_key")); + kw_vals.push_back(text::LiteralAST::Str(call->attrs->GetTypeKey())); + } + if (const auto* dict_attrs = call->attrs.as()) { + // Sort attrs by key for deterministic output + std::vector> sorted; + for (const auto& kv : dict_attrs->dict) { + sorted.push_back(kv); + } + std::sort(sorted.begin(), sorted.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); + for (const auto& kv : sorted) { + kw_keys.push_back(kv.first); + kw_vals.push_back( + Print(printer, kv.second, path->Attr("attrs")->Attr(kv.first))); + } + } else if (call->attrs.defined()) { + // Non-DictAttrs: use reflection to iterate fields + const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(call->attrs->type_index()); + ffi::reflection::ForEachFieldInfo(info, [&](const TVMFFIFieldInfo* fi) { + ffi::String fname(fi->name.data, fi->name.size); + if (fname == "span") return; // skip span field + ffi::Any field_val = ffi::reflection::FieldGetter(fi)(call->attrs); + kw_keys.push_back(fname); + // Special-case DataType fields: the raw DLDataType value doesn't + // round-trip through the generic printer (it falls to "handle"). + // Convert to a string literal matching V1 behavior. + if (field_val.type_index() == ffi::TypeIndex::kTVMFFIDataType) { + DLDataType dt = field_val.cast(); + kw_vals.push_back(text::LiteralAST::Str(DType2Str(runtime::DataType(dt)))); + } else { + kw_vals.push_back( + Print(printer, std::move(field_val), path->Attr("attrs")->Attr(fname))); + } + }); + } + } + + // Step 4. Print sinfo_args + // (call_tir family already returned above, so no duplication here) + if (call->sinfo_args.size() > 0) { + text::AccessPath sinfo_p = path->Attr("sinfo_args"); + ffi::List sinfo_docs; + for (int i = 0, n = call->sinfo_args.size(); i < n; ++i) { + sinfo_docs.push_back(Print(printer, call->sinfo_args[i], sinfo_p->ArrayItem(i))); + } + kw_keys.push_back(ffi::String("sinfo_args")); + kw_vals.push_back(text::TupleAST({}, std::move(sinfo_docs))); + } + + if (!kw_keys.empty()) { + return text::CallAST(prefix, std::move(args), std::move(kw_keys), std::move(kw_vals)); + } + return text::CallAST(prefix, std::move(args), {}, {}); + }); +} + } // namespace relax } // namespace tvm diff --git a/src/relax/ir/script_print_utils.h b/src/relax/ir/script_print_utils.h new file mode 100644 index 000000000000..7e0bbab25bd3 --- /dev/null +++ b/src/relax/ir/script_print_utils.h @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAX_IR_SCRIPT_PRINT_UTILS_H_ +#define TVM_RELAX_IR_SCRIPT_PRINT_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../ir/printer_utils.h" + +namespace tvm { +namespace printer { + +using namespace relax; + +// Thread-local flag: true when printing function param/return annotations. +// Used by PrintShapeValue to decide whether to stringify symbolic TIR vars. +inline thread_local bool g_printing_func_annotation = false; + +/*! + * \brief Determine the function name from context. + * + * Priority: (1) GlobalVar from text::AccessPath MapItem key (module context), + * (2) global_symbol attribute, + * (3) fallback "main". + */ +inline ffi::String FindFuncName(const relax::Function& func, const text::IRPrinter& printer, + const text::AccessPath& path) { + // Priority 1: binding name from module context (set via VarDefNoName in IRModule prologue) + if (auto binding_expr = printer->VarGet(func)) { + if (const auto* id_node = binding_expr.value().as()) { + return id_node->name; + } + } + // Priority 2: GlobalVar from text::AccessPath MapItem key + if (path->step.defined()) { + const auto& step = path->step.value(); + if (step->kind == ffi::reflection::AccessKind::kMapItem) { + if (const auto* gv = step->key.as()) { + return gv->name_hint; + } + } + } + // Priority 3: global_symbol attribute + if (func->attrs.defined()) { + auto it = func->attrs->dict.find("global_symbol"); + if (it != func->attrs->dict.end()) { + return (*it).second.cast(); + } + } + return "main"; +} + +/*! + * \brief Check if this function is at top level in a module. + * + * A function is at the top level if it is reached via + * Root->Attr("functions")->MapItem(gv). + */ +inline bool AtTopLevelInModule(const text::AccessPath& path) { + // path should be MapItem(gv) + if (!path->step.defined()) return false; + if (path->step.value()->kind != ffi::reflection::AccessKind::kMapItem) return false; + // parent should be Attr("functions") + auto parent_opt = path->GetParent(); + if (!parent_opt.has_value()) return false; + text::AccessPath parent = parent_opt.value(); + if (!parent->step.defined()) return false; + if (parent->step.value()->kind != ffi::reflection::AccessKind::kAttr) return false; + return true; +} + +/*! + * \brief Collect all tirx::Var references from a PrimExpr (e.g. shape dims). + */ +inline void CollectTIRVarsFromPrimExpr(const PrimExpr& expr, + std::vector* out, + std::unordered_set* seen) { + tirx::PostOrderVisit(expr, [&](const ffi::ObjectRef& obj) { + if (const auto* tv = obj.as()) { + if (seen->insert(tv).second) { + out->push_back(ffi::GetRef(tv)); + } + } + }); +} + +/*! + * \brief Collect all tirx::Var references from a relax param's struct_info. + * + * Walks TensorStructInfo shapes, ShapeStructInfo values, and PrimStructInfo + * values to find symbolic dimension variables. + */ +inline void CollectTIRVarsFromStructInfo(const StructInfo& sinfo, + std::vector* out, + std::unordered_set* seen) { + if (const auto* tsi = sinfo.as()) { + if (tsi->shape.defined()) { + if (const auto* se = tsi->shape.value().as()) { + for (const auto& dim : se->values) { + CollectTIRVarsFromPrimExpr(dim, out, seen); + } + } + } + } else if (const auto* ssi = sinfo.as()) { + if (ssi->values.defined()) { + for (const auto& val : ssi->values.value()) { + CollectTIRVarsFromPrimExpr(val, out, seen); + } + } + } else if (const auto* psi = sinfo.as()) { + if (psi->value.defined()) { + CollectTIRVarsFromPrimExpr(psi->value.value(), out, seen); + } + } else if (const auto* tsi = sinfo.as()) { + for (const auto& field : tsi->fields) { + CollectTIRVarsFromStructInfo(field, out, seen); + } + } else if (const auto* fsi = sinfo.as()) { + if (fsi->params.defined()) { + for (const auto& param : fsi->params.value()) { + CollectTIRVarsFromStructInfo(param, out, seen); + } + } + CollectTIRVarsFromStructInfo(fsi->ret, out, seen); + } +} + +/*! + * \brief Print a PrimExpr for use in struct_info shape contexts. + * + * In V1, the "relax" dispatch for IntImm/FloatImm prints them as plain + * integer/float literals (not T.int64(10)). This matches that behavior. + */ +inline text::ExprAST PrintShapeValue(const PrimExpr& e, const text::AccessPath& e_p, const text::IRPrinter& printer, + bool stringify_vars = false) { + if (const auto* int_imm = e.as()) { + if (int_imm->dtype.is_bool()) { + return text::LiteralAST::Bool(int_imm->value != 0); + } + return text::LiteralAST::Int(int_imm->value); + } + if (const auto* float_imm = e.as()) { + return text::LiteralAST::Float(float_imm->value); + } + // For PrimExpr containing symbolic TIR Vars, stringify them (matching V1 PrintShapeVar). + // Only do this in param/return annotation contexts (g_printing_func_annotation). + if (stringify_vars || g_printing_func_annotation) { + bool has_tir_var = false; + tirx::PostOrderVisit(e, [&](const ffi::ObjectRef& obj) { + if (obj->IsInstance()) has_tir_var = true; + }); + if (has_tir_var) { + // Helper: get the defined name for a TIR var (uses VarGet for the + // printer-assigned name, which may differ from name_hint when there + // are naming collisions, e.g. two different Vars both named "N"). + auto get_var_name = [&](const tirx::VarNode* v) -> std::string { + tirx::Var var_ref = ffi::GetRef(v); + if (auto defined = printer->VarGet(var_ref)) { + if (const auto* id = defined.value().as()) { + return id->name; + } + } + return std::string(v->name_hint); + }; + // Simple Var: just use defined name as string + if (const auto* v = e.as()) { + return text::LiteralAST::Str(get_var_name(v)); + } + // Compound expressions (n * 2, etc.): build string from parts. + // Precedence-aware so (N+3)//4 renders correctly. + auto get_prec = [](const PrimExpr& expr) -> int { + if (expr.as()) return 4; + if (expr.as()) return 5; + if (expr.as()) return 6; + if (expr.as() || expr.as() || + expr.as() || expr.as() || + expr.as() || expr.as()) return 7; + if (expr.as() || expr.as()) return 12; + if (expr.as() || expr.as() || + expr.as()) return 13; + return 100; + }; + std::function stringify; + stringify = [&](const PrimExpr& expr, int parent_prec) -> std::string { + if (const auto* v = expr.as()) return get_var_name(v); + if (const auto* imm = expr.as()) return std::to_string(imm->value); + int my_prec = get_prec(expr); + std::string result; + if (const auto* add = expr.as()) { + result = stringify(add->a, my_prec) + " + " + stringify(add->b, my_prec + 1); + } else if (const auto* sub = expr.as()) { + result = stringify(sub->a, my_prec) + " - " + stringify(sub->b, my_prec + 1); + } else if (const auto* mul = expr.as()) { + result = stringify(mul->a, my_prec) + " * " + stringify(mul->b, my_prec + 1); + } else if (const auto* div = expr.as()) { + result = stringify(div->a, my_prec) + " // " + stringify(div->b, my_prec + 1); + } else if (const auto* mod = expr.as()) { + result = stringify(mod->a, my_prec) + " % " + stringify(mod->b, my_prec + 1); + } else if (const auto* mn = expr.as()) { + return "T.min(" + stringify(mn->a, 0) + ", " + stringify(mn->b, 0) + ")"; + } else if (const auto* mx = expr.as()) { + return "T.max(" + stringify(mx->a, 0) + ", " + stringify(mx->b, 0) + ")"; + } else if (const auto* cast = expr.as()) { + return stringify(cast->value, parent_prec); + } else { + std::ostringstream os; + os << expr; + return os.str(); + } + if (my_prec < parent_prec) { + return "(" + result + ")"; + } + return result; + }; + return text::LiteralAST::Str(stringify(e, 0)); + } + } + // Handle binary ops recursively to ensure child IntImm values print as + // plain literals (matching V1's "relax" dispatch behavior). + // Without this, children go through the generic trait printer which wraps + // int64 values in T.int64(). + using Op = text::OperationASTObj; + if (const auto* add = e.as()) { + return text::OperationAST(Op::kAdd, + {PrintShapeValue(add->a, e_p->Attr("a"), printer), + PrintShapeValue(add->b, e_p->Attr("b"), printer)}); + } + if (const auto* sub = e.as()) { + return text::OperationAST(Op::kSub, + {PrintShapeValue(sub->a, e_p->Attr("a"), printer), + PrintShapeValue(sub->b, e_p->Attr("b"), printer)}); + } + if (const auto* mul = e.as()) { + return text::OperationAST(Op::kMult, + {PrintShapeValue(mul->a, e_p->Attr("a"), printer), + PrintShapeValue(mul->b, e_p->Attr("b"), printer)}); + } + if (const auto* div = e.as()) { + return text::OperationAST(Op::kFloorDiv, + {PrintShapeValue(div->a, e_p->Attr("a"), printer), + PrintShapeValue(div->b, e_p->Attr("b"), printer)}); + } + if (const auto* mod = e.as()) { + return text::OperationAST(Op::kMod, + {PrintShapeValue(mod->a, e_p->Attr("a"), printer), + PrintShapeValue(mod->b, e_p->Attr("b"), printer)}); + } + if (const auto* mn = e.as()) { + return text::ExprCall(TIR("min"), + {PrintShapeValue(mn->a, e_p->Attr("a"), printer), + PrintShapeValue(mn->b, e_p->Attr("b"), printer)}); + } + if (const auto* mx = e.as()) { + return text::ExprCall(TIR("max"), + {PrintShapeValue(mx->a, e_p->Attr("a"), printer), + PrintShapeValue(mx->b, e_p->Attr("b"), printer)}); + } + // Comparison operators + if (const auto* eq = e.as()) { + return text::OperationAST(Op::kEq, + {PrintShapeValue(eq->a, e_p->Attr("a"), printer), + PrintShapeValue(eq->b, e_p->Attr("b"), printer)}); + } + if (const auto* ne = e.as()) { + return text::OperationAST(Op::kNotEq, + {PrintShapeValue(ne->a, e_p->Attr("a"), printer), + PrintShapeValue(ne->b, e_p->Attr("b"), printer)}); + } + if (const auto* lt = e.as()) { + return text::OperationAST(Op::kLt, + {PrintShapeValue(lt->a, e_p->Attr("a"), printer), + PrintShapeValue(lt->b, e_p->Attr("b"), printer)}); + } + if (const auto* le = e.as()) { + return text::OperationAST(Op::kLtE, + {PrintShapeValue(le->a, e_p->Attr("a"), printer), + PrintShapeValue(le->b, e_p->Attr("b"), printer)}); + } + if (const auto* gt = e.as()) { + return text::OperationAST(Op::kGt, + {PrintShapeValue(gt->a, e_p->Attr("a"), printer), + PrintShapeValue(gt->b, e_p->Attr("b"), printer)}); + } + if (const auto* ge = e.as()) { + return text::OperationAST(Op::kGtE, + {PrintShapeValue(ge->a, e_p->Attr("a"), printer), + PrintShapeValue(ge->b, e_p->Attr("b"), printer)}); + } + // Logical operators + if (const auto* and_n = e.as()) { + return text::OperationAST(Op::kAnd, + {PrintShapeValue(and_n->a, e_p->Attr("a"), printer), + PrintShapeValue(and_n->b, e_p->Attr("b"), printer)}); + } + if (const auto* or_n = e.as()) { + return text::OperationAST(Op::kOr, + {PrintShapeValue(or_n->a, e_p->Attr("a"), printer), + PrintShapeValue(or_n->b, e_p->Attr("b"), printer)}); + } + // Unary Not + if (const auto* not_n = e.as()) { + return text::OperationAST(Op::kNot, + {PrintShapeValue(not_n->a, e_p->Attr("a"), printer)}); + } + // For tirx::Var: print using the printer (which will resolve to the defined IdAST) + if (e->IsInstance()) { + return Print(printer, e, e_p); + } + // For other PrimExpr types, use the general printer + return Print(printer, e, e_p); +} + +/*! + * \brief Extract a scalar value from a 0-d CPU tensor as an text::ExprAST literal. + * + * Returns std::nullopt for non-scalar or non-CPU tensors, or unsupported dtypes. + * Matches the V1 SpecialScalar logic. + */ +inline ffi::Optional SpecialScalar(const runtime::Tensor& tensor, const text::AccessPath& p) { + DataType dtype(tensor->dtype); + const void* data = tensor->data; + if (tensor->ndim != 0 || tensor->device.device_type != kDLCPU) { + return std::nullopt; + } + if (dtype == DataType::Int(8)) { + return text::LiteralAST::Int(*reinterpret_cast(data)); + } else if (dtype == DataType::Int(16)) { + return text::LiteralAST::Int(*reinterpret_cast(data)); + } else if (dtype == DataType::Int(32)) { + return text::LiteralAST::Int(*reinterpret_cast(data)); + } else if (dtype == DataType::Int(64)) { + return text::LiteralAST::Int(*reinterpret_cast(data)); + } else if (dtype == DataType::Float(16)) { + uint16_t bits = *reinterpret_cast(data); + uint16_t sign_bit = (bits & 0x8000) >> 15; + uint16_t exponent = (bits & 0x7C00) >> 10; + uint16_t fraction = (bits & 0x03FF) >> 0; + double value; + if (exponent == 0x1F && fraction == 0) { + value = std::numeric_limits::infinity(); + } else if (exponent == 0x1F) { + value = std::numeric_limits::quiet_NaN(); + } else if (exponent == 0 && fraction == 0) { + value = 0.0; + } else if (exponent == 0) { + value = ::std::pow(2.0, -24) * static_cast(fraction); + } else { + value = ::std::pow(2.0, static_cast(exponent) - 25) * + static_cast(fraction | (1 << 10)); + } + if (sign_bit) { + value *= -1.0; + } + return text::LiteralAST::Float(value); + } else if (dtype == DataType::Float(32)) { + return text::LiteralAST::Float(*reinterpret_cast(data)); + } else if (dtype == DataType::Float(64)) { + return text::LiteralAST::Float(*reinterpret_cast(data)); + } else if (dtype == DataType::Bool()) { + return text::LiteralAST::Bool(*reinterpret_cast(data) != 0); + } else { + return std::nullopt; + } +} + +/*! + * \brief Print a SeqExpr body as a list of statements (no return). + * + * Used by the If and VarBinding printers to print SeqExpr branches + * matching V1's PrintSeqExpr(seq, path, d, use_ret=false). + */ +inline ffi::List PrintSeqExprBody(const relax::SeqExpr& seq, const text::AccessPath& seq_path, + const text::IRPrinter& printer) { + ffi::List stmts; + for (int i = 0; i < static_cast(seq->blocks.size()); ++i) { + text::NodeAST block_ast = printer->operator()( + ffi::Any(seq->blocks[i]), + seq_path->Attr("blocks")->ArrayItem(i)).cast(); + if (auto* sb = block_ast.as()) { + for (const auto& s : sb->stmts) stmts.push_back(s); + } else if (block_ast->IsInstance()) { + stmts.push_back(Downcast(block_ast)); + } + } + // Body is the last expression (printed as ExprStmt, not return) + text::ExprAST ret_expr = Print(printer, seq->body, seq_path->Attr("body")); + stmts.push_back(text::ExprStmtAST(ret_expr)); + return stmts; +} + +} // namespace printer +} // namespace tvm + +#endif // TVM_RELAX_IR_SCRIPT_PRINT_UTILS_H_ diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 434917bd0f94..4995d0246523 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -21,23 +21,109 @@ * \file src/relax/ir/struct_info.cc * \brief Relax struct info. */ +#include +#include +#include #include #include #include #include +#include #include +#include + +#include "script_print_utils.h" namespace tvm { namespace relax { TVM_FFI_STATIC_INIT_BLOCK() { StructInfoNode::RegisterReflection(); - ObjectStructInfoNode::RegisterReflection(); - PrimStructInfoNode::RegisterReflection(); - ShapeStructInfoNode::RegisterReflection(); TensorStructInfoNode::RegisterReflection(); TupleStructInfoNode::RegisterReflection(); FuncStructInfoNode::RegisterReflection(); + ObjectStructInfoNode::RegisterReflection(); + PrimStructInfoNode::RegisterReflection(); + ShapeStructInfoNode::RegisterReflection(); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = ::tvm::ffi::reflection; + refl::GlobalDef() + .def("relax._func_si_args", [](ffi::AnyView /*ctx*/, FuncStructInfo node) -> ffi::List { + ffi::List args; + if (!node->IsOpaque()) { + args.push_back(node->params.value()); + args.push_back(node->ret); + args.push_back(static_cast(node->purity)); + } + return args; + }) + .def("relax._func_si_kwargs", [](ffi::AnyView /*ctx*/, FuncStructInfo node) -> ffi::Dict { + ffi::Dict kwargs; + if (node->IsOpaque()) { + if (!node->ret->IsInstance()) { + kwargs.Set(ffi::String("ret"), node->ret); + } + if (node->purity) { + kwargs.Set(ffi::String("purity"), true); + } + } + return kwargs; + }) + .def("relax._empty_array", + [](ffi::AnyView /*ctx*/, ObjectStructInfo) -> ffi::Array { + return ffi::Array(); + }) + .def("relax._prim_si_args", [](ffi::AnyView /*ctx*/, PrimStructInfo node) -> ffi::List { + ffi::List args; + if (node->value.defined()) { + return args; + } + DataType dt = node->dtype; + if (dt.is_void()) { + return args; + } + ffi::String dtype_str = ffi::DLDataTypeToString(static_cast(dt)); + args.push_back(dtype_str); + return args; + }) + .def("relax._prim_si_kwargs", [](ffi::AnyView /*ctx*/, PrimStructInfo node) -> ffi::Dict { + ffi::Dict kwargs; + if (node->value.defined()) { + PrimExpr value = node->value.value(); + if (const auto* var = value.as()) { + kwargs.Set(ffi::String("value"), ffi::String(var->name_hint)); + } else { + kwargs.Set(ffi::String("value"), value); + } + } + return kwargs; + }) + .def("relax._shape_si_args", [](ffi::AnyView /*ctx*/, ShapeStructInfo node) -> ffi::List { + ffi::List args; + if (node->values.defined()) { + ffi::Array values = node->values.value(); + ffi::List shape_items; + for (int i = 0; i < static_cast(values.size()); ++i) { + PrimExpr v = values[i]; + if (const auto* int_imm = v.as()) { + shape_items.push_back(int_imm->value); + } else { + shape_items.push_back(v); + } + } + args.push_back(shape_items); + } + return args; + }) + .def("relax._shape_si_kwargs", [](ffi::AnyView /*ctx*/, ShapeStructInfo node) -> ffi::Dict { + ffi::Dict kwargs; + if (!node->values.defined()) { + kwargs.Set(ffi::String("ndim"), static_cast(node->ndim)); + } + return kwargs; + }); } ObjectStructInfo::ObjectStructInfo(Span span) { @@ -240,5 +326,109 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("ir.ExprStructInfo", [](Expr expr) { return GetStructInfo(expr); }); } +// ---- __ffi_text_print__ overrides ---- + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // PrimStructInfo: R.Prim(dtype) or R.Prim(value=...) + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](PrimStructInfo n, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + ffi::List args; + ffi::List kwargs_keys; + ffi::List kwargs_values; + if (n->value.defined()) { + kwargs_keys.push_back(ffi::String("value")); + kwargs_values.push_back(PrintShapeValue(n->value.value(), path->Attr("value"), printer, false)); + } else { + args.push_back(text::LiteralAST::Str(DType2Str(n->dtype))); + } + return text::ExprCallKw(Relax("Prim"), std::move(args), + std::move(kwargs_keys), std::move(kwargs_values)); + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // ShapeStructInfo: R.Shape([dims]) or R.Shape(ndim=N) + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](ShapeStructInfo n, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + if (n->values.defined()) { + ffi::Array shape = n->values.value(); + text::AccessPath shape_p = path->Attr("values"); + ffi::List shape_docs; + for (int i = 0, ndim = shape.size(); i < ndim; ++i) { + shape_docs.push_back(PrintShapeValue(shape[i], shape_p->ArrayItem(i), printer, false)); + } + return text::ExprCall(Relax("Shape"), {text::ListAST({}, std::move(shape_docs))}); + } + return text::ExprCallKw(Relax("Shape"), {}, + {ffi::String("ndim")}, {text::LiteralAST::Int(n->ndim)}); + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + namespace refl = ::tvm::ffi::reflection; + namespace text = ::tvm::ffi::pyast; + // TensorStructInfo: R.Tensor((shape,), dtype=...) etc. + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](TensorStructInfo n, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + ffi::List args; + ffi::List kwargs_keys; + ffi::List kwargs_values; + if (n->shape.defined()) { + // Dig into ShapeExpr to get individual dims + if (const auto* shape = n->shape.value().as()) { + auto shape_expr = ffi::GetRef(shape); + text::AccessPath shape_p = path->Attr("shape")->Attr("values"); + ffi::List shape_docs; + for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { + shape_docs.push_back(PrintShapeValue(shape_expr->values[i], + shape_p->ArrayItem(i), printer, + /*stringify_vars=*/false)); + } + args.push_back(text::TupleAST({}, std::move(shape_docs))); + } else { + args.push_back(Print(printer, n->shape.value(), path->Attr("shape"))); + } + } + if (!n->IsUnknownDtype()) { + kwargs_keys.push_back(ffi::String("dtype")); + kwargs_values.push_back(text::LiteralAST::Str(DType2Str(n->dtype))); + } + if (!n->shape.defined() && !n->IsUnknownNdim()) { + kwargs_keys.push_back(ffi::String("ndim")); + kwargs_values.push_back(text::LiteralAST::Int(n->ndim)); + } + // vdevice (matching V1 logic) + if (n->vdevice.defined() && n->vdevice.value()->target.defined()) { + kwargs_keys.push_back(ffi::String("vdevice")); + VDevice vdev = n->vdevice.value(); + // Look up pre-computed "kind:kind_index:scope" from module.cc + if (auto opt = printer->VarGet(vdev)) { + kwargs_values.push_back(opt.value()); + } else { + // Fallback: use target kind name and vdevice_id + std::string dev_kind = vdev->target->kind->name; + kwargs_values.push_back(text::LiteralAST::Str( + dev_kind + ":" + std::to_string(vdev->vdevice_id) + ":" + + std::string(vdev->memory_scope))); + } + } + if (args.empty() && kwargs_keys.empty()) { + return Relax("Tensor"); + } + return text::ExprCallKw(Relax("Tensor"), std::move(args), + std::move(kwargs_keys), std::move(kwargs_values)); + }); +} + } // namespace relax } // namespace tvm diff --git a/src/target/target.cc b/src/target/target.cc index 2c093ee73fd1..068273521c3b 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -40,7 +40,14 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK() { TargetNode::RegisterReflection(); } +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = ::tvm::ffi::reflection; + TargetNode::RegisterReflection(); + refl::GlobalDef().def("target._config", [](ffi::AnyView /*ctx*/, Target node) -> ffi::Array { + ffi::Map config = node->ToConfig(); + return {config}; + }); +} class TargetInternal { public: diff --git a/src/tirx/ir/expr.cc b/src/tirx/ir/expr.cc index f4130e70d6c7..d7e47d78f7d9 100644 --- a/src/tirx/ir/expr.cc +++ b/src/tirx/ir/expr.cc @@ -21,22 +21,31 @@ * \file expr.cc */ #include +#include +#include #include +#include #include #include +#include #include +#include #include #include #include "../../arith/scalable_expression.h" +#include "../../ir/printer_utils.h" #include "../../support/str_escape.h" #include "buffer_common.h" +#include "script_print_utils.h" namespace tvm { namespace tirx { TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = ::tvm::ffi::reflection; + VarNode::RegisterReflection(); SizeVarNode::RegisterReflection(); IterVarNode::RegisterReflection(); @@ -70,8 +79,264 @@ TVM_FFI_STATIC_INIT_BLOCK() { ShuffleNode::RegisterReflection(); CommReducerNode::RegisterReflection(); ReduceNode::RegisterReflection(); + + refl::GlobalDef() + .def("tirx._tir_call_callee", + [](ffi::pyast::IRPrinter printer, tirx::Call call) -> ffi::Any { + if (!call->op->IsInstance()) { + return ffi::Any(); + } + namespace text = ::tvm::ffi::pyast; + if (auto gv_doc = printer->VarGet(call->op)) { + return ffi::Any(gv_doc.value()); + } + GlobalVar op_gv = Downcast(call->op); + for (const auto& kv : printer->obj2info) { + if (const auto* gv_node = kv.first.as()) { + if (gv_node->name_hint == op_gv->name_hint) { + return ffi::Any(kv.second->creator().cast()); + } + } + } + return printer->operator()(ffi::Any(call->op), + text::AccessPath::Root()->Attr("op")); + }); + + // Global function definitions for all computed methods + refl::GlobalDef() + // VarNode / SizeVarNode type annotation helpers + .def("tirx._var_type_or_null", [](ffi::AnyView /*ctx*/, Var node) -> ffi::Optional { + if (!node->type_annotation.defined()) return ffi::Optional(); + if (const auto* tt = node->type_annotation.as()) { + if (tt->fields.empty()) return ffi::Optional(); + } + return ffi::Optional(node->type_annotation); + }) + .def("tirx._sizevar_type_or_null", [](ffi::AnyView /*ctx*/, SizeVar node) -> ffi::Optional { + if (!node->type_annotation.defined()) return ffi::Optional(); + if (const auto* tt = node->type_annotation.as()) { + if (tt->fields.empty()) return ffi::Optional(); + } + return ffi::Optional(node->type_annotation); + }) + // IterVar args + .def("tirx._iter_var_args", [](ffi::AnyView /*ctx*/, IterVar node) -> ffi::Array { + const char* type_str; + switch (static_cast(node->iter_type)) { + case kDataPar: type_str = "DataPar"; break; + case kThreadIndex: type_str = "ThreadIndex"; break; + case kCommReduce: type_str = "CommReduce"; break; + case kOrdered: type_str = "Ordered"; break; + case kOpaque: type_str = "DimInfo"; break; + default: type_str = "Unrolled"; break; + } + ffi::Array result; + result.push_back(node->var); + result.push_back(node->dom); + result.push_back(StringImm(ffi::String(type_str))); + if (!node->thread_tag.empty()) { + result.push_back(StringImm(ffi::String(std::string(node->thread_tag)))); + } + return result; + }) + // Cast args + .def("tirx._cast_args", [](ffi::AnyView /*ctx*/, Cast node) -> ffi::Array { + StringImm dtype_str(ffi::DLDataTypeToString(node->dtype)); + return {dtype_str, node->value}; + }) + // BinOp sugar checks: verify that re-constructing via the sugar function + // produces the same node (i.e. the sugar round-trips). +#define TVM_TIRX_BINOP_SUGAR(lower, NodeTy, sugar_fn) \ + .def("tirx._" #lower "_sugar", [](ffi::AnyView /*ctx*/, tirx::NodeTy node) -> bool { \ + PrimExpr ret = sugar_fn(node->a, node->b); \ + if (const auto* p = ret.as()) { \ + return p->a.same_as(node->a) && p->b.same_as(node->b); \ + } \ + return false; \ + }) + TVM_TIRX_BINOP_SUGAR(add, Add, tvm::add) + TVM_TIRX_BINOP_SUGAR(sub, Sub, tvm::sub) + TVM_TIRX_BINOP_SUGAR(mul, Mul, tvm::mul) + TVM_TIRX_BINOP_SUGAR(floordiv, FloorDiv, tvm::floordiv) + TVM_TIRX_BINOP_SUGAR(floormod, FloorMod, tvm::floormod) + TVM_TIRX_BINOP_SUGAR(eq, EQ, tvm::equal) + TVM_TIRX_BINOP_SUGAR(ne, NE, tvm::not_equal) + TVM_TIRX_BINOP_SUGAR(lt, LT, tvm::less) + TVM_TIRX_BINOP_SUGAR(le, LE, tvm::less_equal) + TVM_TIRX_BINOP_SUGAR(gt, GT, tvm::greater) + TVM_TIRX_BINOP_SUGAR(ge, GE, tvm::greater_equal) + TVM_TIRX_BINOP_SUGAR(and, And, tvm::logical_and) + TVM_TIRX_BINOP_SUGAR(or, Or, tvm::logical_or) +#undef TVM_TIRX_BINOP_SUGAR + // Div sugar is special: also rejects integer-typed operands + .def("tirx._div_sugar", [](ffi::AnyView /*ctx*/, tirx::Div node) -> bool { + PrimExpr ret = tvm::div(node->a, node->b); + if (!ret->IsInstance()) return false; + if ((node->a->dtype.is_int() || node->a->dtype.is_uint()) && + (node->b->dtype.is_int() || node->b->dtype.is_uint())) { + return false; + } + return true; + }) + // Select args + .def("tirx._select_args", [](ffi::AnyView /*ctx*/, Select node) -> ffi::Array { + return {node->condition, node->true_value, node->false_value}; + }) + // Ramp args + .def("tirx._ramp_args", [](ffi::AnyView /*ctx*/, Ramp node) -> ffi::Array { + return {node->base, node->stride, node->lanes}; + }) + // Broadcast args + .def("tirx._broadcast_args", [](ffi::AnyView /*ctx*/, Broadcast node) -> ffi::Array { + return {node->value, node->lanes}; + }) + // Call callee and args + .def("tirx._call_callee", [](ffi::AnyView /*ctx*/, tirx::Call node) -> ffi::String { + if (auto* op = node->op.as()) { + static const OpAttrMap op_names = + Op::GetAttrMap("TScriptPrinterName"); + Op op_ref = ffi::GetRef(op); + if (op_names.count(op_ref)) { + return ffi::String("T." + std::string(op_names[op_ref])); + } + std::string full(op->name); + auto pos = full.rfind('.'); + return ffi::String("T." + ((pos != std::string::npos) ? full.substr(pos + 1) : full)); + } + return ffi::String("T.call"); + }) + .def("tirx._call_args", [](ffi::AnyView /*ctx*/, tirx::Call node) -> ffi::Array { + ffi::Array result; + int print_location = static_cast(ScriptDtypePrintLocation::kNone); + if (auto* op = node->op.as()) { + static const OpAttrMap dtype_locations = + Op::GetAttrMap("TScriptDtypePrintLocation"); + Op op_ref = ffi::GetRef(op); + if (dtype_locations.count(op_ref)) { + print_location = dtype_locations[op_ref].IntValue(); + } + } + std::string dtype_str = node->dtype.is_void() ? "void" + : ffi::DLDataTypeToString(node->dtype); + bool is_llvm_intrin = false; + if (auto* op = node->op.as()) { + static const OpAttrMap op_names = + Op::GetAttrMap("TScriptPrinterName"); + Op op_ref = ffi::GetRef(op); + if (op_names.count(op_ref)) { + std::string name(op_names[op_ref]); + is_llvm_intrin = (name == "call_llvm_pure_intrin" || name == "call_llvm_intrin"); + } + } + if (print_location == static_cast(ScriptDtypePrintLocation::kFirst)) { + result.push_back(StringImm(ffi::String(dtype_str))); + } + for (int i = 0; i < static_cast(node->args.size()); ++i) { + if (i == 0 && is_llvm_intrin) { + auto f_lookup = ffi::Function::GetGlobal("target.llvm_get_intrinsic_name"); + if (f_lookup.has_value() && node->args[0].as()) { + int64_t id = node->args[0].as()->value; + ffi::Any ret; + ffi::AnyView args_view[1] = {ffi::AnyView(id)}; + f_lookup.value().CallPacked(args_view, 1, &ret); + ffi::String name = ret.cast(); + result.push_back(StringImm(name)); + } else { + result.push_back(node->args[i]); + } + } else { + result.push_back(node->args[i]); + } + } + if (print_location == static_cast(ScriptDtypePrintLocation::kLast)) { + result.push_back(StringImm(ffi::String(dtype_str))); + } + return result; + }) + // Shuffle args + .def("tirx._shuffle_args", [](ffi::AnyView /*ctx*/, Shuffle node) -> ffi::Array { + ffi::Array result; + result.push_back(node->vectors); + result.push_back(node->indices); + return result; + }) + // Reduce positional and kwargs + .def("tirx._reduce_positional", [](ffi::AnyView /*ctx*/, Reduce node) -> ffi::Array { + return {node->combiner}; + }) + .def("tirx._reduce_kwargs", [](ffi::AnyView /*ctx*/, Reduce node) -> ffi::Map { + ffi::Map result; + result.Set(ffi::String("source"), node->source); + result.Set(ffi::String("init"), node->init); + result.Set(ffi::String("axis"), node->axis); + result.Set(ffi::String("condition"), node->condition); + result.Set(ffi::String("value_index"), IntImm(DataType::Int(32), node->value_index)); + return result; + }) + // BufferLoad indices: convert Ramp(base, stride, lanes) → [start, stop, step?] + // Skip conversion when predicate is set (vload fallback needs raw indices) + .def("tirx._load_indices", [](ffi::AnyView /*ctx*/, BufferLoad node) -> ffi::Array { + if (node->predicate.defined()) return node->indices; + ffi::Array result; + for (const auto& idx : node->indices) { + if (const auto* ramp = idx.as()) { + if (ramp->stride.as()) { + ffi::Array slice; + slice.push_back(ramp->base); + slice.push_back(ramp->base + ramp->lanes * ramp->stride); + if (!is_one(ramp->stride)) { + slice.push_back(ramp->stride); + } + result.push_back(slice); + continue; + } + } + result.push_back(idx); + } + return result; + }) + // BufferStore indices: same Ramp→slice conversion + // Skip conversion when predicate is set (vstore fallback needs raw indices) + .def("tirx._store_indices", [](ffi::AnyView /*ctx*/, BufferStore node) -> ffi::Array { + if (node->predicate.defined()) return node->indices; + ffi::Array result; + for (const auto& idx : node->indices) { + if (const auto* ramp = idx.as()) { + if (ramp->stride.as()) { + ffi::Array slice; + slice.push_back(ramp->base); + slice.push_back(ramp->base + ramp->lanes * ramp->stride); + if (!is_one(ramp->stride)) { + slice.push_back(ramp->stride); + } + result.push_back(slice); + continue; + } + } + result.push_back(idx); + } + return result; + }) + // BufferRegion indices + .def("tirx._buf_region_indices", [](ffi::AnyView /*ctx*/, BufferRegion node) -> ffi::Array { + ffi::Array result; + for (const auto& r : node->region) { + if (is_one(r->extent)) { + // Single-point access: plain index instead of slice + result.push_back(r->min); + } else { + // Range access: [min, min + extent] → slice + ffi::Array pair; + pair.push_back(r->min); + pair.push_back(r->min + r->extent); + result.push_back(pair); + } + } + return result; + }); } + /* \brief Convert an object to a PrimExpr * * All conversions to a PrimExpr are performed as part of the FFI, @@ -899,5 +1164,94 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } +// --------------------------------------------------------------------------- +// __ffi_text_print__ overrides +// --------------------------------------------------------------------------- + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + + // CommReducer: lambda construction -- genuinely irreducible + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](CommReducer node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + using namespace printer; + ffi::List lhs_vars, rhs_vars, results; + for (int i = 0; i < static_cast(node->lhs.size()); ++i) + lhs_vars.push_back(Print(printer, node->lhs[i], path->Attr("lhs")->ArrayItem(i))); + for (int i = 0; i < static_cast(node->rhs.size()); ++i) + rhs_vars.push_back(Print(printer, node->rhs[i], path->Attr("rhs")->ArrayItem(i))); + for (int i = 0; i < static_cast(node->result.size()); ++i) + results.push_back(Print(printer, node->result[i], path->Attr("result")->ArrayItem(i))); + ffi::List params; + params.insert(params.end(), lhs_vars.begin(), lhs_vars.end()); + params.insert(params.end(), rhs_vars.begin(), rhs_vars.end()); + text::ExprAST lambda_body = (results.size() == 1) ? results[0] : text::TupleAST({}, results); + text::LambdaAST lambda_ast({}, params, lambda_body); + return text::ExprCall(TIR("comm_reducer"), + {lambda_ast, printer->PrintList(node->identity_element, + path->Attr("identity_element"))}); + }); + + // IndexMap: T.index_map(lambda vars: (exprs...), [inverse_index_map=...]) + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](IndexMap node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + using namespace printer; + ffi::List params; + for (int i = 0; i < static_cast(node->initial_indices.size()); ++i) { + params.push_back( + Print(printer, node->initial_indices[i], + path->Attr("initial_indices")->ArrayItem(i))); + } + ffi::List exprs; + for (int i = 0; i < static_cast(node->final_indices.size()); ++i) { + exprs.push_back( + Print(printer, node->final_indices[i], + path->Attr("final_indices")->ArrayItem(i))); + } + text::ExprAST body = (exprs.size() == 1) ? exprs[0] : text::TupleAST({}, std::move(exprs)); + text::LambdaAST lambda_ast({}, std::move(params), body); + if (node->inverse_index_map.defined()) { + IndexMap inv = Downcast(node->inverse_index_map); + ffi::List inv_params; + for (int i = 0; i < static_cast(inv->initial_indices.size()); ++i) { + inv_params.push_back( + Print(printer, inv->initial_indices[i], + path->Attr("inverse_index_map")->Attr("initial_indices")->ArrayItem(i))); + } + ffi::List inv_exprs; + for (int i = 0; i < static_cast(inv->final_indices.size()); ++i) { + inv_exprs.push_back( + Print(printer, inv->final_indices[i], + path->Attr("inverse_index_map")->Attr("final_indices")->ArrayItem(i))); + } + text::ExprAST inv_body = (inv_exprs.size() == 1) ? inv_exprs[0] + : text::TupleAST({}, std::move(inv_exprs)); + text::LambdaAST inv_lambda({}, std::move(inv_params), inv_body); + return text::ExprCallKw(TIR("index_map"), {lambda_ast}, + {ffi::String("inverse_index_map")}, {inv_lambda}); + } + return text::ExprCall(TIR("index_map"), {lambda_ast}); + }); + + // Let: T.Let(body, where={var: value}) + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](Let node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + using namespace printer; + if (!printer->VarGet(node->var).has_value() && !printer->frames.empty()) { + text::DefaultFrame frame = printer->frames.back().cast(); + DefineNewTIRVar(node->var, printer, frame); + } + text::ExprAST body_doc = Print(printer, node->body, path->Attr("body")); + text::ExprAST var_doc = Print(printer, node->var, path->Attr("var")); + text::ExprAST val_doc = Print(printer, node->value, path->Attr("value")); + text::DictAST where_dict({var_doc}, {val_doc}); + return text::ExprCallKw(TIR("Let"), {body_doc}, + {ffi::String("where")}, {where_dict}); + }); +} + } // namespace tirx } // namespace tvm diff --git a/src/tirx/ir/function.cc b/src/tirx/ir/function.cc index b817ccfd7328..906738c16ed4 100644 --- a/src/tirx/ir/function.cc +++ b/src/tirx/ir/function.cc @@ -28,6 +28,12 @@ #include #include +#include +#include + +#include "../../ir/printer_utils.h" +#include "script_print_utils.h" + namespace tvm { namespace tirx { @@ -174,5 +180,356 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("tirx.TensorIntrinGet", TensorIntrin::Get); } +// --------------------------------------------------------------------------- +// __ffi_text_print__ override +// --------------------------------------------------------------------------- + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + + // PrimFunc -> @T.prim_func \n def name(params...): body + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](PrimFunc func, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + using namespace printer; + // Determine function name + ffi::String func_name = "main"; + bool in_module = printer->VarIsDefined(func); + if (in_module) { + if (auto binding_expr = printer->VarGet(func)) { + if (auto* id_node = binding_expr.value().as<::tvm::ffi::pyast::IdASTObj>()) { + func_name = id_node->name; + } + } + } else if (func->attrs.defined()) { + auto it = func->attrs->dict.find("global_symbol"); + if (it != func->attrs->dict.end()) { + func_name = (*it).second.cast(); + } + } + text::IdAST name = text::IdAST(func_name); + + // Build decorator + ffi::List decorators; + bool has_global_symbol = func->attrs.defined() && + func->attrs->dict.count("global_symbol"); + if (!has_global_symbol) { + decorators.push_back( + text::ExprCallKw(TIR("prim_func"), {}, {ffi::String("private")}, {text::LiteralAST::Bool(true)})); + } else { + decorators.push_back(TIR("prim_func")); + } + + // Push frame + text::DefaultFrame frame; + printer->FramePush(frame); + + // Pre-compute buffer_data_counter + int n_args = func->params.size(); + std::unordered_map buffer_data_counter; + for (const auto& pair : func->buffer_map) { + const tirx::VarNode* data_var = pair.second->data.get(); + if (!buffer_data_counter.count(data_var)) { + buffer_data_counter.insert({data_var, 0}); + } + ++buffer_data_counter.at(data_var); + } + + // Step 1. Handle params with buffer inlining + ffi::List params; + std::unordered_set buffer_inlined; + + for (int i = 0; i < n_args; ++i) { + Var var = func->params[i]; + text::AccessPath var_p = path->Attr("params")->ArrayItem(i); + + if (CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) { + Buffer buffer = func->buffer_map[var]; + if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) { + text::AccessPath buffer_p = path->Attr("buffer_map")->MapItem(var); + printer->VarDef(buffer->name, buffer, frame); + DefineBufferDataVar(buffer, printer); + text::ExprAST buf_id = printer->VarGet(buffer).value(); + text::ExprAST annotation = PrintBufferAnnotation(buffer, printer, buffer_p); + params.push_back(text::AssignAST(buf_id, ffi::Optional(), + ffi::Optional(annotation))); + buffer_inlined.insert(buffer.get()); + continue; + } + } + + text::ExprAST var_id = DefineVar(var, printer, var_p); + ffi::Optional annotation; + if (var->type_annotation.defined()) { + annotation = Print(printer, var->type_annotation, var_p->Attr("type_annotation")); + } + params.push_back( + text::AssignAST(var_id, ffi::Optional(), annotation)); + } + + // Step 2. Handle func->attrs + auto PrintAttrDict = [&](const ffi::Map& dict, + const text::AccessPath& dict_p) -> text::ExprAST { + ffi::List keys; + ffi::List vals; + for (const auto& kv : dict) { + keys.push_back(text::LiteralAST::Str(kv.first)); + vals.push_back(Print(printer, kv.second, dict_p)); + } + return text::DictAST(std::move(keys), std::move(vals)); + }; + + if (func->attrs.defined() && !func->attrs->dict.empty()) { + if (func->attrs->dict.count("global_symbol") && + func->attrs->dict.at("global_symbol").cast() == func_name) { + ffi::Map new_attrs; + for (const auto& kv : func->attrs->dict) { + if (kv.first != "global_symbol") { + new_attrs.Set(kv.first, kv.second); + } + } + if (!new_attrs.empty()) { + text::ExprAST attr_dict = PrintAttrDict(new_attrs, path->Attr("attrs")); + frame->stmts.push_back( + text::ExprStmtAST(text::ExprCall(TIR("func_attr"), {attr_dict}))); + } + } else { + text::ExprAST attr_dict = PrintAttrDict(func->attrs->dict, path->Attr("attrs")); + frame->stmts.push_back( + text::ExprStmtAST(text::ExprCall(TIR("func_attr"), {attr_dict}))); + } + } + + // Step 3. Handle buffer_map: non-inlined entries + for (int i = 0; i < n_args; ++i) { + Var param = func->params[i]; + if (func->buffer_map.count(param)) { + Buffer buffer = func->buffer_map[param]; + if (buffer_inlined.count(buffer.get())) continue; + DefineBufferVars(buffer, printer, frame); + text::AccessPath buffer_p = path->Attr("buffer_map")->MapItem(param); + printer->VarDef(buffer->name, buffer, frame); + text::ExprAST buf_id = printer->VarGet(buffer).value(); + text::ExprAST param_doc = params[i]->lhs; + ffi::List extra_args; + extra_args.push_back(param_doc); + text::ExprAST rhs = PrintBufferDecl(buffer, "match_buffer", std::move(extra_args), + printer, buffer_p); + DefineBufferDataVar(buffer, printer); + frame->stmts.push_back(text::AssignAST(buf_id, rhs, ffi::Optional())); + } + } + + // Step 3b. Emit declarations for undefined Vars + { + struct ThreadVarInfo { + std::vector iter_vars; + }; + std::unordered_map thread_var_info; + std::unordered_set thread_vars; + { + class ThreadVarCollector : public tirx::StmtVisitor { + public: + std::unordered_map* info; + void VisitStmt_(const tirx::AttrStmtNode* op) final { + if ((op->attr_key == "thread_extent" || op->attr_key == "virtual_thread") && + (op->node.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) && + op->node.cast()->IsInstance()) { + IterVar iv = op->node.cast(); + std::string key = iv->thread_tag; + (*info)[key].iter_vars.push_back(iv); + } + tirx::StmtVisitor::VisitStmt_(op); + } + }; + ThreadVarCollector collector; + collector.info = &thread_var_info; + Stmt body_to_walk = func->body; + { + const SBlockRealizeNode* root_br = func->body.as(); + if (root_br && !root_br->iter_values.size() && is_one(root_br->predicate)) { + SBlock rb = root_br->block; + if (!rb->annotations.size() && !rb->match_buffers.size() && + !rb->reads.size() && !rb->writes.size() && !rb->init.defined()) { + const SBlockRealizeNode* inner_br = rb->body.as(); + if (rb->alloc_buffers.size() || + (inner_br && inner_br->block->iter_vars.size()) || + (!inner_br && ContainsNode(rb->body))) { + body_to_walk = rb->body; + } + } + } + } + collector(body_to_walk); + + for (const auto& kv : thread_var_info) { + const std::vector& ivs = kv.second.iter_vars; + std::unordered_map var_ptr_count; + for (const IterVar& iv : ivs) { + thread_vars.insert(iv->var.get()); + var_ptr_count[iv->var.get()]++; + } + for (const auto& vpc : var_ptr_count) { + if (vpc.second > 1) { + for (const IterVar& iv : ivs) { + if (iv->var.get() == vpc.first) { + DefineVar(iv->var, printer, text::AccessPath::Root()); + text::ExprAST var_id = printer->VarGet(iv->var).value(); + text::ExprAST rhs = text::ExprCall(TIR("env_thread"), + {text::LiteralAST::Str(iv->thread_tag)}); + frame->stmts.push_back(text::AssignAST(var_id, rhs, ffi::Optional())); + break; + } + } + } + } + } + } + + ffi::Array defined_vars; + for (const auto& param : func->params) { + defined_vars.push_back(param); + } + for (const auto& pair : func->buffer_map) { + Buffer buf = pair.second; + defined_vars.push_back(buf->data); + for (const PrimExpr& s : buf->shape) { + if (const auto* v = s.as()) { + defined_vars.push_back(ffi::GetRef(v)); + } + } + for (const PrimExpr& s : buf->strides) { + if (const auto* v = s.as()) { + defined_vars.push_back(ffi::GetRef(v)); + } + } + if (const auto* v = buf->elem_offset.as()) { + defined_vars.push_back(ffi::GetRef(v)); + } + } + { + class SBlockVarCollector : public tirx::StmtVisitor { + public: + ffi::Array* vars; + void VisitStmt_(const tirx::SBlockNode* op) final { + for (const IterVar& iv : op->iter_vars) { + vars->push_back(iv->var); + } + for (const tirx::Buffer& buf : op->alloc_buffers) { + vars->push_back(buf->data); + } + for (const tirx::MatchBufferRegion& mb : op->match_buffers) { + tirx::Buffer buf = mb->buffer; + vars->push_back(buf->data); + for (const PrimExpr& s : buf->shape) { + if (const auto* v = s.as()) { + vars->push_back(ffi::GetRef(v)); + } + } + for (const PrimExpr& s : buf->strides) { + if (const auto* v = s.as()) { + vars->push_back(ffi::GetRef(v)); + } + } + if (const auto* v = buf->elem_offset.as()) { + vars->push_back(ffi::GetRef(v)); + } + } + tirx::StmtVisitor::VisitStmt_(op); + } + }; + SBlockVarCollector collector; + collector.vars = &defined_vars; + collector(func->body); + } + Stmt body_to_scan = func->body; + ffi::Array undef = tirx::UndefinedVars(body_to_scan, defined_vars); + std::unordered_set seen; + for (const Var& v : undef) { + if (seen.count(v.get())) continue; + seen.insert(v.get()); + if (thread_vars.count(v.get())) continue; + if (!printer->VarGet(v).has_value()) { + DefineNewTIRVar(v, printer, frame); + } + } + } + + // Step 4. Handle func->body with implicit root block detection + ffi::Optional implicit_root_block; + { + const SBlockRealizeNode* root_block_realize = + func->body.as(); + if (root_block_realize && !root_block_realize->iter_values.size() && + is_one(root_block_realize->predicate)) { + SBlock root_block = root_block_realize->block; + if (!root_block->annotations.size() && !root_block->match_buffers.size() && + !root_block->reads.size() && !root_block->writes.size() && + !root_block->init.defined()) { + const SBlockRealizeNode* block_realize = + root_block->body.as(); + if (root_block->alloc_buffers.size() || + (block_realize && block_realize->block->iter_vars.size()) || + (!block_realize && + ContainsNode(root_block->body))) { + implicit_root_block = root_block; + } + } + } + } + + ffi::List body_stmts; + if (implicit_root_block.defined()) { + SBlock root_block = implicit_root_block.value(); + text::AccessPath root_block_p = path->Attr("body")->Attr("block"); + frame->stmts.push_back(text::CommentAST(ffi::Optional(ffi::String("with T.sblock(\"root\"):")))); + for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) { + Buffer buffer = root_block->alloc_buffers[i]; + text::AccessPath buffer_p = root_block_p->Attr("alloc_buffers")->ArrayItem(i); + std::string buf_name = buffer->name; + if (buf_name.empty()) buf_name = "buffer"; + printer->VarDef(buf_name, buffer, frame); + text::ExprAST buf_id = printer->VarGet(buffer).value(); + ffi::List no_extra; + text::ExprAST rhs = PrintBufferDecl(buffer, "sblock_alloc_buffer", std::move(no_extra), + printer, buffer_p); + DefineBufferDataVar(buffer, printer); + frame->stmts.push_back(text::AssignAST(buf_id, rhs, ffi::Optional())); + } + body_stmts = PrintBodyStmts(root_block->body, printer, root_block_p->Attr("body")); + } else { + body_stmts = PrintBodyStmts(func->body, printer, path->Attr("body")); + } + + // Merge frame stmts + body + ffi::List all_body; + for (const auto& s : frame->stmts) all_body.push_back(s); + for (const auto& s : body_stmts) all_body.push_back(s); + + printer->FramePop(); + + // Return type annotation + ffi::Optional ret_type; + if (func->ret_type.defined()) { + const auto* as_tuple = func->ret_type.as(); + if (!as_tuple || as_tuple->fields.size()) { + ret_type = Print(printer, func->ret_type, path->Attr("ret_type")); + } + } + + text::FunctionAST func_ast(name, params, decorators, ret_type, all_body); + + if (!in_module) { + ffi::List result; + result.push_back(text::CommentAST( + ffi::Optional(ffi::String("from tvm.script import tirx as T")))); + result.push_back(text::CommentAST(ffi::Optional())); + result.push_back(func_ast); + return text::StmtBlockAST(result); + } + return func_ast; + }); +} + } // namespace tirx } // namespace tvm diff --git a/src/tirx/ir/script_print_utils.h b/src/tirx/ir/script_print_utils.h new file mode 100644 index 000000000000..42f466fb6166 --- /dev/null +++ b/src/tirx/ir/script_print_utils.h @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIRX_IR_SCRIPT_PRINT_UTILS_H_ +#define TVM_TIRX_IR_SCRIPT_PRINT_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../../ir/printer_utils.h" + +namespace tvm { +namespace printer { + +using namespace tirx; + +/*! \brief Define a TIR variable in the printer and return its IdAST. */ +inline text::ExprAST DefineVar(const tirx::Var& var, const text::IRPrinter& printer, + const text::AccessPath& path) { + if (printer->VarGet(var).has_value()) { + return printer->VarGet(var).value(); + } + text::DefaultFrame frame = printer->frames.back().cast(); + printer->VarDef(var->name_hint, var, frame); + return printer->VarGet(var).value(); +} + +/*! \brief Define a TIR variable and return AssignAST with annotation if typed. */ +inline text::StmtAST DefineVarAssign(const tirx::Var& var, text::ExprAST rhs, + const text::IRPrinter& printer, const text::AccessPath& path) { + text::ExprAST id = DefineVar(var, printer, path); + ffi::Optional annotation; + if (var->type_annotation.defined()) { + annotation = Print(printer, var->type_annotation, path->Attr("type_annotation")); + } + return text::AssignAST(id, rhs, annotation); +} + +/*! \brief Print a statement body, flattening SeqStmt. Returns a List. */ +inline ffi::List PrintBodyStmts(const Stmt& stmt, const text::IRPrinter& printer, + const text::AccessPath& path) { + ffi::List result; + text::NodeAST ast = printer->operator()(ffi::Any(stmt), path).cast(); + if (auto* block = ast.as()) { + for (const auto& s : block->stmts) { + result.push_back(s); + } + } else if (ast->IsInstance()) { + result.push_back(Downcast(ast)); + } else if (ast->IsInstance()) { + result.push_back(text::ExprStmtAST(Downcast(ast))); + } + return result; +} + +/*! \brief A Var occurrence counter visitor (matches V1 OccurrenceCounter). */ +class OccurrenceCounter : public tirx::StmtExprVisitor { + public: + int count = 0; + const tirx::VarNode* v = nullptr; + + void VisitExpr_(const tirx::VarNode* op) final { + if (op == v) ++count; + tirx::StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const tirx::BufferStoreNode* op) final { + VisitBuffer(op->buffer.get()); + tirx::StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const tirx::BufferLoadNode* op) final { + VisitBuffer(op->buffer.get()); + tirx::StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const tirx::AllocBufferNode* op) final { + VisitBuffer(op->buffer.get()); + tirx::StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const tirx::DeclBufferNode* op) final { + VisitBuffer(op->buffer.get()); + tirx::StmtExprVisitor::VisitStmt_(op); + } + void VisitBuffer(const tirx::BufferNode* buffer) { + VisitExpr(buffer->data); + for (const PrimExpr& shape_i : buffer->shape) VisitExpr(shape_i); + for (const PrimExpr& stride_i : buffer->strides) VisitExpr(stride_i); + VisitExpr(buffer->elem_offset); + } + explicit OccurrenceCounter(const tirx::VarNode* var) { v = var; } +}; + +/*! \brief Count how many times a Var occurs in a PrimFunc (params + buffer_map + body). */ +inline int CountVarOccurrence(const tirx::PrimFunc& f, const tirx::Var& v) { + OccurrenceCounter counter(v.get()); + counter(f->body); + for (const tirx::Var& param : f->params) { + counter(param); + } + for (const auto& pair : f->buffer_map) { + counter(pair.first); + counter.VisitBuffer(pair.second.get()); + } + return counter.count; +} + +/*! \brief Check if a buffer is "simple" (can be inlined as param annotation). + * Matches V1 IsSimpleBuffer logic exactly. */ +inline bool IsSimpleBuffer(const tirx::Buffer& buf) { + if (!buf->strides.empty()) return false; + for (const PrimExpr& shp_i : buf->shape) { + if (!tirx::UndefinedVars(shp_i).empty()) return false; + } + if (!tirx::UndefinedVars(buf->elem_offset).empty()) { + return false; + } else if (buf->elem_offset->IsInstance()) { + IntImm elem_offset = Downcast(buf->elem_offset); + if (elem_offset->value != 0) return false; + } + return buf.scope() == "global" && + buf->data_alignment == runtime::kAllocAlignment && + buf->offset_factor == 1 && + buf->buffer_type == tirx::BufferType::kDefault && + buf->axis_separators.empty(); +} + +/*! \brief Print buffer as T.Buffer(shape, dtype) annotation (simple buffer). */ +inline text::ExprAST PrintBufferAnnotation(const tirx::Buffer& buf, const text::IRPrinter& printer, + const text::AccessPath& path) { + ffi::List args; + args.push_back(printer->PrintTuple(buf->shape, path->Attr("shape"))); + args.push_back(text::LiteralAST::Str(DType2Str(buf->dtype))); + return text::ExprCall(TIR("Buffer"), std::move(args)); +} + +/*! + * \brief Check if a PrimExpr is a new (undefined) Var in the printer. + */ +inline bool IsNewVar(const PrimExpr& e, const text::IRPrinter& printer) { + return e->IsInstance() && !printer->VarGet(e).has_value(); +} + +/*! + * \brief Define a new TIR Var if not already defined, and emit + * `var_name = T.()` into the frame. + * Returns the IdAST for the newly defined var. + */ +inline text::ExprAST DefineNewTIRVar(const tirx::Var& var, const text::IRPrinter& printer, + text::DefaultFrame& frame) { + text::ExprAST var_id = DefineVar(var, printer, text::AccessPath::Root()); + std::string dtype_str = DType2Str(var->dtype); + // Match V1's PrintVarCreation: add is_size_var=True kwarg for SizeVar + if (var->IsInstance()) { + text::ExprAST rhs = text::ExprCallKw(TIR(dtype_str), {}, + {ffi::String("is_size_var")}, {text::LiteralAST::Bool(true)}); + frame->stmts.push_back(text::AssignAST(var_id, rhs, ffi::Optional())); + } else { + text::ExprAST rhs = text::ExprCall(TIR(dtype_str), {}); + frame->stmts.push_back(text::AssignAST(var_id, rhs, ffi::Optional())); + } + return var_id; +} + +/*! + * \brief Define any new vars in a buffer's shape, strides, and elem_offset. + * Emits `var = T.int32()` etc. into the frame for each undefined var. + * Must be called BEFORE PrintBufferDecl so that all buffer vars are available. + * + * Uses PostOrderVisit to recurse into compound expressions (e.g. batch_size + 1) + * to find ALL nested Vars, not just top-level ones. + */ +inline void DefineBufferVars(const tirx::Buffer& buf, const text::IRPrinter& printer, + text::DefaultFrame& frame) { + auto visit_expr = [&](const PrimExpr& e) { + tirx::PostOrderVisit(e, [&](const ffi::ObjectRef& obj) { + if (const auto* var_node = obj.as()) { + tirx::Var var = ffi::GetRef(var_node); + if (!printer->VarGet(var).has_value()) { + DefineNewTIRVar(var, printer, frame); + } + } + }); + }; + for (const PrimExpr& e : buf->shape) { + visit_expr(e); + } + for (const PrimExpr& e : buf->strides) { + visit_expr(e); + } + // NOTE: Do NOT define elem_offset vars here. They are handled by + // PrintBufferDecl which decides whether to emit elem_offset=... + // kwarg or offset_factor=... with inline definition (matching V1 + // try_inline_def logic). +} + +/*! + * \brief Define the buffer's data variable as `buf_name.data` using VarDefNoName. + * This allows references to buf->data to render as `A.data` instead of `A_1`. + * Must be called AFTER the buffer itself has been defined via VarDef. + */ +inline void DefineBufferDataVar(const tirx::Buffer& buf, const text::IRPrinter& printer) { + if (!printer->VarGet(buf->data).has_value()) { + text::ExprAST buf_expr = printer->VarGet(buf).value(); + // Capture buf_expr in a Function that returns buf_name.data + ffi::Function creator = ffi::Function::FromTyped([buf_expr]() -> text::ExprAST { + return text::ExprAttr(buf_expr, "data"); + }); + printer->VarDefNoName(creator, buf->data, + ffi::Optional(printer->frames.back().cast())); + } +} + +/*! + * \brief Print a buffer declaration call: T.(extra_args..., shape, dtype, kwargs...) + * + * Matches V1's BufferDecl/BufferCall: positional args are (extra_args..., shape, dtype), + * then kwargs for non-default: data, strides, elem_offset, scope, align, offset_factor, + * buffer_type, axis_separators, annotations. + * + * NOTE: Call DefineBufferVars() first to define any new shape/stride vars. + * + * \param annotations Optional annotations map (from AllocBuffer). If non-empty, + * emits annotations={...} kwarg. + * \param annotations_path AccessPath for annotations (used only when annotations non-empty). + */ +inline text::ExprAST PrintBufferDecl(const tirx::Buffer& buf, const std::string& method, + ffi::List extra_args, + const text::IRPrinter& printer, const text::AccessPath& path, + const ffi::Map& annotations = {}, + const text::AccessPath& annotations_path = text::AccessPath::Root()) { + // Positional: extra_args, shape, dtype + ffi::List args; + for (const auto& a : extra_args) args.push_back(a); + args.push_back(printer->PrintTuple(buf->shape, path->Attr("shape"))); + if (DType2Str(buf->dtype) != "float32") { + args.push_back(text::LiteralAST::Str(DType2Str(buf->dtype))); + } + // Kwargs for non-default fields + ffi::List kw_keys; + ffi::List kw_vals; + // data: print for decl_buffer/match_buffer when the data pointer is shared with + // another already-defined buffer. Skip for alloc_buffer/sblock_alloc_buffer + // (they create their own data pointer, so data= would be self-referential). + if (method != "alloc_buffer" && method != "sblock_alloc_buffer" && + !IsNewVar(buf->data, printer)) { + kw_keys.push_back(ffi::String("data")); + kw_vals.push_back(Print(printer, buf->data, path->Attr("data"))); + } + // strides (skip for alloc_buffer — its parser doesn't accept strides; + // sblock_alloc_buffer does accept strides, so emit them) + if (!buf->strides.empty() && method != "alloc_buffer") { + kw_keys.push_back(ffi::String("strides")); + ffi::List stride_elts; + for (int i = 0; i < static_cast(buf->strides.size()); ++i) { + stride_elts.push_back(Print(printer, buf->strides[i], path->Attr("strides")->ArrayItem(i))); + } + kw_vals.push_back(text::TupleAST({}, std::move(stride_elts))); + } + // elem_offset + // V1 logic: if IntImm, print only if non-zero. + // If new var, DON'T print elem_offset kwarg (but set needs_print_factor). + // If existing var, print elem_offset kwarg. + bool needs_print_factor_for_elem_offset = false; + if (const auto* int_imm = buf->elem_offset.as()) { + if (int_imm->value != 0) { + kw_keys.push_back(ffi::String("elem_offset")); + kw_vals.push_back(Print(printer, buf->elem_offset, path->Attr("elem_offset"))); + } + } else if (IsNewVar(buf->elem_offset, printer)) { + // New var: don't print elem_offset kwarg, but force offset_factor printing. + // Inline-define the var as buf.elem_offset (matching V1's try_inline_def) + // so subsequent references resolve correctly. + needs_print_factor_for_elem_offset = true; + { + tirx::Var offset_var = Downcast(buf->elem_offset); + text::ExprAST buf_expr = printer->VarGet(buf).value(); + ffi::Function creator = ffi::Function::FromTyped([buf_expr]() -> text::ExprAST { + return text::ExprAttr(buf_expr, "elem_offset"); + }); + printer->VarDefNoName(creator, offset_var, + ffi::Optional(printer->frames.back().cast())); + } + } else { + // Existing var: print elem_offset kwarg + kw_keys.push_back(ffi::String("elem_offset")); + kw_vals.push_back(Print(printer, buf->elem_offset, path->Attr("elem_offset"))); + } + // scope + if (buf.scope() != "global") { + kw_keys.push_back(ffi::String("scope")); + kw_vals.push_back(text::LiteralAST::Str(buf.scope())); + } + // align (data_alignment) + if (buf->data_alignment != runtime::kAllocAlignment) { + kw_keys.push_back(ffi::String("align")); + kw_vals.push_back(text::LiteralAST::Int(buf->data_alignment)); + } + // offset_factor + if (needs_print_factor_for_elem_offset || buf->offset_factor != 1) { + kw_keys.push_back(ffi::String("offset_factor")); + kw_vals.push_back(text::LiteralAST::Int(buf->offset_factor)); + } + // buffer_type + if (buf->buffer_type != tirx::BufferType::kDefault) { + kw_keys.push_back(ffi::String("buffer_type")); + kw_vals.push_back(text::LiteralAST::Str("auto")); + } + // axis_separators (V1 prints for all buffer types except standalone alloc_buffer) + if (!buf->axis_separators.empty()) { + kw_keys.push_back(ffi::String("axis_separators")); + kw_vals.push_back(printer->PrintList(buf->axis_separators, path->Attr("axis_separators"))); + } + // annotations (from AllocBuffer node, not the Buffer itself) + if (!annotations.empty()) { + kw_keys.push_back(ffi::String("annotations")); + kw_vals.push_back(Print(printer, annotations, annotations_path)); + } + + if (!kw_keys.empty()) { + return text::ExprCallKw(TIR(method), std::move(args), std::move(kw_keys), std::move(kw_vals)); + } + return text::ExprCall(TIR(method), std::move(args)); +} + +} // namespace printer +} // namespace tvm + +#endif // TVM_TIRX_IR_SCRIPT_PRINT_UTILS_H_ diff --git a/src/tirx/ir/stmt.cc b/src/tirx/ir/stmt.cc index fb5b06954c30..948181081d4a 100644 --- a/src/tirx/ir/stmt.cc +++ b/src/tirx/ir/stmt.cc @@ -22,12 +22,18 @@ */ #include #include +#include #include +#include #include #include #include +#include +#include + #include "buffer_common.h" +#include "script_print_utils.h" namespace tvm { namespace tirx { @@ -35,7 +41,6 @@ namespace tirx { TVM_FFI_STATIC_INIT_BLOCK() { StmtNode::RegisterReflection(); BindNode::RegisterReflection(); - AttrStmtNode::RegisterReflection(); AssertStmtNode::RegisterReflection(); BufferStoreNode::RegisterReflection(); @@ -52,6 +57,42 @@ TVM_FFI_STATIC_INIT_BLOCK() { SBlockRealizeNode::RegisterReflection(); } +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = ::tvm::ffi::reflection; + refl::GlobalDef() + .def("tirx._structured_msg", [](ffi::AnyView /*ctx*/, AssertStmt node) -> ffi::Array { + ffi::Array parts_arr; + for (const StringImm& part : node->message_parts) { + parts_arr.push_back(part); + } + return {node->error_kind, parts_arr}; + }) + .def("tirx._evaluate_is_return", [](ffi::AnyView /*ctx*/, Evaluate self) -> bool { + if (auto* call = self->value.as()) { + return call->op.same_as(tirx::builtin::ret()) && call->args.size() == 1; + } + return false; + }) + .def("tirx._evaluate_expr", [](ffi::AnyView /*ctx*/, Evaluate self) -> PrimExpr { + if (auto* call = self->value.as()) { + if (call->op.same_as(tirx::builtin::ret()) && call->args.size() == 1) { + return call->args[0]; + } + } + return self->value; + }) + .def("tirx._evaluate_kind", [](ffi::AnyView /*ctx*/, Evaluate self) -> ffi::Optional { + // For ret() calls, the return check handles it, so kind doesn't matter + if (auto* call = self->value.as()) { + if (call->op.same_as(tirx::builtin::ret())) return {}; + } + // For other calls, no wrapper needed + if (self->value->IsInstance()) return {}; + // Non-call: wrap with T.evaluate + return ffi::String("T.evaluate"); + }); +} + // Bind Bind::Bind(Var var, PrimExpr value, Span span) { TVM_FFI_ICHECK(value.defined()); @@ -608,5 +649,381 @@ TVM_TIR_REGISTER_OP("type_annotation") .set_attr("TScriptDtypePrintLocation", Integer(ScriptDtypePrintLocation::kFirst)); + +// --------------------------------------------------------------------------- +// __ffi_text_print__ overrides +// --------------------------------------------------------------------------- + +// Static helper for SBlockRealize/SBlock printing +static ::tvm::ffi::pyast::NodeAST PrintSBlockRealize( + SBlockRealize realize, ::tvm::ffi::pyast::IRPrinter printer, + ::tvm::ffi::pyast::AccessPath path) { + using namespace printer; + SBlock block = realize->block; + text::AccessPath block_p = path->Attr("block"); + + text::DefaultFrame frame; + printer->FramePush(frame); + + // Build context expr: T.sblock("name") + text::ExprAST ctx = text::ExprCall(TIR("sblock"), {text::LiteralAST::Str(block->name_hint)}); + + // Define iter_vars and build as_var + for (int i = 0; i < static_cast(block->iter_vars.size()); ++i) { + IterVar iv = block->iter_vars[i]; + text::AccessPath iv_p = block_p->Attr("iter_vars")->ArrayItem(i); + + text::ExprAST var_id = DefineVar(iv->var, printer, iv_p->Attr("var")); + + std::string axis_type; + switch (static_cast(iv->iter_type)) { + case kDataPar: axis_type = "spatial"; break; + case kCommReduce: axis_type = "reduce"; break; + case kOrdered: axis_type = "scan"; break; + default: axis_type = "opaque"; break; + } + + text::ExprAST dom(ffi::UnsafeInit{}); + if (iv->dom.defined()) { + if (is_zero(iv->dom->min)) { + dom = Print(printer, iv->dom->extent, iv_p->Attr("dom")->Attr("extent")); + } else { + dom = text::TupleAST({}, {Print(printer, iv->dom->min, iv_p->Attr("dom")->Attr("min")), + Print(printer, iv->dom->min + iv->dom->extent, + iv_p->Attr("dom")->Attr("extent"))}); + } + } else { + dom = text::LiteralAST::Null(); + } + + text::ExprAST val(ffi::UnsafeInit{}); + if (i < static_cast(realize->iter_values.size())) { + val = Print(printer, realize->iter_values[i], path->Attr("iter_values")->ArrayItem(i)); + } else { + val = text::LiteralAST::Null(); + } + + text::ExprAST rhs = text::ExprCall(text::ExprAttr(text::ExprAttr(text::IdAST("T"), "axis"), axis_type), {dom, val}); + frame->stmts.push_back( + text::AssignAST(var_id, rhs, ffi::Optional())); + } + + // Predicate + if (!is_one(realize->predicate)) { + text::ExprAST pred = Print(printer, realize->predicate, path->Attr("predicate")); + frame->stmts.push_back( + text::ExprStmtAST(text::ExprCall(TIR("where"), {pred}))); + } + + // Reads + { + ffi::List reads; + for (int i = 0; i < static_cast(block->reads.size()); ++i) { + reads.push_back( + Print(printer, block->reads[i], block_p->Attr("reads")->ArrayItem(i))); + } + frame->stmts.push_back( + text::ExprStmtAST(text::ExprCall(TIR("reads"), std::move(reads)))); + } + + // Writes + { + ffi::List writes; + for (int i = 0; i < static_cast(block->writes.size()); ++i) { + writes.push_back( + Print(printer, block->writes[i], block_p->Attr("writes")->ArrayItem(i))); + } + frame->stmts.push_back( + text::ExprStmtAST(text::ExprCall(TIR("writes"), std::move(writes)))); + } + + // Annotations + if (!block->annotations.empty()) { + text::ExprAST annot = Print(printer, block->annotations, block_p->Attr("annotations")); + frame->stmts.push_back( + text::ExprStmtAST(text::ExprCall(TIR("sblock_attr"), {annot}))); + } + + // Alloc buffers + for (int i = 0; i < static_cast(block->alloc_buffers.size()); ++i) { + Buffer buf = block->alloc_buffers[i]; + text::AccessPath buffer_p = block_p->Attr("alloc_buffers")->ArrayItem(i); + std::string buf_name = buf->name; + if (buf_name.empty()) buf_name = "buffer"; + printer->VarDef(buf_name, buf, frame); + text::ExprAST buf_id = printer->VarGet(buf).value(); + ffi::List no_extra; + text::ExprAST rhs = PrintBufferDecl(buf, "sblock_alloc_buffer", std::move(no_extra), + printer, buffer_p); + DefineBufferDataVar(buf, printer); + frame->stmts.push_back( + text::AssignAST(buf_id, rhs, ffi::Optional())); + } + + // Match buffers + for (int i = 0; i < static_cast(block->match_buffers.size()); ++i) { + text::NodeAST s = printer->operator()(ffi::Any(block->match_buffers[i]), + block_p->Attr("match_buffers")->ArrayItem(i)) + .cast(); + if (s->IsInstance()) { + frame->stmts.push_back(Downcast(s)); + } + } + + // Init + if (block->init.defined()) { + ffi::List init_body = PrintBodyStmts(block->init.value(), printer, + block_p->Attr("init")); + text::ExprAST init_ctx = text::ExprCall(TIR("init"), {}); + frame->stmts.push_back( + text::WithAST(ffi::Optional(), init_ctx, init_body)); + } + + // Body + ffi::List body = PrintBodyStmts(block->body, printer, block_p->Attr("body")); + + // Merge + ffi::List all_body; + for (const auto& s : frame->stmts) all_body.push_back(s); + for (const auto& s : body) all_body.push_back(s); + + printer->FramePop(); + return text::WithAST(ffi::Optional(), ctx, all_body); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + using namespace printer; + + // AllocBuffer, DeclBuffer, MatchBufferRegion + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](AllocBuffer node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + using namespace printer; + Buffer buf = node->buffer; + text::DefaultFrame frame = printer->frames.back().cast(); + printer->VarDef(buf->name, buf, frame); + text::ExprAST buf_id = printer->VarGet(buf).value(); + text::AccessPath buffer_p = path->Attr("buffer"); + ffi::List no_extra; + text::ExprAST rhs = PrintBufferDecl(buf, "alloc_buffer", std::move(no_extra), + printer, buffer_p, + node->annotations, path->Attr("annotations")); + DefineBufferDataVar(buf, printer); + return text::AssignAST(buf_id, rhs, ffi::Optional()); + }); + + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](DeclBuffer node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + using namespace printer; + Buffer buf = node->buffer; + text::DefaultFrame frame = printer->frames.back().cast(); + DefineBufferVars(buf, printer, frame); + printer->VarDef(buf->name, buf, frame); + text::ExprAST buf_id = printer->VarGet(buf).value(); + ffi::List no_extra; + text::ExprAST rhs = PrintBufferDecl(buf, "decl_buffer", std::move(no_extra), + printer, path->Attr("buffer")); + DefineBufferDataVar(buf, printer); + return text::AssignAST(buf_id, rhs, ffi::Optional()); + }); + + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](MatchBufferRegion node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + using namespace printer; + Buffer buf = node->buffer; + text::DefaultFrame frame = printer->frames.back().cast(); + DefineBufferVars(buf, printer, frame); + printer->VarDef(buf->name, buf, frame); + text::ExprAST buf_id = printer->VarGet(buf).value(); + text::ExprAST source = Print(printer, node->source, path->Attr("source")); + ffi::List extra_args; + extra_args.push_back(source); + text::ExprAST rhs = PrintBufferDecl(buf, "match_buffer", std::move(extra_args), + printer, path->Attr("buffer")); + DefineBufferDataVar(buf, printer); + return text::AssignAST(buf_id, rhs, ffi::Optional()); + }); + + // AttrStmt + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](AttrStmt node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + using namespace printer; + bool is_iter_var = (node->node.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) && + node->node.cast()->IsInstance(); + if ((node->attr_key == "thread_extent" || node->attr_key == "virtual_thread") && + is_iter_var) { + IterVar iv = node->node.cast(); + bool was_defined = printer->VarGet(iv->var).has_value(); + if (was_defined) { + ffi::List body = PrintBodyStmts(node->body, printer, path->Attr("body")); + text::ExprAST var = printer->VarGet(iv->var).value(); + text::ExprAST ctx = text::ExprCall(TIR("launch_thread"), + {var, Print(printer, node->value, path->Attr("value"))}); + return text::WithAST(ffi::Optional(), ctx, body); + } else { + text::DefaultFrame inner_frame; + printer->FramePush(inner_frame); + DefineVar(iv->var, printer, path->Attr("node")->Attr("var")); + ffi::List body = PrintBodyStmts(node->body, printer, path->Attr("body")); + text::ExprAST var = printer->VarGet(iv->var).value(); + printer->FramePop(); + text::ExprAST ctx = text::ExprCall(TIR("launch_thread"), + {text::LiteralAST::Str(iv->thread_tag), + Print(printer, node->value, path->Attr("value"))}); + return text::WithAST(ffi::Optional(var), ctx, body); + } + } + ffi::List body = PrintBodyStmts(node->body, printer, path->Attr("body")); + text::ExprAST ctx = text::ExprCall(TIR("attr"), + {Print(printer, node->node, path->Attr("node")), + text::LiteralAST::Str(node->attr_key, {path->Attr("attr_key")}), + Print(printer, node->value, path->Attr("value"))}); + return text::WithAST(ffi::Optional(), ctx, body); + }); + + // ForNode + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](tirx::For loop, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + using namespace printer; + // Step 1. Check syntactic sugar: T.grid + std::vector grid; + std::unordered_set grid_loop_vars; + { + for (const ForNode* l = loop.get(); l != nullptr; + l = l->body.as()) { + if (l->kind != ForKind::kSerial || + !is_zero(l->min) || + !l->annotations.empty() || + !l->HasTrivialStep() || + tirx::UsesVar(l->extent, [&grid_loop_vars](const tirx::VarNode* v) { + return grid_loop_vars.count(v) > 0; + })) { + break; + } + grid.push_back(l); + grid_loop_vars.insert(l->loop_var.get()); + } + } + + // Step 2. If grid.size() > 1, print as T.grid + if (grid.size() > 1) { + text::DefaultFrame frame; + printer->FramePush(frame); + int n = grid.size(); + ffi::List lhs_vars; + ffi::List extents; + text::AccessPath cur_p = path; + for (int i = 0; i < n; ++i) { + const ForNode* g = grid[i]; + lhs_vars.push_back( + DefineVar(ffi::GetRef(static_cast(g->loop_var.get())), + printer, cur_p->Attr("loop_var"))); + extents.push_back(Print(printer, g->extent, cur_p->Attr("extent"))); + cur_p = cur_p->Attr("body"); + } + text::ExprAST lhs = text::TupleAST({}, std::move(lhs_vars)); + text::ExprAST rhs = text::ExprCall(TIR("grid"), std::move(extents)); + ffi::List body = PrintBodyStmts( + ffi::GetRef(static_cast(grid.back()->body.get())), + printer, cur_p); + ffi::List all_body; + for (const auto& s : frame->stmts) all_body.push_back(s); + for (const auto& s : body) all_body.push_back(s); + printer->FramePop(); + return text::ForAST(lhs, rhs, all_body); + } + + // Step 3. Single for loop (no grid sugar) + text::DefaultFrame frame; + printer->FramePush(frame); + text::ExprAST lhs = DefineVar(loop->loop_var, printer, path->Attr("loop_var")); + + text::ExprAST rhs(ffi::UnsafeInit{}); + bool is_zero_min = is_zero(loop->min); + bool has_trivial_step = loop->HasTrivialStep(); + + if (loop->kind == ForKind::kSerial && loop->annotations.empty()) { + ffi::List range_args; + if (is_zero_min && has_trivial_step) { + range_args.push_back(Print(printer, loop->extent, path->Attr("extent"))); + } else { + PrimExpr end = loop->min + loop->extent; + range_args.push_back(Print(printer, loop->min, path->Attr("min"))); + range_args.push_back(Print(printer, end, path->Attr("extent"))); + } + if (!has_trivial_step) { + range_args.push_back(Print(printer, *loop->step, path->Attr("step"))); + } + rhs = text::ExprCall(text::IdAST("range"), std::move(range_args)); + } else { + std::string prefix; + switch (loop->kind) { + case ForKind::kSerial: prefix = "serial"; break; + case ForKind::kParallel: prefix = "parallel"; break; + case ForKind::kUnrolled: prefix = "unroll"; break; + case ForKind::kVectorized: prefix = "vectorized"; break; + case ForKind::kThreadBinding: prefix = "thread_binding"; break; + default: prefix = "serial"; break; + } + ffi::List args; + if (is_zero_min && has_trivial_step) { + args.push_back(Print(printer, loop->extent, path->Attr("extent"))); + } else { + PrimExpr end = loop->min + loop->extent; + args.push_back(Print(printer, loop->min, path->Attr("min"))); + args.push_back(Print(printer, end, path->Attr("extent"))); + } + ffi::List kw_keys; + ffi::List kw_vals; + if (!loop->annotations.empty()) { + kw_keys.push_back(ffi::String("annotations")); + kw_vals.push_back(Print(printer, loop->annotations, path->Attr("annotations"))); + } + if (loop->kind == ForKind::kThreadBinding && loop->thread_binding.defined()) { + kw_keys.push_back(ffi::String("thread")); + kw_vals.push_back( + text::LiteralAST::Str(loop->thread_binding.value()->thread_tag, + {path->Attr("thread_binding")})); + } + if (!has_trivial_step) { + kw_keys.push_back(ffi::String("step")); + kw_vals.push_back(Print(printer, *loop->step, path->Attr("step"))); + } + rhs = !kw_keys.empty() + ? text::ExprCallKw(TIR(prefix), std::move(args), std::move(kw_keys), std::move(kw_vals)) + : text::ExprCall(TIR(prefix), std::move(args)); + } + + ffi::List body = PrintBodyStmts(loop->body, printer, path->Attr("body")); + ffi::List all_body; + for (const auto& s : frame->stmts) all_body.push_back(s); + for (const auto& s : body) all_body.push_back(s); + printer->FramePop(); + return text::ForAST(lhs, rhs, all_body); + }); + + // SBlockRealize + SBlock + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](SBlockRealize node, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + return PrintSBlockRealize(node, printer, path); + }); + + refl::TypeAttrDef().def( + "__ffi_text_print__", + [](SBlock block, text::IRPrinter printer, text::AccessPath path) -> text::NodeAST { + ffi::Array iter_values; + for (const auto& iv : block->iter_vars) { + iter_values.push_back(iv->var); + } + SBlockRealize realize(iter_values, Bool(true), block); + return PrintSBlockRealize(realize, printer, path); + }); +} + } // namespace tirx } // namespace tvm