3535#include < utility>
3636#include < vector>
3737
38+ #include " tvm/tir/var.h"
39+
3840namespace tvm {
3941namespace tir {
4042
@@ -158,6 +160,22 @@ class Layout : public ObjectRef {
158160 return undef;
159161 }
160162
163+ /* !
164+ * \brief Packs the Given Array of IterVars into a Single IterVar. Each IterVar in the Array
165+ * should represent either a single primal axis or one or more subordinate axis
166+ * \param iters Array of iter vars to be packed
167+ * \return A packed iter var
168+ */
169+ static IterVar PackIterVar (ffi::Array<IterVar> iters);
170+
171+ /* !
172+ * \brief Unpacks a Packed IterVar into its constituents
173+ * \param packed_iter A Packed IterVar containing a single primal axis or one or more subordinate
174+ * axis
175+ * \return Constituent IterVars
176+ */
177+ static ffi::Array<IterVar> UnpackIterVar (IterVar packed_iter);
178+
161179 /* !
162180 * \brief Returns a sub-layout which is the portion of the object
163181 * that starts at dimension \p pos and spans \p len dimensions
@@ -187,9 +205,12 @@ class Layout : public ObjectRef {
187205 inline size_t ndim_primal () const {
188206 if (!defined ()) return 0 ;
189207 size_t ct = 0 ;
190- for (auto x : operator ->()->axes ) {
191- if (LayoutAxis::Get (x).IsPrimal ()) {
192- ct++;
208+ for (auto px : operator ->()->axes ) {
209+ auto iter_vars = UnpackIterVar (px);
210+ for (auto x : iter_vars) {
211+ if (LayoutAxis::Get (x).IsPrimal ()) {
212+ ct++;
213+ }
193214 }
194215 }
195216 return ct;
@@ -204,10 +225,13 @@ class Layout : public ObjectRef {
204225 Layout new_src_layout;
205226 // 1) Find the axis which are missing in the current layout. Make them the prefix.
206227 std::string new_src_layout_str = " " ;
207- for (auto dst_axis : dst_layout->axes ) {
208- if (LayoutAxis::Get (dst_axis).IsPrimal ()) {
209- if (!this ->Contains (LayoutAxis::Get (dst_axis))) {
210- new_src_layout_str += dst_axis->var ->name_hint ;
228+ for (auto packed_axis : dst_layout->axes ) {
229+ auto iter_vars = UnpackIterVar (packed_axis);
230+ for (auto dst_axis : iter_vars) {
231+ if (LayoutAxis::Get (dst_axis).IsPrimal ()) {
232+ if (!this ->Contains (LayoutAxis::Get (dst_axis))) {
233+ new_src_layout_str += dst_axis->var ->name_hint ;
234+ }
211235 }
212236 }
213237 }
@@ -221,18 +245,36 @@ class Layout : public ObjectRef {
221245 * \brief return the index of the input axis.
222246 * If it is not found in the layout or the layout is undefined,
223247 * return -1.
224- * \param axis the input axis.
248+ * \param axis The input axis either a layout axis, or a packed axis
225249 * \return the index or -1 if not found.
226250 */
227- inline int32_t IndexOf (const LayoutAxis & axis) const {
251+ inline int32_t IndexOf (const std::string & axis) const {
228252 if (!this ->defined ()) return -1 ;
229253 const auto axes = operator ->()->axes ;
230254 for (size_t i = 0 ; i < axes.size (); ++i) {
231- if (axes[i]->var ->name_hint == axis. name () ) return static_cast <int32_t >(i);
255+ if (axes[i]->var ->name_hint == axis) return static_cast <int32_t >(i);
232256 }
233257 return -1 ;
234258 }
235259
260+ /* !
261+ * \brief return the index of the input axis.
262+ * If it is not found in the layout or the layout is undefined,
263+ * return -1.
264+ * \param axis the input layout axis.
265+ * \return the index or -1 if not found.
266+ */
267+ inline int32_t IndexOf (const LayoutAxis& axis) const { return IndexOf (axis.name ()); }
268+
269+ /* !
270+ * \brief return the index of the input axis.
271+ * If it is not found in the layout or the layout is undefined,
272+ * return -1.
273+ * \param iter the input iter var.
274+ * \return the index or -1 if not found.
275+ */
276+ inline int32_t IndexOf (const tir::IterVar& iter) const { return IndexOf (iter->var ->name_hint ); }
277+
236278 /* !
237279 * \brief Get the factor size of the subordinate axis.
238280 * \param axis the input primal-axis or subordinate-axis.
@@ -249,20 +291,23 @@ class Layout : public ObjectRef {
249291 */
250292 bool Contains (const LayoutAxis& axis) const {
251293 if (!defined ()) return false ;
252- for (const tir::IterVar var : operator ->()->axes ) {
253- if (var->var ->name_hint == axis.name ()) {
254- return true ;
294+ for (const tir::IterVar packed_var : operator ->()->axes ) {
295+ auto iter_vars = UnpackIterVar (packed_var);
296+ for (auto var : iter_vars) {
297+ if (var->var ->name_hint == axis.name ()) {
298+ return true ;
299+ }
255300 }
256301 }
257302 return false ;
258303 }
259304
260- const LayoutAxis& operator [](int32_t i) const {
305+ IterVar operator [](int32_t i) const {
261306 ICHECK (defined ()) << " Try to access axis from an undefined layout." ;
262307 int32_t index = i < 0 ? static_cast <int32_t >(ndim () + i) : i;
263308 ICHECK (index >= 0 && static_cast <size_t >(index) < ndim ()) << " Invalid index " << i;
264309 const tir::IterVar axis = operator ->()->axes [index];
265- return LayoutAxis::Get ( axis) ;
310+ return axis;
266311 }
267312
268313 /* ! \return the string description of the layout */
0 commit comments