Skip to content

Commit 2fc4407

Browse files
author
Abhinav Goel
committed
Fix pyink formatting in attention_test.py
1 parent d7fd385 commit 2fc4407

1 file changed

Lines changed: 51 additions & 71 deletions

File tree

tests/unit/attention_test.py

Lines changed: 51 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -55,62 +55,52 @@ def test_one_block_mask(self):
5555
bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0]])
5656
# pylint: disable=protected-access
5757
block_mask = _make_bidirectional_block_mask(bidirectional_mask)
58-
expected_mask = np.asarray(
59-
[
60-
[
61-
[False, False, False, False, False, False],
62-
[False, True, True, True, False, False],
63-
[False, True, True, True, False, False],
64-
[False, True, True, True, False, False],
65-
[False, False, False, False, False, False],
66-
[False, False, False, False, False, False],
67-
]
68-
]
69-
)
58+
expected_mask = np.asarray([[
59+
[False, False, False, False, False, False],
60+
[False, True, True, True, False, False],
61+
[False, True, True, True, False, False],
62+
[False, True, True, True, False, False],
63+
[False, False, False, False, False, False],
64+
[False, False, False, False, False, False],
65+
]])
7066
np.testing.assert_array_equal(block_mask, expected_mask)
7167

7268
def test_two_blocks_mask(self):
7369
bidirectional_mask = np.asarray([[0, 1, 1, 0, 1, 1]])
7470
# pylint: disable=protected-access
7571
block_mask = _make_bidirectional_block_mask(bidirectional_mask)
76-
expected_mask = np.asarray(
77-
[
78-
[
79-
[False, False, False, False, False, False],
80-
[False, True, True, False, False, False],
81-
[False, True, True, False, False, False],
82-
[False, False, False, False, False, False],
83-
[False, False, False, False, True, True],
84-
[False, False, False, False, True, True],
85-
]
86-
]
87-
)
72+
expected_mask = np.asarray([[
73+
[False, False, False, False, False, False],
74+
[False, True, True, False, False, False],
75+
[False, True, True, False, False, False],
76+
[False, False, False, False, False, False],
77+
[False, False, False, False, True, True],
78+
[False, False, False, False, True, True],
79+
]])
8880
np.testing.assert_array_equal(block_mask, expected_mask)
8981

9082
def test_batch_block_masks(self):
9183
bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]])
9284
# pylint: disable=protected-access
9385
block_mask = _make_bidirectional_block_mask(bidirectional_mask)
94-
expected_mask = np.asarray(
86+
expected_mask = np.asarray([
9587
[
96-
[
97-
[False, False, False, False, False, False],
98-
[False, True, True, True, False, False],
99-
[False, True, True, True, False, False],
100-
[False, True, True, True, False, False],
101-
[False, False, False, False, False, False],
102-
[False, False, False, False, False, False],
103-
],
104-
[
105-
[False, False, False, False, False, False],
106-
[False, True, True, False, False, False],
107-
[False, True, True, False, False, False],
108-
[False, False, False, False, False, False],
109-
[False, False, False, False, True, True],
110-
[False, False, False, False, True, True],
111-
],
112-
]
113-
)
88+
[False, False, False, False, False, False],
89+
[False, True, True, True, False, False],
90+
[False, True, True, True, False, False],
91+
[False, True, True, True, False, False],
92+
[False, False, False, False, False, False],
93+
[False, False, False, False, False, False],
94+
],
95+
[
96+
[False, False, False, False, False, False],
97+
[False, True, True, False, False, False],
98+
[False, True, True, False, False, False],
99+
[False, False, False, False, False, False],
100+
[False, False, False, False, True, True],
101+
[False, False, False, False, True, True],
102+
],
103+
])
114104
np.testing.assert_array_equal(block_mask, expected_mask)
115105

116106
def test_empty_block_mask(self):
@@ -140,34 +130,24 @@ def test_combine_with_causal_mask(self):
140130
# pylint: disable=protected-access
141131
image_mask = _make_bidirectional_block_mask(bidirectional_mask)
142132
combined_mask = causal_mask | image_mask[:, None, None, ...]
143-
expected_mask = np.asarray(
144-
[
145-
[
146-
[
147-
[
148-
[True, False, False, False, False, False],
149-
[True, True, True, True, False, False],
150-
[True, True, True, True, False, False],
151-
[True, True, True, True, False, False],
152-
[True, True, True, True, True, False],
153-
[True, True, True, True, True, True],
154-
]
155-
]
156-
],
157-
[
158-
[
159-
[
160-
[True, False, False, False, False, False],
161-
[True, True, True, False, False, False],
162-
[True, True, True, False, False, False],
163-
[True, True, True, True, False, False],
164-
[True, True, True, True, True, True],
165-
[True, True, True, True, True, True],
166-
]
167-
]
168-
],
169-
]
170-
)
133+
expected_mask = np.asarray([
134+
[[[
135+
[True, False, False, False, False, False],
136+
[True, True, True, True, False, False],
137+
[True, True, True, True, False, False],
138+
[True, True, True, True, False, False],
139+
[True, True, True, True, True, False],
140+
[True, True, True, True, True, True],
141+
]]],
142+
[[[
143+
[True, False, False, False, False, False],
144+
[True, True, True, False, False, False],
145+
[True, True, True, False, False, False],
146+
[True, True, True, True, False, False],
147+
[True, True, True, True, True, True],
148+
[True, True, True, True, True, True],
149+
]]],
150+
])
171151
np.testing.assert_array_equal(combined_mask, expected_mask)
172152

173153

0 commit comments

Comments
 (0)