@@ -167,6 +167,7 @@ def forward(
167167 num_activated_expert_per_token_offset ,
168168 )
169169
170+ ctx .has_num_activated_expert_per_token_offset = num_activated_expert_per_token_offset is not None
170171 ctx .mark_non_differentiable (y1 )
171172 ctx .set_materialize_grads (False )
172173
@@ -260,7 +261,10 @@ def backward(ctx, _: None, dz: torch.Tensor):
260261 grads .extend ([dx_reduced , dw1 ])
261262 if db1 is not None :
262263 grads .append (db1 )
263- grads .extend ([None ] * 5 )
264+ if ctx .has_num_activated_expert_per_token_offset :
265+ grads .extend ([None ] * 5 )
266+ else :
267+ grads .extend ([None ] * 4 )
264268 return tuple (grads )
265269
266270
@@ -280,7 +284,7 @@ def forward(
280284 x_gather_idx : torch .Tensor ,
281285 s_scatter_idx : torch .Tensor ,
282286 s_reverse_scatter_idx : torch .Tensor ,
283- num_activated_expert_per_token_offset : torch .Tensor ,
287+ num_activated_expert_per_token_offset : torch .Tensor | None ,
284288 is_varlen_K : bool ,
285289 activation_type : ActivationType ,
286290 ) -> torch .Tensor :
@@ -335,6 +339,7 @@ def forward(
335339 s_scatter_idx ,
336340 s_reverse_scatter_idx ,
337341 )
342+ ctx .has_num_activated_expert_per_token_offset = num_activated_expert_per_token_offset is None
338343
339344 return o
340345
@@ -436,7 +441,12 @@ def backward(ctx, dout: torch.Tensor):
436441 grads .extend ([None , dz , dw2 ])
437442 if db2 is not None :
438443 grads .append (db2 )
439- grads .extend ([ds , * [None ] * 5 ])
444+
445+ if ctx .has_num_activated_expert_per_token_offset :
446+ grads .extend ([ds , * [None ] * 4 ])
447+ else :
448+ grads .extend ([ds , * [None ] * 5 ])
449+
440450 return tuple (grads )
441451
442452
0 commit comments