Skip to content

Commit e6040e1

Browse files
committed
Apply same optimization pattern to TensorAdaptor
TensorAdaptor has identical InitializeElementSize and GetTransformAndItsUpperDimension patterns as TensorDescriptor. Apply the same optimization: - Replace nested static_for lambdas with find_in_tuple_of_sequences - Replace generate_tuple lambda with pack expansion Results: generate_tuple lambdas 100 -> 96 (4 events, 17ms eliminated)
1 parent bbf5c5e commit e6040e1

1 file changed

Lines changed: 23 additions & 35 deletions

File tree

include/ck/tensor_description/tensor_adaptor.hpp

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,29 @@ struct TensorAdaptor
4545
return BottomDimensionHiddenIds{};
4646
}
4747

48-
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
48+
// Helper to get length of a top dimension from transforms
49+
template <index_t I>
50+
__host__ __device__ static constexpr auto
51+
GetTopDimLengthFromTransforms(const Transforms& transforms)
4952
{
50-
const auto lengths = generate_tuple(
51-
[&](auto idim_top) {
52-
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top);
53-
54-
constexpr index_t itran = tmp[Number<0>{}];
55-
constexpr index_t idim_up = tmp[Number<1>{}];
56-
constexpr bool found = tmp[Number<2>{}];
57-
58-
static_assert(found == true,
59-
"wrong! not found matching transformation and upper-dimension");
60-
61-
const auto length =
62-
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
53+
constexpr auto result = find_in_tuple_of_sequences<TopDimensionHiddenIds::At(Number<I>{})>(
54+
UpperDimensionHiddenIdss{});
55+
static_assert(result.found, "wrong! not found matching transformation and upper-dimension");
56+
return transforms[Number<result.itran>{}].GetUpperLengths()[Number<result.idim_up>{}];
57+
}
6358

64-
return length;
65-
},
66-
Number<ndim_top_>{});
59+
// Compute element size using pack expansion instead of generate_tuple with lambda
60+
template <index_t... Is>
61+
__host__ __device__ static constexpr auto ComputeElementSizeImpl(const Transforms& transforms,
62+
Sequence<Is...>)
63+
{
64+
return (GetTopDimLengthFromTransforms<Is>(transforms) * ...);
65+
}
6766

68-
return container_product(lengths);
67+
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
68+
{
69+
return ComputeElementSizeImpl(transforms,
70+
typename arithmetic_sequence_gen<0, ndim_top_, 1>::type{});
6971
}
7072

7173
template <index_t IDim>
@@ -75,24 +77,10 @@ struct TensorAdaptor
7577

7678
constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);
7779

78-
index_t itran_found = 0;
79-
index_t idim_up_found = 0;
80-
bool found = false;
81-
82-
static_for<0, ntransform_, 1>{}([&](auto itran) {
83-
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
84-
85-
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
86-
if constexpr(up_dim_ids[idim_up] == idim_hidden)
87-
{
88-
itran_found = itran;
89-
idim_up_found = idim_up;
90-
found = true;
91-
}
92-
});
93-
});
80+
// Use compile-time search helper instead of nested static_for with lambdas
81+
constexpr auto result = find_in_tuple_of_sequences<idim_hidden>(UpperDimensionHiddenIdss{});
9482

95-
return make_tuple(itran_found, idim_up_found, found);
83+
return make_tuple(result.itran, result.idim_up, result.found);
9684
}
9785

9886
__host__ __device__ static constexpr index_t GetNumOfBottomDimension()

0 commit comments

Comments
 (0)