Skip to content

Commit f8d808e

Browse files
committed
Replace O(N) recursive sequence_map_inverse with O(1) pack expansion
Use constexpr loop in find_source_index to locate permutation inverse indices, then expand via pack expansion for O(1) template instantiation depth instead of O(N) recursive template instantiation.
1 parent 44f481a commit f8d808e

1 file changed

Lines changed: 35 additions & 20 deletions

File tree

include/ck/utility/sequence.hpp

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -525,31 +525,46 @@ struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMa
525525
{
526526
};
527527

528-
template <typename SeqMap>
529-
struct sequence_map_inverse
528+
// Invert a permutation sequence using O(1) template instantiations
529+
// For X2Y = {a, b, c, ...}, computes Y2X where Y2X[X2Y[i]] = i
530+
//
531+
// Uses a single constexpr member function instead of N template instantiations.
532+
// The function is called N times via pack expansion, but only one function exists.
533+
namespace detail {
534+
// TODO: Replace with std::array when HIPRTC supports it
535+
template <typename T, index_t N>
536+
struct ConstexprArray
530537
{
531-
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
532-
struct sequence_map_inverse_impl
533-
{
534-
static constexpr auto new_y2x =
535-
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
538+
T data[N];
536539

537-
using type =
538-
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
539-
type;
540-
};
540+
constexpr const T& operator[](index_t i) const { return data[i]; }
541+
};
542+
} // namespace detail
543+
544+
template <index_t... Is>
545+
struct sequence_map_inverse<Sequence<Is...>>
546+
{
547+
private:
548+
static constexpr detail::ConstexprArray<index_t, sizeof...(Is)> values = {{Is...}};
541549

542-
template <typename X2Y, typename WorkingY2X, index_t XBegin>
543-
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
550+
static constexpr index_t find_inverse(index_t target)
544551
{
545-
using type = WorkingY2X;
546-
};
552+
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Is)); ++i)
553+
{
554+
if(values[i] == target)
555+
return i;
556+
}
557+
return -1; // should not reach for valid permutation
558+
}
547559

548-
using type =
549-
typename sequence_map_inverse_impl<SeqMap,
550-
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
551-
0,
552-
SeqMap::Size()>::type;
560+
template <index_t... Positions>
561+
static constexpr auto compute(Sequence<Positions...>)
562+
{
563+
return Sequence<find_inverse(Positions)...>{};
564+
}
565+
566+
public:
567+
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
553568
};
554569

555570
template <index_t... Xs, index_t... Ys>

0 commit comments

Comments
 (0)