@@ -107,14 +107,14 @@ def flash_attention_block_masked(
107107 # `l` is initialized to 0 since no blocks have been processed yet and the sum
108108 # is 0.
109109 l = jnp .zeros (
110- (batch_size , num_kv_heads , q_groups , q_seq_len ), dtype = jnp . float32
110+ (batch_size , num_kv_heads , q_groups , q_seq_len ), dtype = data_type
111111 )
112112 # `m` is initialized to the mask_value so that the first block's maximum logit
113113 # correctly becomes the running maximum.
114114 m = jnp .full (
115115 (batch_size , num_kv_heads , q_groups , q_seq_len ),
116116 mask_value ,
117- dtype = jnp . float32 ,
117+ dtype = data_type ,
118118 )
119119
120120 output = jnp .zeros (
@@ -138,11 +138,12 @@ def outer_loop_body(j, carried):
138138 def inner_loop_body (i , carried_inner ):
139139 output , l , m = carried_inner
140140
141+ # let's get the slice of Q in N dimension
142+ q_slice = jax .lax .dynamic_slice_in_dim (q , i * block_q , block_q , axis = - 2 )
143+
141144 # Calculates the attention computation (Q@K.T)@V with online softmax for
142145 # the current query and key/value blocks.
143146 def compute_attention_block (output , l , m ):
144- # let's get the slice of Q in N dimension
145- q_slice = jax .lax .dynamic_slice_in_dim (q , i * block_q , block_q , axis = - 2 )
146147 output_i_slice = jax .lax .dynamic_slice_in_dim (
147148 output , i * block_q , block_q , axis = - 2
148149 )
@@ -156,7 +157,7 @@ def compute_attention_block(output, l, m):
156157 "bxhqc,bxkc->bxhqk" ,
157158 q_slice ,
158159 k_j_slice ,
159- preferred_element_type = jnp . float32 ,
160+ preferred_element_type = data_type ,
160161 )
161162 full_mask_i_j_slice = jax .lax .dynamic_slice (
162163 mask_full ,
@@ -193,7 +194,7 @@ def compute_attention_block(output, l, m):
193194
194195 output_i_slice_new = numerator / divider
195196 output = jax .lax .dynamic_update_index_in_dim (
196- output , output_i_slice_new . astype ( data_type ) , i * block_q , axis = - 2
197+ output , output_i_slice_new , i * block_q , axis = - 2
197198 )
198199 l = jax .lax .dynamic_update_index_in_dim (
199200 l , l_i_new , i * block_q , axis = - 1
0 commit comments