Skip to content

Commit 0ac598d

Browse files
cjluo-nvrealAsma
andauthored
Update modelopt/torch/quantization/plugins/huggingface.py
Co-authored-by: realAsma <86726418+realAsma@users.noreply.github.com> Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com>
1 parent 0f6f680 commit 0ac598d

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def _gate_forward_hook(self, module, input, output):
486486
logits = output if not isinstance(output, tuple) else output[0]
487487
top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k
488488
_, indices = torch.topk(logits.float(), top_k, dim=-1)
489-
counts = torch.bincount(indices.reshape(-1), minlength=len(self.expert_token_count))
489+
counts = torch.bincount(indices.reshape(-1), minlength=self.expert_token_count.shape[0])
490490
self.expert_token_count += counts.to(self.expert_token_count.device)
491491

492492
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)