Conversation
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
📝 WalkthroughWalkthroughTwo example files are updated: one adds an explicit Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/deepseek/quantize_to_nvfp4.py (1)
210-210:⚠️ Potential issue | 🟠 MajorMissing
dtype=torch.bfloat16— same stale-default bug that was fixed inptq.pybut not here.
torch.set_default_dtype(torch.bfloat16)at line 173 sets the runtime default, butweight_dequant's parameterdtype=torch.get_default_dtype()is evaluated once at import time (when the module is defined), capturingtorch.float32. The subsequenttorch.set_default_dtypecall has no effect on the already-captured default. Callingweight_dequant(item, scale_inv)without an explicitdtypewill produce a float32 tensor, even though the result is stored inbf16_state_dictand fed into NVFP4/FP8 quantization — causing silent dtype mismatches and 2× memory overhead for the dequantized weights.🐛 Proposed fix — consistent with the `ptq.py` fix
- bf16_state_dict[key] = weight_dequant(item, scale_inv) + bf16_state_dict[key] = weight_dequant(item, scale_inv, dtype=torch.bfloat16)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek/quantize_to_nvfp4.py` at line 210, The call to weight_dequant(item, scale_inv) relies on a default dtype captured at import time (torch.get_default_dtype()) and thus produces float32 despite torch.set_default_dtype(torch.bfloat16); fix by passing an explicit dtype=torch.bfloat16 when calling weight_dequant so bf16_state_dict entries are actually bfloat16 (or alternatively update weight_dequant's signature to default to torch.bfloat16), e.g., change the call site that writes into bf16_state_dict to call weight_dequant(item, scale_inv, dtype=torch.bfloat16).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@examples/deepseek/quantize_to_nvfp4.py`:
- Line 210: The call to weight_dequant(item, scale_inv) relies on a default
dtype captured at import time (torch.get_default_dtype()) and thus produces
float32 despite torch.set_default_dtype(torch.bfloat16); fix by passing an
explicit dtype=torch.bfloat16 when calling weight_dequant so bf16_state_dict
entries are actually bfloat16 (or alternatively update weight_dequant's
signature to default to torch.bfloat16), e.g., change the call site that writes
into bf16_state_dict to call weight_dequant(item, scale_inv,
dtype=torch.bfloat16).
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #912 +/- ##
==========================================
- Coverage 73.54% 73.10% -0.44%
==========================================
Files 205 205
Lines 22000 22281 +281
==========================================
+ Hits 16179 16288 +109
- Misses 5821 5993 +172 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
## What does this PR do? **Type of change:** ? Bug fix **Overview:** ? Fix two bugs in the PTQ script ## Testing Run DeepseekV3.2 PTQ and export <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Enhanced data type handling in quantization examples for bf16 operations * Updated internal dependencies for quantization utilities to improve modularity <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
What does this PR do?
Type of change: ? Bug fix
Overview: ?
Fix two bugs in the PTQ script
Testing
Run DeepseekV3.2 PTQ and export
Summary by CodeRabbit