@@ -113,6 +113,7 @@ class CodeGen_X86 : public CodeGen_Posix {
113113 void codegen_vector_reduce (const VectorReduce *, const Expr &init) override ;
114114 // @}
115115
116+ std::vector<llvm::Value *> deinterleave_vector (llvm::Value *, int ) override ;
116117 llvm::Value *interleave_vectors (const std::vector<llvm::Value *> &) override ;
117118
118119private:
@@ -910,6 +911,30 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init
910911 CodeGen_Posix::codegen_vector_reduce (op, init);
911912}
912913
914+ std::vector<Value *> CodeGen_X86::deinterleave_vector (Value *vec, int num_vecs) {
915+ int vec_elements = get_vector_num_elements (vec->getType ()) / num_vecs;
916+ const size_t element_bits = vec->getType ()->getScalarSizeInBits ();
917+ if (target.has_feature (Target::AVX) &&
918+ is_power_of_two (num_vecs) &&
919+ is_power_of_two (vec_elements) &&
920+ (int )(vec_elements * num_vecs * element_bits) > native_vector_bits ()) {
921+
922+ // Our interleaving logic below supports this case
923+ std::vector<Value *> slices (vec_elements);
924+ for (int i = 0 ; i < vec_elements; i++) {
925+ slices[i] = slice_vector (vec, i * num_vecs, num_vecs);
926+ }
927+ vec = interleave_vectors (slices);
928+ std::vector<Value *> result (num_vecs);
929+ for (int i = 0 ; i < num_vecs; i++) {
930+ result[i] = slice_vector (vec, i * vec_elements, vec_elements);
931+ }
932+ return result;
933+ } else {
934+ return CodeGen_Posix::deinterleave_vector (vec, num_vecs);
935+ }
936+ }
937+
913938Value *CodeGen_X86::interleave_vectors (const std::vector<Value *> &vecs) {
914939 // Only use x86-specific interleaving for AVX and above
915940 if (vecs.empty () || !target.has_feature (Target::AVX)) {
@@ -1146,6 +1171,24 @@ Value *CodeGen_X86::interleave_vectors(const std::vector<Value *> &vecs) {
11461171
11471172 // Now we define helpers for each instruction we are going to use
11481173
1174+ // Useful for debugging or enhancing this algorithm
1175+ /*
1176+ auto dump_bits = [&]() {
1177+ for (int b : l_bits) {
1178+ debug(0) << b << " ";
1179+ }
1180+ debug(0) << "| ";
1181+ for (int b : s_bits) {
1182+ debug(0) << b << " ";
1183+ }
1184+ debug(0) << "| ";
1185+ for (int b : v_bits) {
1186+ debug(0) << b << " ";
1187+ }
1188+ debug(0) << "\n";
1189+ };
1190+ */
1191+
11491192 // unpckl/h instruction
11501193 auto unpck = [&](Value *a, Value *b) -> std::pair<Value *, Value *> {
11511194 int n = get_vector_num_elements (a->getType ());
@@ -1258,6 +1301,99 @@ Value *CodeGen_X86::interleave_vectors(const std::vector<Value *> &vecs) {
12581301 s_bits.pop_back ();
12591302 }
12601303
1304+ // If adjacent vectors are shuffles of the same underlying vector(s),
1305+ // concatenate pairs, because this is probably free.
1306+ while ((size_t )vec_elements < elems_per_native_vec && !v_bits.empty ()) {
1307+ std::vector<Value *> new_v;
1308+ new_v.reserve (v.size () / 2 );
1309+ bool fail = false ;
1310+ std::vector<int > indices;
1311+ indices.reserve (vec_elements * 2 );
1312+ for (size_t i = 0 ; i < v.size (); i += 2 ) {
1313+ ShuffleVectorInst *a = llvm::dyn_cast<ShuffleVectorInst>(v[i]);
1314+ ShuffleVectorInst *b = llvm::dyn_cast<ShuffleVectorInst>(v[i + 1 ]);
1315+ if (a &&
1316+ b &&
1317+ a->getOperand (0 ) == b->getOperand (0 ) &&
1318+ a->getOperand (1 ) == b->getOperand (1 )) {
1319+
1320+ // Concatenate the two shuffles
1321+ indices.clear ();
1322+ for (int j : a->getShuffleMask ()) {
1323+ indices.push_back (j);
1324+ }
1325+ for (int j : b->getShuffleMask ()) {
1326+ indices.push_back (j);
1327+ }
1328+ new_v.push_back (shuffle_vectors (a->getOperand (0 ), a->getOperand (1 ), indices));
1329+ } else {
1330+ fail = true ;
1331+ }
1332+ }
1333+ if (fail) {
1334+ break ;
1335+ }
1336+
1337+ v.swap (new_v);
1338+ // The lowest vector bit becomes the highest lane or slice bit
1339+ if ((size_t )vec_elements < elems_per_slice) {
1340+ l_bits.push_back (v_bits[0 ]);
1341+ } else {
1342+ s_bits.push_back (v_bits[0 ]);
1343+ }
1344+ v_bits.erase (v_bits.begin ());
1345+ vec_elements *= 2 ;
1346+ }
1347+
1348+ if (final_num_s_bits > 1 &&
1349+ (size_t )vec_elements == elems_per_native_vec &&
1350+ (size_t )v_bits[0 ] >= l_bits.size () - 1 ) {
1351+ // A big binary shuffle of adjacent pairs will fix the l bits
1352+ // entirely. AVX-512 has these. Yes, this will use registers for the
1353+ // shuffle indices, but the alternative requires very many unpck
1354+ // operations to completely cycle out the v_bits that are hiding in the
1355+ // bottom of the l_bits.
1356+
1357+ std::vector<int > lo_indices (vec_elements);
1358+ std::vector<int > hi_indices (vec_elements);
1359+ std::vector<int > sorted_bits = l_bits;
1360+ sorted_bits.insert (sorted_bits.end (), s_bits.begin (), s_bits.end ());
1361+ sorted_bits.push_back (v_bits[0 ]);
1362+ std::sort (sorted_bits.begin (), sorted_bits.end ());
1363+ std::vector<int > idx_of_bit (l_bits.size () + s_bits.size () + v_bits.size (), 0 );
1364+ for (size_t b = 0 ; b < sorted_bits.size (); b++) {
1365+ idx_of_bit[sorted_bits[b]] = b;
1366+ }
1367+
1368+ for (size_t dst_idx = 0 ; dst_idx < (size_t )vec_elements * 2 ; dst_idx++) {
1369+ size_t src_idx = 0 ;
1370+ for (size_t b = 0 ; b < l_bits.size (); b++) {
1371+ src_idx |= ((dst_idx >> idx_of_bit[l_bits[b]]) & 1 ) << b;
1372+ }
1373+ for (size_t b = 0 ; b < s_bits.size (); b++) {
1374+ src_idx |= ((dst_idx >> idx_of_bit[s_bits[b]]) & 1 ) << (b + l_bits.size ());
1375+ }
1376+ src_idx |= ((dst_idx >> idx_of_bit[v_bits[0 ]]) & 1 ) << (l_bits.size () + s_bits.size ());
1377+ if (dst_idx < (size_t )vec_elements) {
1378+ lo_indices[dst_idx] = (int )src_idx;
1379+ } else {
1380+ hi_indices[dst_idx - vec_elements] = (int )src_idx;
1381+ }
1382+ }
1383+
1384+ for_all_pairs (0 , [&](auto *a, auto *b) {
1385+ Value *lo = shuffle_vectors (*a, *b, lo_indices);
1386+ Value *hi = shuffle_vectors (*a, *b, hi_indices);
1387+ *a = lo;
1388+ *b = hi;
1389+ });
1390+
1391+ auto first_s_bit = sorted_bits.begin () + l_bits.size ();
1392+ std::copy (sorted_bits.begin (), first_s_bit, l_bits.begin ());
1393+ std::copy (first_s_bit, first_s_bit + s_bits.size (), s_bits.begin ());
1394+ v_bits[0 ] = sorted_bits.back ();
1395+ }
1396+
12611397 // Interleave pairs if we have vectors smaller than a single slice. Choosing
12621398 // which pairs to interleave is important because we want to pull down v
12631399 // bits that are destined to end up as l bits, and we want to pull them down
@@ -1300,9 +1436,8 @@ Value *CodeGen_X86::interleave_vectors(const std::vector<Value *> &vecs) {
13001436
13011437 // Concatenate/repack to get at least the desired number of slice bits.
13021438 while ((int )s_bits.size () < final_num_s_bits && !v_bits.empty ()) {
1303- int desired_low_slice_bit = ctz64 (elems_per_slice);
1304- int desired_high_slice_bit = desired_low_slice_bit + 1 ;
1305-
1439+ const int desired_low_slice_bit = ctz64 (elems_per_slice);
1440+ const int desired_high_slice_bit = desired_low_slice_bit + 1 ;
13061441 int bit;
13071442 if (!s_bits.empty () &&
13081443 s_bits[0 ] == desired_low_slice_bit) {
@@ -1340,37 +1475,44 @@ Value *CodeGen_X86::interleave_vectors(const std::vector<Value *> &vecs) {
13401475 // Now we have at least two whole vectors. Next we try to finalize lane bits using
13411476 // unpck instructions.
13421477 while (l_bits[0 ] != 0 ) {
1343- int bit = std::min (l_bits[0 ], (int )ctz64 (elems_per_slice)) - 1 ;
1478+
1479+ int first_s_bit = (int )ctz64 (elems_per_slice);
1480+ int bit = std::min (l_bits[0 ], first_s_bit) - 1 ;
13441481
13451482 auto vb_it = std::find (v_bits.begin (), v_bits.end (), bit);
13461483
13471484 // internal_assert(vb_it != v_bits.end());
13481485 if (vb_it == v_bits.end ()) {
13491486 // The next bit is not in vector bits. It must be hiding in the
13501487 // slice bits due to earlier concatenation. Move it into the v_bits
1351- // with a shufi
1488+ // with a shufi. We'll need to pick a v bit to take its place,
1489+ // ideally one destined to end up in the s bits.
1490+ vb_it = std::find_if (v_bits.begin (), v_bits.end (), [&](int b) { return b >= first_s_bit; });
1491+ if (vb_it == v_bits.end ()) {
1492+ vb_it = v_bits.begin ();
1493+ }
1494+
13521495 if (s_bits.back () == bit) {
13531496 // It's the last (or sole) slice bit. Swap it with the first v bit
1354- std::swap (s_bits.back (), v_bits[ 0 ] );
1355- for_all_pairs (0 , [&](auto *a, auto *b) {
1497+ std::swap (s_bits.back (), *vb_it );
1498+ for_all_pairs (vb_it - v_bits. begin () , [&](auto *a, auto *b) {
13561499 auto [lo, hi] = shufi (*a, *b, false );
13571500 *a = lo;
13581501 *b = hi;
13591502 });
13601503 } else {
13611504 internal_assert (s_bits.size () == 2 && s_bits[0 ] == bit);
13621505 // It's the low slice bit. We need shufi with crossover.
1363- int v_bit = v_bits[ 0 ] ;
1364- v_bits[ 0 ] = s_bits[0 ];
1506+ int v_bit = *vb_it ;
1507+ *vb_it = s_bits[0 ];
13651508 s_bits[0 ] = s_bits[1 ];
13661509 s_bits[1 ] = v_bit;
1367- for_all_pairs (0 , [&](auto *a, auto *b) {
1510+ for_all_pairs (vb_it - v_bits. begin () , [&](auto *a, auto *b) {
13681511 auto [lo, hi] = shufi (*a, *b, true );
13691512 *a = lo;
13701513 *b = hi;
13711514 });
13721515 }
1373- vb_it = v_bits.begin ();
13741516 }
13751517
13761518 int j = vb_it - v_bits.begin ();
0 commit comments