55#include < Formats/NativeWriter.h>
66#include < Formats/TemporaryFileStream.h>
77
8+ #include < Common/logger_useful.h>
9+ #include < Common/thread_local_rng.h>
810#include < Compression/CompressedWriteBuffer.h>
11+ #include < Core/ProtocolDefines.h>
912#include < Disks/IVolume.h>
1013#include < Disks/TemporaryFileOnDisk.h>
1114#include < IO/WriteBufferFromTemporaryFile.h>
12- #include < Common/logger_useful.h>
13- #include < Common/thread_local_rng.h>
1415
16+ #include < base/FnTraits.h>
1517#include < fmt/format.h>
1618
1719namespace DB
@@ -32,7 +34,7 @@ namespace
3234 explicit FileBlockReader (const TemporaryFileOnDisk & file, const Block & header)
3335 : file_reader{file.getDisk ()->readFile (file.getPath ())}
3436 , compressed_reader{*file_reader}
35- , block_reader{compressed_reader, header, 0 }
37+ , block_reader{compressed_reader, header, DBMS_TCP_PROTOCOL_VERSION }
3638 {
3739 }
3840
@@ -83,7 +85,7 @@ namespace
8385 class MergingBlockReader
8486 {
8587 public:
86- explicit MergingBlockReader (FileBlockReaderPtr reader_, size_t desired_block_size = DEFAULT_BLOCK_SIZE)
88+ explicit MergingBlockReader (FileBlockReaderPtr reader_, size_t desired_block_size = DEFAULT_BLOCK_SIZE * 8 )
8789 : reader{std::move (reader_)}, accumulator{desired_block_size}
8890 {
8991 }
@@ -167,7 +169,7 @@ namespace
167169 void reset (const Block & sample)
168170 {
169171 header = sample.cloneEmpty ();
170- output.emplace (compressed_writer, 0 , header);
172+ output.emplace (compressed_writer, DBMS_TCP_PROTOCOL_VERSION , header);
171173 }
172174
173175 Block header;
@@ -180,6 +182,29 @@ namespace
180182 size_t num_blocks = 0 ;
181183 };
182184
185+ std::deque<size_t > generateRandomPermutation (size_t from, size_t to)
186+ {
187+ size_t size = to - from;
188+ std::deque<size_t > indices (size);
189+ std::iota (indices.begin (), indices.end (), from);
190+ std::shuffle (indices.begin (), indices.end (), thread_local_rng);
191+ return indices;
192+ }
193+
194+ // Try to apply @callback in the order specified in @indices
195+ // Until it returns true for each index in the @indices.
196+ void retryForEach (std::deque<size_t > indices, Fn<bool (size_t )> auto callback)
197+ {
198+ while (!indices.empty ())
199+ {
200+ size_t bucket = indices.front ();
201+ indices.pop_front ();
202+
203+ if (!callback (bucket))
204+ indices.push_back (bucket);
205+ }
206+ }
207+
183208}
184209
185210class GraceHashJoin ::FileBucket
@@ -202,30 +227,10 @@ class GraceHashJoin::FileBucket
202227 {
203228 }
204229
205- void addRightBlock (const Block & block)
206- {
207- ensureState (State::WRITING_BLOCKS);
208- right_file.write (block);
209- }
210-
211- bool tryAddLeftBlock (const Block & block)
212- {
213- ensureState (State::WRITING_BLOCKS);
214- std::unique_lock lock{left_file_mutex, std::try_to_lock};
215- if (!lock.owns_lock ())
216- {
217- return false ;
218- }
219- left_file.write (block);
220- return true ;
221- }
222-
223- void addLeftBlock (const Block & block)
224- {
225- ensureState (State::WRITING_BLOCKS);
226- std::unique_lock lock{left_file_mutex};
227- left_file.write (block);
228- }
230+ void addLeftBlock (const Block & block) { return addBlockImpl (block, left_file_mutex, left_file); }
231+ void addRightBlock (const Block & block) { return addBlockImpl (block, right_file_mutex, right_file); }
232+ bool tryAddLeftBlock (const Block & block) { return tryAddBlockImpl (block, left_file_mutex, left_file); }
233+ bool tryAddRightBlock (const Block & block) { return tryAddBlockImpl (block, right_file_mutex, right_file); }
229234
230235 void startJoining ()
231236 {
@@ -253,7 +258,28 @@ class GraceHashJoin::FileBucket
253258
254259 MergingBlockReader openRightTableReader () const { return right_file.makeReader (); }
255260
261+ std::scoped_lock<std::mutex> lockJoin () { return std::scoped_lock{join_mutex}; }
262+
256263private:
264+ bool tryAddBlockImpl (const Block & block, std::mutex & mutex, FileBlockWriter & writer)
265+ {
266+ ensureState (State::WRITING_BLOCKS);
267+ std::unique_lock lock{mutex, std::try_to_lock};
268+ if (!lock.owns_lock ())
269+ {
270+ return false ;
271+ }
272+ writer.write (block);
273+ return true ;
274+ }
275+
276+ void addBlockImpl (const Block & block, std::mutex & mutex, FileBlockWriter & writer)
277+ {
278+ ensureState (State::WRITING_BLOCKS);
279+ std::unique_lock lock{mutex};
280+ writer.write (block);
281+ }
282+
257283 void transition (State expected, State desired)
258284 {
259285 State prev = state.exchange (desired);
@@ -271,10 +297,18 @@ class GraceHashJoin::FileBucket
271297 FileBlockWriter left_file;
272298 FileBlockWriter right_file;
273299 std::mutex left_file_mutex;
300+ std::mutex right_file_mutex;
301+ std::mutex join_mutex; // / Protects external in-memory join
274302 const FileBucket * parent;
275303 std::atomic<State> state;
276304};
277305
306+ class GraceHashJoin ::InMemoryJoin : public HashJoin
307+ {
308+ public:
309+ using HashJoin::HashJoin;
310+ };
311+
278312GraceHashJoin::GraceHashJoin (
279313 ContextPtr context_, std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block_, bool any_take_last_row_)
280314 : log{&Poco::Logger::get (" GraceHashJoin" )}
@@ -296,6 +330,8 @@ GraceHashJoin::GraceHashJoin(
296330 LOG_TRACE (log, " Initialize {} buckets" , initial_num_buckets);
297331}
298332
333+ GraceHashJoin::~GraceHashJoin () = default ;
334+
299335bool GraceHashJoin::addJoinedBlock (const Block & block, bool /* check_limits*/ )
300336{
301337 Block materialized = materializeBlock (block);
@@ -344,7 +380,8 @@ GraceHashJoin::BucketsSnapshot GraceHashJoin::rehash(size_t desired_size)
344380
345381 if (next_size > max_num_buckets)
346382 {
347- throw Exception (ErrorCodes::LIMIT_EXCEEDED, " Too many grace hash join buckets, consider increasing max_rows_in_join/max_bytes_in_join" );
383+ throw Exception (
384+ ErrorCodes::LIMIT_EXCEEDED, " Too many grace hash join buckets, consider increasing max_rows_in_join/max_bytes_in_join" );
348385 }
349386
350387 next_snapshot->reserve (next_size);
@@ -402,24 +439,17 @@ void GraceHashJoin::joinBlock(Block & block, std::shared_ptr<ExtraBlock> & /*not
402439 if (not_processed)
403440 throw Exception (ErrorCodes::LOGICAL_ERROR, " Unsupported hash join type" );
404441
405- std::deque<size_t > indices (snapshot->size () - 1 );
406- std::iota (indices.begin (), indices.end (), 1 );
407- std::shuffle (indices.begin (), indices.end (), thread_local_rng);
408- while (!indices.empty ())
409- {
410- size_t bucket = indices.front ();
411- indices.pop_front ();
412-
413- Block & block_shard = blocks[bucket];
414- if (block_shard.rows () == 0 )
415- {
416- continue ;
417- }
418- if (!snapshot->at (bucket)->tryAddLeftBlock (block_shard))
442+ // We need to skip the first bucket that is already joined in memory, so we start with 1.
443+ auto indices = generateRandomPermutation (1 , snapshot->size ());
444+ retryForEach (
445+ indices,
446+ [&](size_t bucket)
419447 {
420- indices.push_back (bucket);
421- }
422- }
448+ Block & block_shard = blocks[bucket];
449+ if (block_shard.rows () == 0 )
450+ return true ;
451+ return snapshot->at (bucket)->tryAddLeftBlock (block_shard);
452+ });
423453}
424454
425455size_t GraceHashJoin::getTotalRowCount () const
@@ -541,9 +571,9 @@ Block GraceHashJoin::joinNextBlockInBucket(DelayedBlocks & iterator)
541571 return block;
542572}
543573
544- std::unique_ptr<HashJoin > GraceHashJoin::makeInMemoryJoin ()
574+ std::unique_ptr<GraceHashJoin::InMemoryJoin > GraceHashJoin::makeInMemoryJoin ()
545575{
546- return std::make_unique<HashJoin >(table_join, right_sample_block, any_take_last_row);
576+ return std::make_unique<InMemoryJoin >(table_join, right_sample_block, any_take_last_row);
547577}
548578
549579void GraceHashJoin::fillInMemoryJoin (InMemoryJoinPtr & join, FileBucket * bucket)
@@ -562,31 +592,44 @@ void GraceHashJoin::addJoinedBlockImpl(InMemoryJoinPtr & join, size_t bucket_ind
562592 BucketsSnapshot snapshot = buckets.get ();
563593 Blocks blocks = scatterBlock<true >(block, snapshot->size ());
564594
565- join->addJoinedBlock (blocks[bucket_index], /* check_limits=*/ false );
566-
567- // We need to rebuild block without bucket_index part in case of overflow.
568- bool overflow = !fitsInMemory (join.get ());
569- Block to_write;
570- if (overflow)
595+ // Add block to the in-memory join
571596 {
572- blocks.erase (blocks.begin () + bucket_index);
573- to_write = concatenateBlocks (blocks);
574- }
597+ auto guard = snapshot->at (bucket_index)->lockJoin ();
598+ join->addJoinedBlock (blocks[bucket_index], /* check_limits=*/ false );
575599
576- while (overflow)
577- {
578- snapshot = rehash (snapshot->size () * 2 );
579- rehashInMemoryJoin (join, snapshot, bucket_index);
580- blocks = scatterBlock<true >(to_write, snapshot->size ());
581- overflow = !fitsInMemory (join.get ());
582- }
600+ // We need to rebuild block without bucket_index part in case of overflow.
601+ bool overflow = !fitsInMemory (join.get ());
602+ Block to_write;
603+ if (overflow)
604+ {
605+ blocks.erase (blocks.begin () + bucket_index);
606+ to_write = concatenateBlocks (blocks);
607+ }
583608
584- assert (blocks.empty () || blocks.size () == snapshot->size ());
585- for (size_t i = 1 ; i < blocks.size (); ++i)
586- {
587- if (i != bucket_index && blocks[i].rows ())
588- snapshot->at (i)->addRightBlock (blocks[i]);
609+ while (overflow)
610+ {
611+ snapshot = rehash (snapshot->size () * 2 );
612+ rehashInMemoryJoin (join, snapshot, bucket_index);
613+ blocks = scatterBlock<true >(to_write, snapshot->size ());
614+ overflow = !fitsInMemory (join.get ());
615+ }
589616 }
617+
618+ if (blocks.empty ())
619+ // All blocks were added to the @join
620+ return ;
621+
622+ // Write the rest of the blocks to the disk buckets
623+ assert (blocks.size () == snapshot->size ());
624+ auto indices = generateRandomPermutation (1 , snapshot->size ());
625+ retryForEach (
626+ indices,
627+ [&](size_t bucket)
628+ {
629+ if (bucket == bucket_index || !blocks[bucket].rows ())
630+ return true ;
631+ return snapshot->at (bucket)->tryAddRightBlock (blocks[bucket]);
632+ });
590633}
591634
592635template <bool right>
0 commit comments