From 320af0f3bb05252ecee0a30c6e386b3d8d566fd4 Mon Sep 17 00:00:00 2001 From: rubik Date: Tue, 30 Jun 2026 15:15:23 +0800 Subject: [PATCH] issue/467 [BUG]: enforce safetensors index-based weight loading for robustness Replace fragile glob with model.safetensors.index.json lookup to handle non-standard names (e.g., Mistral-Large-Instruct). Includes glob fallback --- csrc/models/mistral/mistral_for_causal_lm.cpp | 4 ++++ python/infinilm/infer_engine.py | 8 +++++--- python/infinilm/modeling_utils.py | 16 +++++++++++++++- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/csrc/models/mistral/mistral_for_causal_lm.cpp b/csrc/models/mistral/mistral_for_causal_lm.cpp index a862add7f..181e39ceb 100644 --- a/csrc/models/mistral/mistral_for_causal_lm.cpp +++ b/csrc/models/mistral/mistral_for_causal_lm.cpp @@ -22,6 +22,10 @@ std::shared_ptr create_mistral_model_config(std:: config_json["attention_bias"] = false; } + if (!config_json.contains("torch_dtype")) { + config_json["torch_dtype"] = "bfloat16"; + } + return model_config; } diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 10cf58be2..31c0ecf1b 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -98,9 +98,11 @@ def __init__( @property def dtype(self): - torch_dtype = self.hf_config.get("torch_dtype") - if torch_dtype is None: - torch_dtype = self.hf_config.get("dtype") + torch_dtype = ( + self.hf_config.get("torch_dtype") or + self.hf_config.get("dtype") or + "bfloat16" + ) return parse_dtype(torch_dtype) @property diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index 94f4016a9..e39d9a938 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -183,7 +183,21 @@ def load_model_state_dict_by_file( already_loaded_keys = [] embed_tokens_torch_unscaled = None - file_list = glob.glob(os.path.join(model_path, "*.safetensors")) + index_file_path = os.path.join(model_path, "model.safetensors.index.json") + if os.path.exists(index_file_path): + # Priority 1: If the index file exists, strictly load exactly what it maps to. + # This handles all standard sharded models perfectly, regardless of their actual prefix. + print(f"Found index file: {index_file_path}. Loading shards by index.") + with open(index_file_path, "r") as f: + index_data = json.load(f) + weight_map = index_data.get("weight_map", {}) + unique_filenames = set(weight_map.values()) + file_list = [os.path.join(model_path, fname) for fname in unique_filenames] + else: + # Priority 2: If no index file, scan all safetensors files. + print("No index file found. Scanning all safetensors files...") + file_list = glob.glob(os.path.join(model_path, "*.safetensors")) + if len(file_list) > 0: for file_path in tqdm(file_list, desc="Processing files"): tqdm.write(f"Processing: {os.path.basename(file_path)}")