@@ -45,27 +45,29 @@ struct TensorAdaptor
4545 return BottomDimensionHiddenIds{};
4646 }
4747
48- __host__ __device__ static constexpr auto InitializeElementSize (const Transforms& transforms)
48+ // Helper to get length of a top dimension from transforms
49+ template <index_t I>
50+ __host__ __device__ static constexpr auto
51+ GetTopDimLengthFromTransforms (const Transforms& transforms)
4952 {
50- const auto lengths = generate_tuple (
51- [&](auto idim_top) {
52- constexpr auto tmp = GetTransformAndItsUpperDimension (idim_top);
53-
54- constexpr index_t itran = tmp[Number<0 >{}];
55- constexpr index_t idim_up = tmp[Number<1 >{}];
56- constexpr bool found = tmp[Number<2 >{}];
57-
58- static_assert (found == true ,
59- " wrong! not found matching transformation and upper-dimension" );
60-
61- const auto length =
62- transforms[Number<itran>{}].GetUpperLengths ()[Number<idim_up>{}];
53+ constexpr auto result = find_in_tuple_of_sequences<TopDimensionHiddenIds::At (Number<I>{})>(
54+ UpperDimensionHiddenIdss{});
55+ static_assert (result.found , " wrong! not found matching transformation and upper-dimension" );
56+ return transforms[Number<result.itran >{}].GetUpperLengths ()[Number<result.idim_up >{}];
57+ }
6358
64- return length;
65- },
66- Number<ndim_top_>{});
59+ // Compute element size using pack expansion instead of generate_tuple with lambda
60+ template <index_t ... Is>
61+ __host__ __device__ static constexpr auto ComputeElementSizeImpl (const Transforms& transforms,
62+ Sequence<Is...>)
63+ {
64+ return (GetTopDimLengthFromTransforms<Is>(transforms) * ...);
65+ }
6766
68- return container_product (lengths);
67+ __host__ __device__ static constexpr auto InitializeElementSize (const Transforms& transforms)
68+ {
69+ return ComputeElementSizeImpl (transforms,
70+ typename arithmetic_sequence_gen<0 , ndim_top_, 1 >::type{});
6971 }
7072
7173 template <index_t IDim>
@@ -75,24 +77,10 @@ struct TensorAdaptor
7577
7678 constexpr index_t idim_hidden = TopDimensionHiddenIds::At (idim_top);
7779
78- index_t itran_found = 0 ;
79- index_t idim_up_found = 0 ;
80- bool found = false ;
81-
82- static_for<0 , ntransform_, 1 >{}([&](auto itran) {
83- constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
84-
85- static_for<0 , up_dim_ids.Size (), 1 >{}([&](auto idim_up) {
86- if constexpr (up_dim_ids[idim_up] == idim_hidden)
87- {
88- itran_found = itran;
89- idim_up_found = idim_up;
90- found = true ;
91- }
92- });
93- });
80+ // Use compile-time search helper instead of nested static_for with lambdas
81+ constexpr auto result = find_in_tuple_of_sequences<idim_hidden>(UpperDimensionHiddenIdss{});
9482
95- return make_tuple (itran_found, idim_up_found, found);
83+ return make_tuple (result. itran , result. idim_up , result. found );
9684 }
9785
9886 __host__ __device__ static constexpr index_t GetNumOfBottomDimension ()
0 commit comments