@@ -55,62 +55,52 @@ def test_one_block_mask(self):
5555 bidirectional_mask = np .asarray ([[0 , 1 , 1 , 1 , 0 , 0 ]])
5656 # pylint: disable=protected-access
5757 block_mask = _make_bidirectional_block_mask (bidirectional_mask )
58- expected_mask = np .asarray (
59- [
60- [
61- [False , False , False , False , False , False ],
62- [False , True , True , True , False , False ],
63- [False , True , True , True , False , False ],
64- [False , True , True , True , False , False ],
65- [False , False , False , False , False , False ],
66- [False , False , False , False , False , False ],
67- ]
68- ]
69- )
58+ expected_mask = np .asarray ([[
59+ [False , False , False , False , False , False ],
60+ [False , True , True , True , False , False ],
61+ [False , True , True , True , False , False ],
62+ [False , True , True , True , False , False ],
63+ [False , False , False , False , False , False ],
64+ [False , False , False , False , False , False ],
65+ ]])
7066 np .testing .assert_array_equal (block_mask , expected_mask )
7167
7268 def test_two_blocks_mask (self ):
7369 bidirectional_mask = np .asarray ([[0 , 1 , 1 , 0 , 1 , 1 ]])
7470 # pylint: disable=protected-access
7571 block_mask = _make_bidirectional_block_mask (bidirectional_mask )
76- expected_mask = np .asarray (
77- [
78- [
79- [False , False , False , False , False , False ],
80- [False , True , True , False , False , False ],
81- [False , True , True , False , False , False ],
82- [False , False , False , False , False , False ],
83- [False , False , False , False , True , True ],
84- [False , False , False , False , True , True ],
85- ]
86- ]
87- )
72+ expected_mask = np .asarray ([[
73+ [False , False , False , False , False , False ],
74+ [False , True , True , False , False , False ],
75+ [False , True , True , False , False , False ],
76+ [False , False , False , False , False , False ],
77+ [False , False , False , False , True , True ],
78+ [False , False , False , False , True , True ],
79+ ]])
8880 np .testing .assert_array_equal (block_mask , expected_mask )
8981
9082 def test_batch_block_masks (self ):
9183 bidirectional_mask = np .asarray ([[0 , 1 , 1 , 1 , 0 , 0 ], [0 , 1 , 1 , 0 , 1 , 1 ]])
9284 # pylint: disable=protected-access
9385 block_mask = _make_bidirectional_block_mask (bidirectional_mask )
94- expected_mask = np .asarray (
86+ expected_mask = np .asarray ([
9587 [
96- [
97- [False , False , False , False , False , False ],
98- [False , True , True , True , False , False ],
99- [False , True , True , True , False , False ],
100- [False , True , True , True , False , False ],
101- [False , False , False , False , False , False ],
102- [False , False , False , False , False , False ],
103- ],
104- [
105- [False , False , False , False , False , False ],
106- [False , True , True , False , False , False ],
107- [False , True , True , False , False , False ],
108- [False , False , False , False , False , False ],
109- [False , False , False , False , True , True ],
110- [False , False , False , False , True , True ],
111- ],
112- ]
113- )
88+ [False , False , False , False , False , False ],
89+ [False , True , True , True , False , False ],
90+ [False , True , True , True , False , False ],
91+ [False , True , True , True , False , False ],
92+ [False , False , False , False , False , False ],
93+ [False , False , False , False , False , False ],
94+ ],
95+ [
96+ [False , False , False , False , False , False ],
97+ [False , True , True , False , False , False ],
98+ [False , True , True , False , False , False ],
99+ [False , False , False , False , False , False ],
100+ [False , False , False , False , True , True ],
101+ [False , False , False , False , True , True ],
102+ ],
103+ ])
114104 np .testing .assert_array_equal (block_mask , expected_mask )
115105
116106 def test_empty_block_mask (self ):
@@ -140,34 +130,24 @@ def test_combine_with_causal_mask(self):
140130 # pylint: disable=protected-access
141131 image_mask = _make_bidirectional_block_mask (bidirectional_mask )
142132 combined_mask = causal_mask | image_mask [:, None , None , ...]
143- expected_mask = np .asarray (
144- [
145- [
146- [
147- [
148- [True , False , False , False , False , False ],
149- [True , True , True , True , False , False ],
150- [True , True , True , True , False , False ],
151- [True , True , True , True , False , False ],
152- [True , True , True , True , True , False ],
153- [True , True , True , True , True , True ],
154- ]
155- ]
156- ],
157- [
158- [
159- [
160- [True , False , False , False , False , False ],
161- [True , True , True , False , False , False ],
162- [True , True , True , False , False , False ],
163- [True , True , True , True , False , False ],
164- [True , True , True , True , True , True ],
165- [True , True , True , True , True , True ],
166- ]
167- ]
168- ],
169- ]
170- )
133+ expected_mask = np .asarray ([
134+ [[[
135+ [True , False , False , False , False , False ],
136+ [True , True , True , True , False , False ],
137+ [True , True , True , True , False , False ],
138+ [True , True , True , True , False , False ],
139+ [True , True , True , True , True , False ],
140+ [True , True , True , True , True , True ],
141+ ]]],
142+ [[[
143+ [True , False , False , False , False , False ],
144+ [True , True , True , False , False , False ],
145+ [True , True , True , False , False , False ],
146+ [True , True , True , True , False , False ],
147+ [True , True , True , True , True , True ],
148+ [True , True , True , True , True , True ],
149+ ]]],
150+ ])
171151 np .testing .assert_array_equal (combined_mask , expected_mask )
172152
173153
0 commit comments