@@ -213,7 +213,15 @@ def twice(x):
213213 allocator = alloc , use_memory_pool = use_memory_pool )
214214
215215 from pyopencl .tools import ImmediateAllocator , MemoryPool
216- assert isinstance (actx .allocator ,
216+
217+ from arraycontext .impl .pytato import _PaddedAllocator
218+ alloc_to_check = actx .allocator
219+ if isinstance (alloc_to_check , _PaddedAllocator ):
220+ # On the Intel CPU runtime the actx wraps its allocator to pad
221+ # buffers (working around an out-of-bounds runtime store); check
222+ # the wrapped allocator's type.
223+ alloc_to_check = alloc_to_check ._allocator
224+ assert isinstance (alloc_to_check ,
217225 MemoryPool if use_memory_pool else ImmediateAllocator )
218226
219227 f = actx .compile (twice )
@@ -397,6 +405,26 @@ def twice(x):
397405 actx2 ._enable_profiling (True )
398406
399407
408+ def _auto_test_vs_ref (
409+ ref_t_unit : lp .TranslationUnit , cl_ctx : cl .Context ,
410+ t_unit : lp .TranslationUnit ):
411+ from pyopencl .tools import ImmediateAllocator
412+
413+ queue = cl .CommandQueue (cl_ctx )
414+ allocator = ImmediateAllocator (queue )
415+
416+ # The Intel CPU OpenCL runtime writes out of bounds past kernel output
417+ # buffers when executing partial work-groups, corrupting the host heap.
418+ # auto_test_vs_ref allocates its own buffers, so on that runtime pad them
419+ # (via _PaddedAllocator) so the stray stores land in valid memory.
420+ dev = cl_ctx .devices [0 ]
421+ if dev .type & cl .device_type .CPU and "intel" in dev .platform .name .lower ():
422+ from arraycontext .impl .pytato import _PaddedAllocator
423+ allocator = _PaddedAllocator (allocator )
424+
425+ lp .auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit , allocator = allocator )
426+
427+
400428def test_parallelize_disjoint_loop_sets_scalar ():
401429 from loopy .kernel .data import GroupInameTag , LocalInameTag
402430
@@ -484,7 +512,7 @@ def test_parallelize_disjoint_loop_sets_single_non_redn_iname():
484512 == {GroupInameTag (0 )}
485513 assert knl .iname_tags_of_type ("k" , (GroupInameTag , LocalInameTag )) == set ()
486514
487- lp . auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
515+ _auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
488516
489517
490518def test_parallelize_disjoint_loop_sets_multiple_non_redn_inames ():
@@ -524,7 +552,7 @@ def test_parallelize_disjoint_loop_sets_multiple_non_redn_inames():
524552 == {LocalInameTag (0 )}
525553 assert knl .iname_tags_of_type ("k" , (GroupInameTag , LocalInameTag )) == set ()
526554
527- lp . auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
555+ _auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
528556
529557
530558def test_parallelize_disjoint_loop_sets_only_redn_iname ():
@@ -563,7 +591,7 @@ def test_parallelize_disjoint_loop_sets_only_redn_iname():
563591 == {GroupInameTag (0 )}
564592 assert knl .iname_tags_of_type ("k" , (GroupInameTag , LocalInameTag )) == set ()
565593
566- lp . auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
594+ _auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
567595
568596
569597def test_parallelize_disjoint_loop_sets_mixed ():
@@ -601,7 +629,7 @@ def test_parallelize_disjoint_loop_sets_mixed():
601629 == {LocalInameTag (0 )}
602630 assert knl .iname_tags_of_type ("k" , (GroupInameTag , LocalInameTag )) == set ()
603631
604- lp . auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
632+ _auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
605633
606634
607635def test_parallelize_disjoint_loop_sets_multiple_independent_loop_sets ():
@@ -665,7 +693,7 @@ def test_parallelize_disjoint_loop_sets_multiple_independent_loop_sets():
665693 and insn .synchronization_kind == "global" ]
666694 assert len (gbarriers ) == 1
667695
668- lp . auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
696+ _auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
669697
670698
671699def test_parallelize_disjoint_loop_sets_multiple_dependent_loop_sets ():
@@ -733,7 +761,7 @@ def test_parallelize_disjoint_loop_sets_multiple_dependent_loop_sets():
733761 assert gbarrier .id in knl .id_to_insn ["loopset2insn1" ].depends_on
734762 assert gbarrier .id in knl .id_to_insn ["loopset2insn2" ].depends_on
735763
736- lp . auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
764+ _auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
737765
738766
739767def test_alias_global_temporaries ():
@@ -789,7 +817,7 @@ def global_temp(name: str):
789817 assert base_storages ["tmp2" ] != base_storages ["tmp1" ]
790818 assert len (set (base_storages .values ())) == 2
791819
792- lp . auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
820+ _auto_test_vs_ref (ref_t_unit , cl_ctx , t_unit )
793821
794822
795823if __name__ == "__main__" :
0 commit comments