From 8f8ce91c0eadbb7dcc050b3cbb8808da7676fb25 Mon Sep 17 00:00:00 2001 From: Snehal Verma Date: Mon, 11 May 2026 21:48:12 +0000 Subject: [PATCH] write dequantization scripts for DeepSeek V4 FP4/FP8 weights --- .../standalone_scripts/deepseek_dequantize.py | 192 ++++++++++++++++++ .../deepseek_fp8_to_bf16.py | 174 ---------------- 2 files changed, 192 insertions(+), 174 deletions(-) create mode 100644 src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_dequantize.py delete mode 100644 src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_dequantize.py b/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_dequantize.py new file mode 100644 index 0000000000..3798259476 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_dequantize.py @@ -0,0 +1,192 @@ +r"""Convert weights from FP8/FP4 to BF16 for a DeepSeek HF model. + +Example cmd: +python3 deepseek_dequantize.py --input-path \ + --output-path +""" + +import os +import json +import torch +from argparse import ArgumentParser +from glob import glob +from tqdm import tqdm +from safetensors.torch import load_file, save_file + +# Lookup table for E2M1 FP4 (Two e2m1 nibbles packed per int8/uint8 byte) +_FP4_E2M1_LUT = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], + dtype=torch.float32 +) + + +def weight_dequant_cpu(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + Dequantizes the given FP8 weight tensor using the provided scale tensor on CPU. + """ + assert x.dim() == 2 and s.dim() == 2, "Both x and s must be 2D tensors" + + M, N = x.shape + + x = x.to(torch.float32) + y = torch.empty_like(x, dtype=torch.bfloat16) + + for i in range(0, M, block_size): + for j in range(0, N, block_size): + row_start = i + row_end = min(i + block_size, M) + col_start = j + col_end = min(j + block_size, N) + block = x[row_start:row_end, col_start:col_end] + scale = s[i // block_size, j // block_size] + y[row_start:row_end, col_start:col_end] = (block * scale).to(torch.bfloat16) + + return y + +# reference: https://github.com/huggingface/transformers/blob/da6c53e431f7c9ef0691239d4ce89b0f711ecad7/src/transformers/integrations/finegrained_fp8.py#L933-L1046 +def unpack_fp4(packed: torch.Tensor) -> torch.Tensor: + """ + Unpacks packed E2M1 FP4 values (two 4-bit nibbles per byte) into standard float32 values. + """ + u8 = packed.contiguous().view(torch.uint8) + low = (u8 & 0xF).long() + high = ((u8 >> 4) & 0xF).long() + + unpacked = torch.stack([_FP4_E2M1_LUT[low], _FP4_E2M1_LUT[high]], dim=-1) + return unpacked.reshape(*packed.shape[:-1], 2 * packed.shape[-1]) + +# reference: https://github.com/huggingface/transformers/blob/da6c53e431f7c9ef0691239d4ce89b0f711ecad7/src/transformers/integrations/finegrained_fp8.py#L933-L1046 +def dequantize_mxfp4(quantized: torch.Tensor, scales: torch.Tensor) -> torch.Tensor: + """ + Dequantizes FP4 (E2M1) or FP8 (E4M3) weights using their block-wise scale grids. + """ + fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None) + is_fp4 = quantized.dtype == torch.int8 or (fp4_dtype is not None and quantized.dtype == fp4_dtype) + + if is_fp4: + quantized_fp32 = unpack_fp4(quantized) + else: + quantized_fp32 = quantized.to(torch.float32) + + rows, cols = quantized_fp32.shape[-2:] + scale_rows, scale_cols = scales.shape[-2:] + + if rows % scale_rows != 0 or cols % scale_cols != 0: + raise ValueError(f"Weight shape ({rows}, {cols}) not divisible by scale grid ({scale_rows}, {scale_cols}).") + + block_m = rows // scale_rows + block_n = cols // scale_cols + + original_shape = quantized_fp32.shape + q = quantized_fp32.reshape(-1, scale_rows, block_m, scale_cols, block_n) + s = scales.to(torch.float32).reshape(-1, scale_rows, scale_cols).unsqueeze(-1).unsqueeze(2) + + return (q * s).to(torch.bfloat16).reshape(original_shape) + + +def convert_model(input_path: str, output_path: str, cache_file_num: int = 2): + """ + Scans, converts, and saves a DeepSeek FP8/FP4 checkpoint directory to BF16. + """ + torch.set_default_dtype(torch.bfloat16) + os.makedirs(output_path, exist_ok=True) + model_index_file = os.path.join(input_path, "model.safetensors.index.json") + + if not os.path.exists(model_index_file): + raise FileNotFoundError(f"Could not locate {model_index_file}. Ensure the path is correct.") + + with open(model_index_file, "r", encoding="utf8") as f: + model_index = json.load(f) + weight_map = model_index["weight_map"] + + loaded_files = {} + converted_scales = [] + + def get_tensor(tensor_name): + file_name = weight_map[tensor_name] + if file_name not in loaded_files: + file_path = os.path.join(input_path, file_name) + loaded_files[file_name] = load_file(file_path, device="cpu") + return loaded_files[file_name][tensor_name] + + safetensor_files = sorted(glob(os.path.join(input_path, "*.safetensors"))) + print(f"Found {len(safetensor_files)} weight shards to process...") + + for safetensor_file in tqdm(safetensor_files, desc="Converting Shards"): + file_name = os.path.basename(safetensor_file) + current_state_dict = load_file(safetensor_file, device="cpu") + loaded_files[file_name] = current_state_dict + + new_state_dict = {} + + for name, tensor in current_state_dict.items(): + # Skip scale tensors; they will be integrated directly into the weights + if name.endswith("_scale_inv") or name.endswith(".scale"): + continue + + if tensor.dtype in (torch.int8, torch.float8_e4m3fn) or tensor.element_size() == 1: + # Handle both DeepSeek-V3 (_scale_inv) and V4 (scale) naming conventions + scale_name_v3 = f"{name}_scale_inv" + scale_name_v4 = f"{name[:-len('.weight')]}.scale" if name.endswith(".weight") else None + + scale_inv = None + used_scale_name = None + + try: + scale_inv = get_tensor(scale_name_v3) + used_scale_name = scale_name_v3 + except KeyError: + if scale_name_v4: + try: + scale_inv = get_tensor(scale_name_v4) + used_scale_name = scale_name_v4 + except KeyError: + pass + + if scale_inv is not None: + if tensor.dtype == torch.int8: + dequantized_tensor = dequantize_mxfp4(tensor, scale_inv) + elif tensor.dtype == torch.float8_e4m3fn: + dequantized_tensor = weight_dequant_cpu(tensor, scale_inv, block_size=128) + + new_state_dict[name] = dequantized_tensor + converted_scales.append(used_scale_name) + else: + print(f"\nWarning: scale missing for {name}. Keeping original tensor.") + new_state_dict[name] = tensor + else: + # Keep other non-quantized tensors (like biases, embeds, layer norms) intact in BF16 + new_state_dict[name] = tensor.to(torch.bfloat16) + + save_file(new_state_dict, os.path.join(output_path, file_name)) + + # Memory management: keep only the `cache_file_num` most recently used files + while len(loaded_files) > cache_file_num: + oldest_file = next(iter(loaded_files)) + del loaded_files[oldest_file] + + # Clean up JSON Index Map + print("Saving updated model index map...") + for scale_name in set(converted_scales): + if scale_name in weight_map: + weight_map.pop(scale_name) + + new_model_index_file = os.path.join(output_path, "model.safetensors.index.json") + with open(new_model_index_file, "w", encoding="utf8") as f: + json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) + + print(f"Successfully saved dequantized BF16 model to: {output_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Dequantize DeepSeek hybrid checkpoints (FP8/FP4) to BF16.") + parser.add_argument("--input-path", "--input-fp8-hf-path", type=str, required=True, + help="Path to DeepSeek FP8/FP4 Hugging Face folder") + parser.add_argument("--output-path", "--output-bf16-hf-path", type=str, required=True, + help="Directory to save output BF16 weights") + parser.add_argument("--cache-size", "--cache-file-num", type=int, default=2, + help="Max cached files in RAM during indexing lookup") + args = parser.parse_args() + + convert_model(args.input_path, args.output_path, args.cache_size) + diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py b/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py deleted file mode 100644 index c973d7f1d9..0000000000 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Convert weights from FP8 to BF16 for a HF model. - -Install these dependencies before running this script: - -pip install torch==2.4.1 safetensors==0.4.5 - -Example cmd: - -python3 -m maxtext.checkpoint_conversion.standalone_scripts.deepseek_fp8_to_bf16 --input-fp8-hf-path \ - --output-bf16-hf-path -""" - - -import os -import json -from argparse import ArgumentParser -from glob import glob - -from tqdm import tqdm - -import torch - -from safetensors.torch import load_file, save_file - - -def weight_dequant_cpu(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: - """ - Dequantizes the given FP8 weight tensor using the provided scale tensor on CPU. - - Args: - x (torch.Tensor): The quantized FP8 weight tensor of shape (M, N), dtype=torch.float8. - s (torch.Tensor): The scale tensor, dtype=torch.bfloat16 or float32. - block_size (int, optional): Size of the block used in quantization. - - Returns: - torch.Tensor: The dequantized weight tensor, dtype=torch.bfloat16. - - Raises: - AssertionError: If the input tensors are not 2D. - """ - assert x.dim() == 2 and s.dim() == 2, "Both x and s must be 2D tensors" - - M, N = x.shape - - x = x.to(torch.float32) - y = torch.empty_like(x, dtype=torch.get_default_dtype()) - - for i in range(0, M, block_size): - for j in range(0, N, block_size): - row_start = i - row_end = min(i + block_size, M) - col_start = j - col_end = min(j + block_size, N) - block = x[row_start:row_end, col_start:col_end] - scale = s[i // block_size, j // block_size] - y[row_start:row_end, col_start:col_end] = (block * scale).to(torch.get_default_dtype()) - - return y - - -def convert_fp8_to_bf16(fp8_path: str, bf16_path: str, cache_file_num: int = 2): - """ - Converts a FP8 model to a BF16 model and saves the converted weights. - - This function reads FP8 weights from the specified directory, converts them to BF16, - and saves the converted weights to another specified directory. It also updates the - model index file to reflect the changes. The conversion process runs on CPU devices. - - Args: - fp8_path (str): The path to the directory containing the FP8 weights and model index file. - bf16_path (str): The path to the directory where the converted BF16 weights will be saved. - - Raises: - KeyError: If a required scale_inv tensor is missing for a weight. - - Notes: - - The function assumes that the FP8 weights are stored in safetensor files. - - The function caches loaded safetensor files to optimize memory usage. - - The function updates the model index file to remove references to scale_inv tensors. - """ - torch.set_default_dtype(torch.bfloat16) - os.makedirs(bf16_path, exist_ok=True) - model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") - with open(model_index_file, "rt", encoding="utf8") as f: - model_index = json.load(f) - weight_map = model_index["weight_map"] - - # Cache for loaded safetensor files - loaded_files = {} - fp8_weight_names = [] - - # Helper function to get tensor from the correct file - def get_tensor(tensor_name): - """ - Retrieves a tensor from the cached safetensor files or loads it from disk if not cached. - - Args: - tensor_name (str): The name of the tensor to retrieve. - - Returns: - torch.Tensor: The retrieved tensor. - - Raises: - KeyError: If the tensor does not exist in the safetensor file. - """ - file_name = weight_map[tensor_name] - if file_name not in loaded_files: - file_path = os.path.join(fp8_path, file_name) - loaded_files[file_name] = load_file(file_path, device="cpu") - return loaded_files[file_name][tensor_name] - - safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) - safetensor_files.sort() - for safetensor_file in tqdm(safetensor_files): - file_name = os.path.basename(safetensor_file) - current_state_dict = load_file(safetensor_file, device="cpu") - loaded_files[file_name] = current_state_dict - - new_state_dict = {} - for weight_name, weight in current_state_dict.items(): - if weight_name.endswith("_scale_inv"): - continue - elif weight.element_size() == 1: # FP8 weight - scale_inv_name = f"{weight_name}_scale_inv" - try: - # Get scale_inv from the correct file - scale_inv = get_tensor(scale_inv_name) - fp8_weight_names.append(weight_name) - new_state_dict[weight_name] = weight_dequant_cpu(weight, scale_inv) - except KeyError: - print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") - new_state_dict[weight_name] = weight - else: - new_state_dict[weight_name] = weight - - new_safetensor_file = os.path.join(bf16_path, file_name) - save_file(new_state_dict, new_safetensor_file) - - # Memory management: keep only the `cache_file_num` most recently used files - while len(loaded_files) > cache_file_num: - oldest_file = next(iter(loaded_files)) - del loaded_files[oldest_file] - - # Update model index - for weight_name in fp8_weight_names: - scale_inv_name = f"{weight_name}_scale_inv" - if scale_inv_name in weight_map: - weight_map.pop(scale_inv_name) - new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") - with open(new_model_index_file, "wt", encoding="utf8") as f: - json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--input-fp8-hf-path", type=str, required=True) - parser.add_argument("--output-bf16-hf-path", type=str, required=True) - parser.add_argument("--cache-file-num", type=int, required=False, default=2) - args = parser.parse_args() - convert_fp8_to_bf16(args.input_fp8_hf_path, args.output_bf16_hf_path, args.cache_file_num)