Skip to content

Commit 83102c8

Browse files
committed
[RELAX][LAYOUT] Support multiple axis paching.
Like OIHW[4o4i] where we can pack multiple axis. Helpful while handling complex target layouts. This PR covers layout representation and transforms for these.
1 parent fa51ea2 commit 83102c8

11 files changed

Lines changed: 394 additions & 149 deletions

File tree

include/tvm/tir/data_layout.h

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include <utility>
3636
#include <vector>
3737

38+
#include "tvm/tir/var.h"
39+
3840
namespace tvm {
3941
namespace tir {
4042

@@ -158,6 +160,22 @@ class Layout : public ObjectRef {
158160
return undef;
159161
}
160162

163+
/*!
164+
* \brief Packs the Given Array of IterVars into a Single IterVar. Each IterVar in the Array
165+
* should represent either a single primal axis or one or more subordinate axis
166+
* \param iters Array of iter vars to be packed
167+
* \return A packed iter var
168+
*/
169+
static IterVar PackIterVar(ffi::Array<IterVar> iters);
170+
171+
/*!
172+
* \brief Unpacks a Packed IterVar into its constituents
173+
* \param packed_iter A Packed IterVar containing a single primal axis or one or more subordinate
174+
* axis
175+
* \return Constituent IterVars
176+
*/
177+
static ffi::Array<IterVar> UnpackIterVar(IterVar packed_iter);
178+
161179
/*!
162180
* \brief Returns a sub-layout which is the portion of the object
163181
* that starts at dimension \p pos and spans \p len dimensions
@@ -187,9 +205,12 @@ class Layout : public ObjectRef {
187205
inline size_t ndim_primal() const {
188206
if (!defined()) return 0;
189207
size_t ct = 0;
190-
for (auto x : operator->()->axes) {
191-
if (LayoutAxis::Get(x).IsPrimal()) {
192-
ct++;
208+
for (auto px : operator->()->axes) {
209+
auto iter_vars = UnpackIterVar(px);
210+
for (auto x : iter_vars) {
211+
if (LayoutAxis::Get(x).IsPrimal()) {
212+
ct++;
213+
}
193214
}
194215
}
195216
return ct;
@@ -204,10 +225,13 @@ class Layout : public ObjectRef {
204225
Layout new_src_layout;
205226
// 1) Find the axis which are missing in the current layout. Make them the prefix.
206227
std::string new_src_layout_str = "";
207-
for (auto dst_axis : dst_layout->axes) {
208-
if (LayoutAxis::Get(dst_axis).IsPrimal()) {
209-
if (!this->Contains(LayoutAxis::Get(dst_axis))) {
210-
new_src_layout_str += dst_axis->var->name_hint;
228+
for (auto packed_axis : dst_layout->axes) {
229+
auto iter_vars = UnpackIterVar(packed_axis);
230+
for (auto dst_axis : iter_vars) {
231+
if (LayoutAxis::Get(dst_axis).IsPrimal()) {
232+
if (!this->Contains(LayoutAxis::Get(dst_axis))) {
233+
new_src_layout_str += dst_axis->var->name_hint;
234+
}
211235
}
212236
}
213237
}
@@ -221,18 +245,36 @@ class Layout : public ObjectRef {
221245
* \brief return the index of the input axis.
222246
* If it is not found in the layout or the layout is undefined,
223247
* return -1.
224-
* \param axis the input axis.
248+
* \param axis The input axis either a layout axis, or a packed axis
225249
* \return the index or -1 if not found.
226250
*/
227-
inline int32_t IndexOf(const LayoutAxis& axis) const {
251+
inline int32_t IndexOf(const std::string& axis) const {
228252
if (!this->defined()) return -1;
229253
const auto axes = operator->()->axes;
230254
for (size_t i = 0; i < axes.size(); ++i) {
231-
if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
255+
if (axes[i]->var->name_hint == axis) return static_cast<int32_t>(i);
232256
}
233257
return -1;
234258
}
235259

260+
/*!
261+
* \brief return the index of the input axis.
262+
* If it is not found in the layout or the layout is undefined,
263+
* return -1.
264+
* \param axis the input layout axis.
265+
* \return the index or -1 if not found.
266+
*/
267+
inline int32_t IndexOf(const LayoutAxis& axis) const { return IndexOf(axis.name()); }
268+
269+
/*!
270+
* \brief return the index of the input axis.
271+
* If it is not found in the layout or the layout is undefined,
272+
* return -1.
273+
* \param iter the input iter var.
274+
* \return the index or -1 if not found.
275+
*/
276+
inline int32_t IndexOf(const tir::IterVar& iter) const { return IndexOf(iter->var->name_hint); }
277+
236278
/*!
237279
* \brief Get the factor size of the subordinate axis.
238280
* \param axis the input primal-axis or subordinate-axis.
@@ -249,20 +291,23 @@ class Layout : public ObjectRef {
249291
*/
250292
bool Contains(const LayoutAxis& axis) const {
251293
if (!defined()) return false;
252-
for (const tir::IterVar var : operator->()->axes) {
253-
if (var->var->name_hint == axis.name()) {
254-
return true;
294+
for (const tir::IterVar packed_var : operator->()->axes) {
295+
auto iter_vars = UnpackIterVar(packed_var);
296+
for (auto var : iter_vars) {
297+
if (var->var->name_hint == axis.name()) {
298+
return true;
299+
}
255300
}
256301
}
257302
return false;
258303
}
259304

260-
const LayoutAxis& operator[](int32_t i) const {
305+
IterVar operator[](int32_t i) const {
261306
ICHECK(defined()) << "Try to access axis from an undefined layout.";
262307
int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
263308
ICHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
264309
const tir::IterVar axis = operator->()->axes[index];
265-
return LayoutAxis::Get(axis);
310+
return axis;
266311
}
267312

268313
/*! \return the string description of the layout */

python/tvm/tir/data_layout.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def __len__(self):
4141
return _ffi_api.LayoutNdim(self) # type: ignore
4242

4343
def __contains__(self, axis):
44-
return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
44+
# Note: We do a weaker check for packed axis assuming layout is valid
45+
return not any(bkt in axis for bkt in "[]") and axis in self.name
4546

4647
def __getitem__(self, index):
4748
if index >= len(self):
@@ -54,7 +55,7 @@ def index_of(self, axis):
5455
Parameters
5556
----------
5657
axis : str
57-
The axis name, need to be [a-z,A-Z]
58+
The axis name, needs to be [a-z,A-Z] or a packed axis
5859
5960
Returns
6061
-------

src/contrib/msc/core/ir/graph_builder.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,8 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional<Expr>
454454
} else if (input_types[i] == "weight" &&
455455
(optype == "msc.linear" || optype == "msc.linear_bias")) {
456456
if (ref->layout.name() == "IO") {
457-
ffi::String valid_layout = ref->layout[1].name() + ref->layout[0].name();
457+
ffi::String valid_layout =
458+
ref->layout[1]->var->name_hint + ref->layout[0]->var->name_hint;
458459
const auto& valid_shape = ffi::Array<Integer>({ref->shape[1], ref->shape[0]});
459460
weight = MSCTensor(weight_name, ref->dtype, valid_layout, valid_shape);
460461
} else {

src/contrib/msc/core/transform/layout_utils.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ const LayoutDecision LayoutUtils::ExpandLayout(const LayoutDecision& src_layout,
158158
<< "Only support normal layout, get " << src_layout->layout;
159159
std::set<std::string> used_axes;
160160
for (size_t i = 0; i < src_layout->layout.ndim(); i++) {
161-
used_axes.insert(src_layout->layout[i].name());
161+
used_axes.insert(src_layout->layout[i]->var->name_hint);
162162
}
163163
std::vector<std::string> prefer_axes{"N", "C", "H", "W", "D"};
164164
for (const auto& a : axes) {
@@ -198,7 +198,7 @@ const LayoutDecision LayoutUtils::ReduceLayout(const LayoutDecision& src_layout,
198198
if (reduce_axes_set.count(i)) {
199199
continue;
200200
}
201-
new_layout += src_layout->layout[i].name();
201+
new_layout += src_layout->layout[i]->var->name_hint;
202202
}
203203
return LayoutDecision(new_layout);
204204
}
@@ -207,7 +207,7 @@ const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout
207207
const ffi::Array<Integer>& axes) {
208208
ffi::String layout_str;
209209
for (const auto& a : axes) {
210-
layout_str = layout_str + src_layout->layout[a->value].name();
210+
layout_str = layout_str + src_layout->layout[a->value]->var->name_hint;
211211
}
212212
return LayoutDecision(layout_str);
213213
}
@@ -216,7 +216,7 @@ const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout
216216
const std::vector<size_t>& axes) {
217217
ffi::String layout_str;
218218
for (const auto& a : axes) {
219-
layout_str = layout_str + src_layout->layout[a].name();
219+
layout_str = layout_str + src_layout->layout[a]->var->name_hint;
220220
}
221221
return LayoutDecision(layout_str);
222222
}
@@ -226,7 +226,7 @@ int LayoutUtils::InferBatchDim(const LayoutDecision& layout) {
226226
return -1;
227227
}
228228
for (size_t i = 0; i < layout->layout.ndim(); i++) {
229-
if (layout->layout[i].name() == "N") {
229+
if (layout->layout[i]->var->name_hint == "N") {
230230
return static_cast<int>(i);
231231
}
232232
}

src/contrib/msc/core/transform/set_expr_layout.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ InferLayoutOutput ForwardInferLayoutBinary(
246246
input_layouts.push_back(LayoutDecision(""));
247247
} else if (t_info->ndim == 1) {
248248
const auto& ref_layout = output->output_layouts[0].LeafValue()->layout;
249-
input_layouts.push_back(LayoutDecision(ref_layout[ref_layout.ndim() - 1].name()));
249+
input_layouts.push_back(LayoutDecision(ref_layout[ref_layout.ndim() - 1]->var->name_hint));
250250
} else {
251251
input_layouts.push_back(output->input_layouts[i]);
252252
}
@@ -361,7 +361,7 @@ InferLayoutOutput ForwardInferLayoutMatmul(
361361
size_t start = a_layout->layout.ndim() - b_shape.size();
362362
ffi::String pre_layout;
363363
for (size_t i = start; i < a_layout->layout.ndim() - 2; i++) {
364-
pre_layout = pre_layout + a_layout->layout[i].name();
364+
pre_layout = pre_layout + a_layout->layout[i]->var->name_hint;
365365
}
366366
LayoutDecision b_layout = LayoutDecision(pre_layout + "IO");
367367
return InferLayoutOutput({a_layout, b_layout}, {a_layout}, Attrs());
@@ -671,7 +671,7 @@ InferLayoutOutput BackwardInferLayoutBinary(
671671
input_layouts.push_back(LayoutDecision(""));
672672
} else if (t_info->ndim == 1) {
673673
const auto& ref_layout = output->output_layouts[0].LeafValue()->layout;
674-
input_layouts.push_back(LayoutDecision(ref_layout[ref_layout.ndim() - 1].name()));
674+
input_layouts.push_back(LayoutDecision(ref_layout[ref_layout.ndim() - 1]->var->name_hint));
675675
} else {
676676
input_layouts.push_back(output->input_layouts[i]);
677677
}
@@ -766,7 +766,7 @@ InferLayoutOutput BackwardInferLayoutMatmul(
766766
size_t start = output_layout->layout.ndim() - b_shape.size();
767767
ffi::String pre_layout;
768768
for (size_t i = start; i < output_layout->layout.ndim() - 2; i++) {
769-
pre_layout = pre_layout + output_layout->layout[i].name();
769+
pre_layout = pre_layout + output_layout->layout[i]->var->name_hint;
770770
}
771771
LayoutDecision b_layout = LayoutDecision(pre_layout + "IO");
772772
return InferLayoutOutput({output_layout, b_layout}, {output_layout}, Attrs());

src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ class TensorRTConvCodeGen : public TensorRTOpCode {
329329
const auto& weight = node()->WeightAt("weight");
330330
std::vector<int64_t> kernel_size;
331331
for (size_t i = 0; i < weight->Ndim(); i++) {
332-
if (weight->layout[i].name() == "I" || weight->layout[i].name() == "O") {
332+
if (weight->layout[i]->var->name_hint == "I" || weight->layout[i]->var->name_hint == "O") {
333333
continue;
334334
}
335335
kernel_size.push_back(weight->DimAt(i)->value);
@@ -442,10 +442,10 @@ class TensorRTPadCodeGen : public TensorRTOpCode {
442442
std::vector<int> pre_padding{2, 0}, post_padding{2, 0};
443443
const auto& input = node()->InputAt(0);
444444
for (size_t i = 0; i < input->Ndim(); i++) {
445-
if (input->layout[i].name() == "H") {
445+
if (input->layout[i]->var->name_hint == "H") {
446446
pre_padding[0] = pad_width[i * 2];
447447
post_padding[0] = pad_width[i * 2 + 1];
448-
} else if (input->layout[i].name() == "W") {
448+
} else if (input->layout[i]->var->name_hint == "W") {
449449
pre_padding[1] = pad_width[i * 2];
450450
post_padding[1] = pad_width[i * 2 + 1];
451451
}

src/contrib/msc/framework/torch/torch_opcode.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ class TorchConvCodeGen : public TorchOpCode {
276276
const auto& weight = node()->WeightAt("weight");
277277
std::vector<int64_t> kernel_size;
278278
for (size_t i = 0; i < weight->Ndim(); i++) {
279-
if (weight->layout[i].name() == "I" || weight->layout[i].name() == "O") {
279+
if (weight->layout[i]->var->name_hint == "I" || weight->layout[i]->var->name_hint == "O") {
280280
continue;
281281
}
282282
kernel_size.push_back(weight->DimAt(i)->value);

src/contrib/msc/framework/tvm/relax_opcode.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ class RelaxConvCodeGen : public RelaxOpCode {
288288
const auto& out_layout = tir::Layout(out_layout_str);
289289
ffi::Array<Integer> expand_shape;
290290
for (size_t i = 0; i < node()->OutputAt(0)->Ndim(); i++) {
291-
if (out_layout[i].name() == "C") {
291+
if (out_layout[i]->var->name_hint == "C") {
292292
expand_shape.push_back(node()->OutputAt(0)->DimAt(i));
293293
} else {
294294
expand_shape.push_back(Integer(1));

src/relax/op/tensor/manipulate.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1968,7 +1968,7 @@ InferLayoutOutput InferLayoutTile(
19681968
// Same dimension: reorder repeats according to layout transformation.
19691969
// If len(repeats) < ndim, it's padded with 1s at the beginning.
19701970
for (int i = 0; i < ndim; ++i) {
1971-
const tir::LayoutAxis& axis = existing_layout_obj[i];
1971+
const tir::LayoutAxis& axis = tir::LayoutAxis::Get(existing_layout_obj[i]);
19721972
int pos_in_initial = initial_layout.IndexOf(axis);
19731973
ICHECK_NE(pos_in_initial, -1) << "Axis not found in initial layout";
19741974
// If len(repeats) < ndim, repeats are right-aligned.
@@ -1990,7 +1990,7 @@ InferLayoutOutput InferLayoutTile(
19901990
}
19911991
// Repeats for existing dimensions need to be permuted.
19921992
for (int i = 0; i < ndim; ++i) {
1993-
const tir::LayoutAxis& axis = existing_layout_obj[i];
1993+
const tir::LayoutAxis& axis = tir::LayoutAxis::Get(existing_layout_obj[i]);
19941994
int pos_in_initial = initial_layout.IndexOf(axis);
19951995
ICHECK_NE(pos_in_initial, -1) << "Axis not found in initial layout";
19961996
new_repeats.push_back(attrs->repeats[pos_in_initial + num_new_dims]);

0 commit comments

Comments
 (0)