diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 861992b705..0a72a6ffbe 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -8,8 +8,6 @@ import gc import inspect import logging -import shutil -import subprocess import warnings from abc import ABC, abstractmethod from pathlib import Path @@ -20,28 +18,24 @@ from QEfficient.base.onnx_transforms import ( BaseOnnxTransform, - FP16ClipTransform, + CustomOpTransform, OnnxTransformPipeline, + RenameFunctionOutputsTransform, SplitTensorsTransform, ) from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.blocking.blocking_configurator import build_transformer_blocking_config_for_transform -from QEfficient.compile.qnn_compiler import compile as qnn_compile from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.transformers.models.pytorch_transforms import ( BlockingAttentionTransform, + ReplicateKVHeadTransform, ) from QEfficient.utils import ( constants, - create_json, create_model_params, dump_qconfig, - generate_mdp_partition_config, get_attr_or_key, - hash_dict_params, - load_json, require_value, - to_named_specializations, ) from QEfficient.utils.export_utils import export_wrapper @@ -58,6 +52,10 @@ class QEFFBaseModel(ABC): :_onnx_transforms: ONNX transformations to be applied after ONNX export. """ + _start = 0 + _end = 1 + _total_layers = None + _pytorch_transforms: List[PytorchTransform] _onnx_transforms = [BaseOnnxTransform] @@ -284,8 +282,16 @@ def _export( instance using from_pretrained() for re-export. """ + + idx = int(QEFFBaseModel._start) + # agent change start: generalized layerwise window + end_idx = int(getattr(QEFFBaseModel, "_end", idx + 1)) + if end_idx <= idx: + raise ValueError(f"Invalid export window: start={idx}, end={end_idx}") + # TODO: Hack for retain_full_kv, handle this outside export_kwargs.pop("retain_full_kv", None) + export_kwargs.pop("mla_absorption", None) onnx_path = export_dir / f"{self.model_name}.onnx" # Return early if ONNX already exists @@ -298,12 +304,44 @@ def _export( export_dir.mkdir(parents=True, exist_ok=True) + # Setup temporary paths + tmp_onnx_dir = export_dir / "onnx_layerwise_tmp" + tmp_onnx_dir.mkdir(parents=True, exist_ok=True) + + output_name = [] + output_name.append("logits") + # agent change start: emit retained states for all layers in current export window + for layer_idx in range(idx, end_idx): + output_name.append(f"compressed_kv.{layer_idx}_InternalRetainedState") + output_name.append(f"k_pe.{layer_idx}_InternalRetainedState") + + if idx >= 1: + z = example_inputs.pop("input_ids") + # z = example_inputs["input_ids"] + ################### model_dependent ############################ + inputs_embeds = torch.rand(z.shape[0], z.shape[1], 7168, device=z.device, dtype=torch.float16) + # example_inputs[f"layer_{QEFFBaseModel._start}/inputs_embeds"] = inputs_embeds + # dynamic_axes[f"layer_{QEFFBaseModel._start}/inputs_embeds"] = dynamic_axes.pop("input_ids") + example_inputs["inputs_embeds"] = inputs_embeds + dynamic_axes["inputs_embeds"] = dynamic_axes.pop("input_ids") + # Create input_names from example_inputs + # example_inputs[f"layer_{QEFFBaseModel._start}/position_ids"] = example_inputs.pop("position_ids") + # dynamic_axes[f"layer_{QEFFBaseModel._start}/position_ids"] = dynamic_axes.pop("position_ids") + + window_size = end_idx - idx + if "compressed_kvs" in example_inputs: + example_inputs["compressed_kvs"] = [ + val for i, val in enumerate(example_inputs["compressed_kvs"]) if i < window_size + ] + # Create input_names from example_inputs input_names = [] for param in inspect.signature(self.model.forward).parameters: if param in example_inputs: if param == "past_key_values": for i in range(len(example_inputs["past_key_values"])): + # example_inputs["past_key_values"] = [ + # val for i, val in enumerate(example_inputs["past_key_values"]) if i < window_size] if len(example_inputs["past_key_values"][0]) == 2: input_names.extend([f"past_key.{i}", f"past_value.{i}"]) elif len(example_inputs["past_key_values"][0]) == 4: @@ -319,56 +357,66 @@ def _export( raise ValueError( f"Unknown shape of past_key_values! Expected length of past_key_values for each layer to be either 2 or 4 but got {len(example_inputs['past_key_values'][0])}" ) + elif param == "compressed_kvs": + if len(example_inputs["compressed_kvs"][0]) == 2: + for layer_offset in range(len(example_inputs["compressed_kvs"])): + layer_idx = idx + layer_offset + input_names.extend([f"compressed_kv.{layer_idx}", f"k_pe.{layer_idx}"]) + else: + for i in range(len(example_inputs["compressed_kvs"])): + input_names.extend( + [ + f"compressed_kv.{i}", + ] + ) + input_names.extend( + [ + f"k_pe.{i}", + ] + ) else: input_names.append(param) + dynamic_axes = {k: v for k, v in dynamic_axes.items() if k in input_names} + + import os + import time + + layerwise_dir = export_dir / "onnx_layerwise_tmp" + start_time = time.time() + + # example_inputs["layer_indices_to_run"] = [i] + current_layer_dir = layerwise_dir / f"layer_{idx}_{end_idx}" + current_layer_dir.mkdir(parents=True, exist_ok=True) - try: + layer_onnx_path = str(current_layer_dir / f"{self.model_name}_layer_{idx}_{end_idx}.onnx") + layer_onnx_path_tmp = str(current_layer_dir / f"{self.model_name}_layer_tmp_{idx}_{end_idx}.onnx") + if not os.path.isfile(layer_onnx_path): torch.onnx.export( self.model, (example_inputs,), - str(onnx_path), + layer_onnx_path_tmp, input_names=input_names, - output_names=output_names, + output_names=output_name, dynamic_axes=dynamic_axes, opset_version=constants.ONNX_EXPORT_OPSET, **export_kwargs, ) - logger.info("PyTorch export successful") - _ = self._offload_model_weights(offload_pt_weights) - model = onnx.load(onnx_path, load_external_data=False) - - needs_external_tensor_data = any( - transform in self._onnx_transforms for transform in (FP16ClipTransform, SplitTensorsTransform) - ) - transform_kwargs = { - "onnx_base_dir": str(export_dir) if needs_external_tensor_data else None, - "model_name": self.model_name, - } - if onnx_transform_kwargs is not None: - transform_kwargs.update(onnx_transform_kwargs) - - onnx_transforms = OnnxTransformPipeline(transforms=self._onnx_transforms) - model, transformed = onnx_transforms.apply(model, **transform_kwargs) - - # Add metadata to the model - model.metadata_props.append( - onnx.StringStringEntryProto(key="qeff_transforms", value=",".join(self._transform_names())) - ) - logger.info("ONNX transforms applied") - - onnx_path_tmp = onnx_path.with_suffix(onnx_path.suffix + ".tmp") - onnx.save(model, onnx_path_tmp) - onnx_path_tmp.replace(onnx_path) - del model - gc.collect() - logger.info("Transformed ONNX saved") - - except Exception as e: - logger.error(f"ONNX export or transforms failed: {e}") - raise e - - self.onnx_path = onnx_path - return onnx_path + total_end = time.time() + print(f"\nTotal export time: {total_end - start_time:.2f} seconds") + + model = onnx.load(layer_onnx_path_tmp, load_external_data=False) + # print(model.functions) + transform_kwargs = { + "onnx_base_dir": str(current_layer_dir), + "model_name": self.model_name, + "layer_idx": idx, + } + _onnx_transforms = [SplitTensorsTransform, CustomOpTransform, RenameFunctionOutputsTransform] + onnx_transforms = OnnxTransformPipeline(transforms=_onnx_transforms) + model, transformed = onnx_transforms.apply(model, **transform_kwargs) + onnx.save(model, layer_onnx_path_tmp) + self.onnx_path = layer_onnx_path_tmp + return layer_onnx_path_tmp def get_onnx_path( self, @@ -378,6 +426,7 @@ def get_onnx_path( offload_pt_weights: Optional[bool] = True, use_onnx_subfunctions: Optional[bool] = False, retain_full_kv: Optional[bool] = False, + mla_absorption: Optional[Dict[str, bool]] = None, qaic_config: Optional[dict] = None, **compiler_options, ): @@ -385,6 +434,7 @@ def get_onnx_path( "offload_pt_weights": offload_pt_weights, "use_onnx_subfunctions": use_onnx_subfunctions, "retain_full_kv": retain_full_kv, + "mla_absorption": mla_absorption, } if prefill_only: @@ -433,11 +483,22 @@ def transform( qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None) - if getattr(self.model, "config", None) or getattr(self.model.model, "config", None): + model_config = getattr(self.model, "config", None) or getattr(self.model.model, "config", None) + + if model_config: + if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []): + if qaic_config: + if qaic_config.get("blocking_mode", None) == "h": + qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) + num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply( + self.model, num_kv_heads_repeat + ) + if replicate_kv_transformed: + self.hash_params["config"] = self.model.config.to_diff_dict() + blocking_config = build_transformer_blocking_config_for_transform( - getattr(self.model, "config", None) - if getattr(self.model, "config", None) - else getattr(self.model.model, "config", None), + model_config, ctx_len=ctx_len, seq_len=seq_len, bs=bs, @@ -471,6 +532,7 @@ def _compile( offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, retain_full_kv: Optional[bool] = None, + mla_absorption: Optional[Dict[str, bool]] = None, qaic_config: Optional[dict] = None, specialization_module_name: Optional[str] = None, **compiler_options, @@ -498,6 +560,7 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ + onnx_path = Path( onnx_path if onnx_path @@ -510,150 +573,152 @@ def _compile( offload_pt_weights, use_onnx_subfunctions, retain_full_kv, + mla_absorption, num_devices=mdp_ts_num_devices, qaic_config=qaic_config, **compiler_options, ) ) - compile_dir = Path(compile_dir or onnx_path.parent) - qpc_path = compile_dir / "qpc" - if not onnx_path.is_file(): - raise FileNotFoundError(f"ONNX file not found at: {onnx_path}") - - if enable_qnn: - if compiler_options: - logger.warning( - f"Extra arguments to QNN compilation are supported only via qnn_config file. Ignoring {compiler_options}" - ) - - self.qpc_path = qnn_compile( - onnx_path=onnx_path, - qpc_base_path=compile_dir, - specializations=specializations, - custom_io=custom_io, - device_group=list(range(mdp_ts_num_devices)), - num_cores=compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES), - mxfp6=compiler_options.get("mxfp6_matmul", constants.DEFAULT_AIC_MXPF6_MATMUL), - mxint8=mxint8_kv_cache, - qnn_config=qnn_config, - ) - - return self.qpc_path - - command = ( - constants.COMPILER - + [ - f"-aic-hw-version={compiler_options.pop('aic_hw_version', compiler_options.pop('aic-hw-version', constants.DEFAULT_AIC_HW_VERSION))}" - ] - + [f"-m={onnx_path}"] - ) - - # MDP partition config: prioritize dump over load - mdp_dump_json_path = compiler_options.pop("mdp_dump_partition_config", None) - mdp_ts_json_path = compiler_options.pop("mdp_load_partition_config", None) - mdp_ts_json = None - - if mdp_dump_json_path: - if mdp_ts_json_path: - logger.warning( - "Loading and Dumping partition is not supported at the same time. Prioritizing dump config over load config!" - ) - command.append(f"-mdp-dump-partition-config={mdp_dump_json_path}") - elif mdp_ts_json_path: - command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") - mdp_ts_json = load_json(str(mdp_ts_json_path)) - elif mdp_ts_num_devices > 1: - # Generate mdp config only if neither dump nor load is provided and num_devices > 1 - mdp_ts_json = generate_mdp_partition_config( - mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES) - ) - mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json" - create_json(str(mdp_ts_json_path), mdp_ts_json) - command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") - - for key, value in compiler_options.items(): - option = "-" + key.replace("_", "-") - if isinstance(value, bool): - if value: - command.append(option) - continue - command.append(f"{option}={value}") - - if use_onnx_subfunctions: - logger.info("Using ONNX subfunctions for compilation.") - command.append("-sub-functions") - - compile_hash_params = { - "command": command, - "specializations": specializations, - "custom_io": custom_io, - "mdp_ts_num_devices": mdp_ts_num_devices, - "mdp_ts_json": mdp_ts_json, - "num_speculative_tokens": num_speculative_tokens, - "prefill_only": prefill_only, - } - compile_hash = hash_dict_params(compile_hash_params) - - compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash) - qpc_path = compile_dir / "qpc" - qpc_path.mkdir(parents=True, exist_ok=True) - - if qpc_path.is_dir(): - if (qpc_path / "programqpc.bin").is_file(): - self.qpc_path = qpc_path - return qpc_path - # Probably compilation failure last time, delete directory to start over - shutil.rmtree(qpc_path) - - # Write the generated MDP partition config file (not if user provided it) - - # Write specializations.json file - if specializations is not None: - specializations_json = compile_dir / "specializations.json" - specializations_data = { - "specializations": to_named_specializations(specializations, module_name=specialization_module_name) - } - create_json(str(specializations_json), specializations_data) - command.append(f"-network-specialization-config={specializations_json}") - - # Write custom_io.yaml file - model_in_bfloat16 = hasattr(self, "config") and (self.config.torch_dtype == torch.bfloat16) - pkv_in_bfloat16 = (custom_io is not None) and any( - "past_" in key and "bfloat16" in value for key, value in custom_io.items() - ) - if custom_io is not None: - custom_io_yaml = compile_dir / "custom_io.yaml" - with open(custom_io_yaml, "w") as fp: - for io_name, dtype in custom_io.items(): - fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n") - if model_in_bfloat16 and pkv_in_bfloat16: - logger.warning( - "Model and Past KV types are both bfloat16. Custom IO list file will be ignored during compile." - ) - else: - command.append(f"-custom-IO-list-file={custom_io_yaml}") - - command.append(f"-aic-binary-dir={qpc_path}") - logger.info(f"Running compiler: {' '.join(command)}") - - try: - subprocess.run(command, capture_output=True, check=True) - except subprocess.CalledProcessError as e: - raise RuntimeError( - "\n".join( - [ - "Compilation failed!", - f"Compiler command: {e.cmd}", - f"Compiler exitcode: {e.returncode}", - "Compiler stderr:", - e.stderr.decode(), - ] - ) - ) - # Dump JSON file with hashed parameters - hashed_compile_params_path = compile_dir / "hashed_compile_params.json" - create_json(hashed_compile_params_path, compile_hash_params) - logger.info("Hashed parameters exported successfully.") - - self.qpc_path = qpc_path - return qpc_path + return onnx_path + # compile_dir = Path(compile_dir or onnx_path.parent) + # qpc_path = compile_dir / "qpc" + # if not onnx_path.is_file(): + # raise FileNotFoundError(f"ONNX file not found at: {onnx_path}") + + # if enable_qnn: + # if compiler_options: + # logger.warning( + # f"Extra arguments to QNN compilation are supported only via qnn_config file. Ignoring {compiler_options}" + # ) + + # self.qpc_path = qnn_compile( + # onnx_path=onnx_path, + # qpc_base_path=compile_dir, + # specializations=specializations, + # custom_io=custom_io, + # device_group=list(range(mdp_ts_num_devices)), + # num_cores=compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES), + # mxfp6=compiler_options.get("mxfp6_matmul", constants.DEFAULT_AIC_MXPF6_MATMUL), + # mxint8=mxint8_kv_cache, + # qnn_config=qnn_config, + # ) + + # return self.qpc_path + + # command = ( + # constants.COMPILER + # + [ + # f"-aic-hw-version={compiler_options.pop('aic_hw_version', compiler_options.pop('aic-hw-version', constants.DEFAULT_AIC_HW_VERSION))}" + # ] + # + [f"-m={onnx_path}"] + # ) + + # # MDP partition config: prioritize dump over load + # mdp_dump_json_path = compiler_options.pop("mdp_dump_partition_config", None) + # mdp_ts_json_path = compiler_options.pop("mdp_load_partition_config", None) + # mdp_ts_json = None + + # if mdp_dump_json_path: + # if mdp_ts_json_path: + # logger.warning( + # "Loading and Dumping partition is not supported at the same time. Prioritizing dump config over load config!" + # ) + # command.append(f"-mdp-dump-partition-config={mdp_dump_json_path}") + # elif mdp_ts_json_path: + # command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") + # mdp_ts_json = load_json(str(mdp_ts_json_path)) + # elif mdp_ts_num_devices > 1: + # # Generate mdp config only if neither dump nor load is provided and num_devices > 1 + # mdp_ts_json = generate_mdp_partition_config( + # mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES) + # ) + # mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json" + # create_json(str(mdp_ts_json_path), mdp_ts_json) + # command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") + + # for key, value in compiler_options.items(): + # option = "-" + key.replace("_", "-") + # if isinstance(value, bool): + # if value: + # command.append(option) + # continue + # command.append(f"{option}={value}") + + # if use_onnx_subfunctions: + # logger.info("Using ONNX subfunctions for compilation.") + # command.append("-sub-functions") + + # compile_hash_params = { + # "command": command, + # "specializations": specializations, + # "custom_io": custom_io, + # "mdp_ts_num_devices": mdp_ts_num_devices, + # "mdp_ts_json": mdp_ts_json, + # "num_speculative_tokens": num_speculative_tokens, + # "prefill_only": prefill_only, + # } + # compile_hash = hash_dict_params(compile_hash_params) + + # compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash) + # qpc_path = compile_dir / "qpc" + # qpc_path.mkdir(parents=True, exist_ok=True) + + # if qpc_path.is_dir(): + # if (qpc_path / "programqpc.bin").is_file(): + # self.qpc_path = qpc_path + # return qpc_path + # # Probably compilation failure last time, delete directory to start over + # shutil.rmtree(qpc_path) + + # # Write the generated MDP partition config file (not if user provided it) + + # # Write specializations.json file + # if specializations is not None: + # specializations_json = compile_dir / "specializations.json" + # specializations_data = { + # "specializations": to_named_specializations(specializations, module_name=specialization_module_name) + # } + # create_json(str(specializations_json), specializations_data) + # command.append(f"-network-specialization-config={specializations_json}") + + # # Write custom_io.yaml file + # model_in_bfloat16 = hasattr(self, "config") and (self.config.torch_dtype == torch.bfloat16) + # pkv_in_bfloat16 = (custom_io is not None) and any( + # "past_" in key and "bfloat16" in value for key, value in custom_io.items() + # ) + # if custom_io is not None: + # custom_io_yaml = compile_dir / "custom_io.yaml" + # with open(custom_io_yaml, "w") as fp: + # for io_name, dtype in custom_io.items(): + # fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n") + # if model_in_bfloat16 and pkv_in_bfloat16: + # logger.warning( + # "Model and Past KV types are both bfloat16. Custom IO list file will be ignored during compile." + # ) + # else: + # command.append(f"-custom-IO-list-file={custom_io_yaml}") + + # command.append(f"-aic-binary-dir={qpc_path}") + # logger.info(f"Running compiler: {' '.join(command)}") + + # try: + # subprocess.run(command, capture_output=True, check=True) + # except subprocess.CalledProcessError as e: + # raise RuntimeError( + # "\n".join( + # [ + # "Compilation failed!", + # f"Compiler command: {e.cmd}", + # f"Compiler exitcode: {e.returncode}", + # "Compiler stderr:", + # e.stderr.decode(), + # ] + # ) + # ) + # # Dump JSON file with hashed parameters + # hashed_compile_params_path = compile_dir / "hashed_compile_params.json" + # create_json(hashed_compile_params_path, compile_hash_params) + # logger.info("Hashed parameters exported successfully.") + + # self.qpc_path = qpc_path + # return qpc_path diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 2ba53829a4..91c6ef3e27 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -129,17 +129,80 @@ def apply(cls, model: ModelProto) -> bool: return op_applied +class RemovePrefix(BaseOnnxTransform): + @classmethod + def apply(cls, model: ModelProto) -> bool: + graph = model.graph + renamed = False + + def strip_prefix(name: str) -> str: + parts = name.rsplit("/", 1) + return parts[1] if len(parts) == 2 else parts[0] + + input_names = [] + for i, inputs in enumerate(graph.input): + original = inputs.name + new = strip_prefix(original) + if new != original: + renamed = True + inputs.name = new + graph.input[i].name = new + input_names.append(new) + + input_name_set = set(input_names) + output_rename_map = {} + + # Rename model graph outputs and keep mapping so producer/consumer edges can be fixed. + for out in graph.output: + original = out.name + new = strip_prefix(original) + if new != original: + out.name = new + output_rename_map[original] = new + renamed = True + + for node in graph.node: + for i, out in enumerate(node.output): + if out in output_rename_map and output_rename_map[out] != out: + node.output[i] = output_rename_map[out] + renamed = True + + new_inputs = [] + for s in node.input: + # Keep node inputs in sync for renamed model outputs. + if s in output_rename_map: + new_inputs.append(output_rename_map[s]) + continue + + if s in input_name_set: + new_inputs.append(s) + continue + + replaced = s + if "/" in s: + tail = s.rsplit("/", 1)[1] + if tail in input_name_set: + replaced = tail + new_inputs.append(replaced) + + for idx in range(len(node.input)): + if node.input[idx] != new_inputs[idx]: + node.input[idx] = new_inputs[idx] + renamed = True + + return renamed + + class RenameFunctionOutputsTransform(BaseOnnxTransform): """Rename outputs of decoder-related functions for better clarity.""" @classmethod - def apply(cls, model: ModelProto) -> bool: + def apply(cls, model: ModelProto, layer_idx=0) -> bool: graph = model.graph op_type_to_func = {f.name: f for f in model.functions} decoder_patterns = ["DecoderLayer", "Block", "Layer"] renamed = False model_out_map = {v.name: i for i, v in enumerate(graph.output)} - layer_idx = 0 for node in graph.node: if any(p in node.name or p in node.op_type for p in decoder_patterns): @@ -150,13 +213,16 @@ def apply(cls, model: ModelProto) -> bool: if "_InternalRetainedState" in out_name: renamed = True orig = node.output[i] - new = ( - f"past_key.{layer_idx}_RetainedState" - if "key" in out_name - else f"past_value.{layer_idx}_RetainedState" - if "value" in out_name - else orig - ) + if "key" in out_name: + new = f"past_key.{layer_idx}_RetainedState" + elif "value" in out_name: + new = f"past_value.{layer_idx}_RetainedState" + elif "compressed_kv" in out_name: + new = f"compressed_kv.{layer_idx}_RetainedState" + elif "k_pe" in out_name: + new = f"k_pe.{layer_idx}_RetainedState" + else: + new = orig node.output[i] = new if orig in model_out_map: graph.output[model_out_map[orig]].name = new @@ -275,7 +341,9 @@ def _set_external_data(tensor, file_name): applied[CustomOpTransform] = CustomOpTransform.apply(model) if RenameFunctionOutputsTransform in requested: - applied[RenameFunctionOutputsTransform] = RenameFunctionOutputsTransform.apply(model) + applied[RenameFunctionOutputsTransform] = RenameFunctionOutputsTransform.apply( + model, layer_idx=kwargs.get("layer_idx", 0) + ) if AdapterWeightsToInputsTransform in requested: applied[AdapterWeightsToInputsTransform] = AdapterWeightsToInputsTransform.apply(model, **kwargs) diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py index 6f19e006dc..2ab5c03bec 100644 --- a/QEfficient/blocking/attention_blocking.py +++ b/QEfficient/blocking/attention_blocking.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional import torch from transformers.cache_utils import Cache @@ -17,8 +17,10 @@ from QEfficient.blocking.blocked_attention_forwards import ( blocked_bhqkv_attention_forward, blocked_h_attention_forward, + blocked_h_mla_attention_forward, blocked_hqkv_attention_forward, blocked_kv_attention_forward, + blocked_kv_mla_attention_forward, blocked_q_attention_forward, blocked_qkv_attention_forward, ) @@ -57,6 +59,11 @@ def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool: BlockingMode.BHQKV: blocked_bhqkv_attention_forward, } +_STRATEGIES_MLA: Dict[BlockingMode, Callable] = { + BlockingMode.KV: blocked_kv_mla_attention_forward, + BlockingMode.H: blocked_h_mla_attention_forward, +} + # helper function needed both in generic blocked approach and in other modeling files for non-blocked approach def past_key_value_update( @@ -160,3 +167,66 @@ def generic_blocked_attention_interface( ) return attn_output, attn_weights + + +def generic_blocked_mla_attention_interface( + module, + attention_mask: Optional[torch.Tensor], + scaling: float, + mla_absorption: Dict[str, Any], + blocking_config: AttentionBlockingConfig, + query: Optional[torch.Tensor] = None, + q_a_proj_out: Optional[torch.Tensor] = None, + fusedqk: Optional[torch.Tensor] = None, + q_nope: Optional[torch.Tensor] = None, + q_pe: Optional[torch.Tensor] = None, + kva: Optional[torch.Tensor] = None, + k_pe: Optional[torch.Tensor] = None, + per_head_q_up: Optional[torch.Tensor] = None, + per_head_k_up: Optional[torch.Tensor] = None, + per_head_v_up: Optional[torch.Tensor] = None, + per_head_k_up_normal: Optional[torch.Tensor] = None, + layer_idx: Optional[int] = None, + compressed_kvs: Optional[torch.Tensor] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_seen_tokens: Optional[int] = None, + non_blocked_forward: Callable = None, + score_mod: Optional[Callable] = None, + position_bias: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + sliding_window: Optional[int] = None, + **kwargs, +): + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + mla_blocking_strategy = _STRATEGIES_MLA.get(blocking_config.mode) + attn_output, attn_weights = mla_blocking_strategy( + module=module, + query=query, + q_a_proj_out=q_a_proj_out, + fusedqk=fusedqk, + q_nope=q_nope, + q_pe=q_pe, + kva=kva, + k_pe=k_pe, + per_head_q_up=per_head_q_up, + per_head_k_up=per_head_k_up, + per_head_v_up=per_head_v_up, + per_head_k_up_normal=per_head_k_up_normal, + attention_mask=attention_mask, + scaling=scaling, + cache_kwargs=cache_kwargs, + layer_idx=layer_idx, + compressed_kvs=compressed_kvs, + mla_absorption=mla_absorption, + num_kv_blocks=blocking_config.num_kv_blocks, + num_q_blocks=blocking_config.num_q_blocks, + head_block_size=blocking_config.head_block_size, + num_batch_blocks=blocking_config.num_batch_blocks, + score_mod=score_mod, + position_bias=position_bias, + sinks=sinks, + ) + + return attn_output, attn_weights diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 83efb8302e..37f65034a3 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -817,3 +817,223 @@ def blocked_q_attention_forward( attn_weights = torch.cat(q_attn_blocks, dim=2) return attn_output, attn_weights + + +def blocked_kv_mla_attention_forward( + module: nn.Module, + query: torch.Tensor, + per_head_k_up_normal: torch.Tensor, + per_head_v_up: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: int, + cache_kwargs: Dict[str, Any], + layer_idx: int, + compressed_kvs: Optional[torch.Tensor], + mla_absorption: Dict[str, Any], + *, + use_causal_mask: bool = False, + sliding_window: Optional[int] = None, + skip_kv: bool = False, + position_bias: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Initialize result tensor + batch_size, num_heads, seq_len, _ = query.shape + output = torch.zeros( + batch_size, num_heads, seq_len, module.config.kv_lora_rank, device=query.device, dtype=query.dtype + ) + + if hasattr(module, "config"): + mask_dtype = module.config.torch_dtype + else: + mask_dtype = query.dtype + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=mask_dtype, device=query.device) + + # Initialize Running Maximum and Denominator + current_max = torch.full( + (batch_size, num_heads, seq_len), + float(MIN_MASKED_ATTENTION_VALUE), + device=query.device, + dtype=query.dtype, + ) + skip_kv = True + current_denominator = torch.zeros(batch_size, num_heads, seq_len, device=query.device, dtype=query.dtype) + + ctx_len = compressed_kvs.layers[layer_idx].ckv.shape[2] + kv_block_size = -(-ctx_len // num_kv_blocks) + + position_ids = cache_kwargs.get("position_ids") + current_position = position_ids.max(dim=-1).values + + for j in range(num_kv_blocks): + start_index = j * kv_block_size + if j == num_kv_blocks - 1: + kv_len_block = ctx_len - start_index + else: + kv_len_block = kv_block_size + end_index = start_index + kv_len_block + + skip_future = None + if skip_kv: + skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() + # Eager mode Only + if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): + if skip_future.item(): + break + + compressed_kv_block = compressed_kvs.read_only_blocked_ckv(start_index, end_index, layer_idx, cache_kwargs) + k_pe_block = compressed_kvs.read_only_blocked_k_pe(start_index, end_index, layer_idx, cache_kwargs) + + causal_mask_block = _create_causal_mask( + position_ids=position_ids, + target_length=end_index, + start_index=start_index, + ) + + if mla_absorption is not None: + absorption = mla_absorption.get("absorption", False) + else: + absorption = False + + k_heads, q_heads = compressed_kv_block.shape[1], query.shape[1] + num_heads_to_repeat = q_heads - k_heads + repeated_ckv_block = compressed_kv_block[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.kv_lora_rank + ) + compressed_kv_block = torch.cat((compressed_kv_block, repeated_ckv_block), dim=1) + + repeated_k_pe_block = k_pe_block[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + ) + k_pe_block = torch.cat((k_pe_block, repeated_k_pe_block), dim=1) + + if absorption: + krope_nope = torch.cat((compressed_kv_block, k_pe_block), dim=-1) + k_heads, q_heads = krope_nope.shape[1], query.shape[1] + num_heads_to_repeat = q_heads - k_heads + repeated_k = krope_nope[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + module.kv_lora_rank + ) + krope_nope = torch.cat((krope_nope, repeated_k), dim=1) + attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling + # [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size] + attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) + current_max, current_denominator, output = update_running_softmax( + current_max, + attn_weights_block, + current_denominator, + output, + compressed_kv_block, + skip_kv, + skip_future, + ) # [1, 64, q_len, kv_block_size] X [1, 1, kv_block_size, 512] -> [1, 64, q_len, 512] + else: + k_heads, q_heads = compressed_kv_block.shape[1], query.shape[1] + num_heads_to_repeat = q_heads - k_heads + repeated_ckv_block = compressed_kv_block[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.kv_lora_rank + ) + compressed_kv_block = torch.cat((compressed_kv_block, repeated_ckv_block), dim=1) + knope = torch.matmul(compressed_kv_block, per_head_k_up_normal) + + repeated_k_pe_block = k_pe_block[:, 0, :, :].expand( + batch_size, num_heads_to_repeat, -1, module.qk_rope_head_dim + ) + k_pe_block = torch.cat((k_pe_block, repeated_k_pe_block), dim=1) + + krope_nope = torch.cat((knope, k_pe_block.expand(-1, num_heads, -1, -1)), dim=-1) + attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling + attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) + current_max, current_denominator, output = update_running_softmax( + current_max, + attn_weights_block, + current_denominator, + output, + compressed_kv_block, + skip_kv, + skip_future, + ) + + attn_output = torch.matmul(output, per_head_v_up) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_weights = None + + return attn_output, attn_weights + + +def blocked_h_mla_attention_forward( + module: nn.Module, + q_a_proj_out: torch.Tensor, + fusedqk: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kva: torch.Tensor, + k_pe: torch.Tensor, + per_head_q_up: torch.Tensor, + per_head_k_up: torch.Tensor, + per_head_v_up: torch.Tensor, + per_head_k_up_normal: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + mla_absorption: Dict[str, Any], + head_block_size: int, + *, + position_bias: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + H-blocked attention that slices along head dimension to create blocks and processes each block. + """ + batch_size, num_heads, q_len, _ = q_pe.shape + if head_block_size <= 0: + head_block_size = num_heads + num_head_blocks = math.ceil(num_heads / head_block_size) + + if hasattr(module, "config"): + mask_dtype = module.config.torch_dtype + else: + mask_dtype = q_pe.dtype + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=mask_dtype, device=q_pe.device) + + if mla_absorption is not None: + absorption = mla_absorption.get("absorption", False) + online = mla_absorption.get("online", False) + else: + absorption = False + + h_output_blocks = [] + h_attn_blocks = [] + # Process each head block independently + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, num_heads) + + if absorption: + if online: + qup_kupT = torch.matmul(per_head_q_up[:, h_start:h_end, :, :], per_head_k_up[:, h_start:h_end, :, :]) + dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) + else: + dq_qup_kupT = torch.matmul(q_a_proj_out, fusedqk[:, h_start:h_end, :, :]) + qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe[:, h_start:h_end, :, :]), dim=-1) + krope_nope = torch.cat((kva, k_pe), dim=-1) + attn_weights = torch.matmul(qkupTrope_nope, krope_nope.transpose(2, 3)) * scaling + else: + knope = torch.matmul(kva, per_head_k_up_normal[:, h_start:h_end, :, :]) + krope_nope = torch.cat((knope, k_pe), dim=-1) + qrope_nope = torch.cat((q_nope[:, h_start:h_end, :, :], q_pe[:, h_start:h_end, :, :]), dim=-1) + attn_weights = torch.matmul(qrope_nope, krope_nope.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = torch.where(attention_mask, masked_tensor, attn_weights) + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype) + attn_output = torch.matmul(attn_weights, kva) + attn_output = torch.matmul(attn_output, per_head_v_up[:, h_start:h_end, :, :]) + h_output_blocks.append(attn_output) + h_attn_blocks.append(attn_weights) + + attn_output = torch.cat(h_output_blocks, dim=1).transpose(1, 2).contiguous() + attn_weights = torch.cat(h_attn_blocks, dim=1) + return attn_output, attn_weights diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index e741a956e1..deed73a7bf 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -18,7 +18,7 @@ from QEfficient.blocking.attention_blocking import AttentionBlockingConfig, BlockingMode from QEfficient.utils import get_attr_or_key, require_value -from QEfficient.utils.constants import VTCM_SIZE_THRESHOLD +from QEfficient.utils.constants import DEFAULT_NUM_HEADS, FP16_BYTES, KV_LORA_RANK, ROPE_DIM, VTCM_SIZE_THRESHOLD def _infer_head_dim(model_config: Any, num_heads: int) -> int: @@ -90,6 +90,92 @@ def block_candidates_generator(max_length: int) -> List[int]: return block_list +def matmul1_bytes(q_len: int, kv_block_size: int, num_heads: int = DEFAULT_NUM_HEADS) -> int: + """Bytes for [1,num_heads,q,kv] x [1,1,kv,512] -> [1,num_heads,q,512] in fp16.""" + elems_a = num_heads * q_len * kv_block_size + elems_b = kv_block_size * KV_LORA_RANK + elems_out = num_heads * q_len * KV_LORA_RANK + return FP16_BYTES * (elems_a + elems_b + elems_out) + + +def matmul2_bytes(q_len: int, kv_block_size: int, num_heads: int = DEFAULT_NUM_HEADS) -> int: + """Bytes for [1,num_heads,q,576] x [1,1,576,kv] -> [1,num_heads,q,kv] in fp16.""" + elems_a = num_heads * q_len * (KV_LORA_RANK + ROPE_DIM) + elems_b = 576 * kv_block_size + elems_out = num_heads * q_len * kv_block_size + return FP16_BYTES * (elems_a + elems_b + elems_out) + + +def max_kv_block_size( + q_len: int, + budget_bytes: int = VTCM_SIZE_THRESHOLD, + num_heads: int = DEFAULT_NUM_HEADS, +) -> int: + """Return the largest integer kv_block_size that satisfies both matmul budgets. + + Returns 0 if no positive kv_block_size can satisfy the constraints. + """ + if q_len < 0: + raise ValueError("q_len must be non-negative") + if budget_bytes <= 0: + raise ValueError("budget_bytes must be positive") + if num_heads <= 0: + raise ValueError("num_heads must be positive") + + # Enforce strict inequality in bytes: + # FP16_BYTES * elems < budget_bytes => elems <= floor((budget_bytes - 1)/FP16_BYTES) + max_elems = (budget_bytes - 1) // FP16_BYTES + + # Matmul1 elements: + # A_elems = num_heads*q_len*kv + # B_elems = kv*512 + # C_elems = num_heads*q_len*512 + # Enforce A_elems + B_elems + C_elems <= max_elems + c1_elems = num_heads * q_len * KV_LORA_RANK + rem1 = max_elems - c1_elems + den1 = num_heads * q_len + KV_LORA_RANK # kv coefficient from A_elems + B_elems + k1 = rem1 // den1 if rem1 >= 0 else -1 + + # Matmul2 elements: + # A_elems = num_heads*q_len*576 + # B_elems = 576*kv + # C_elems = num_heads*q_len*kv + # Enforce A_elems + B_elems + C_elems <= max_elems + a2_elems = num_heads * q_len * 576 + rem2 = max_elems - a2_elems + den2 = num_heads * q_len + 576 # kv coefficient from B_elems + C_elems + k2 = rem2 // den2 if rem2 >= 0 else -1 + + kv = min(k1, k2) + return max(0, kv) + + +def get_num_kv_blocks_for_mla(q_len, num_heads, ctx_len): + """Compute the maximum kv_block_size under an fp16 memory budget. + + Constraints (bytes) per matmul: + 1) [1, num_heads, q_len, 576] x [1, 1, 576, kv] -> [1, num_heads, q_len, kv] + 2) [1, num_heads, q_len, kv] x [1, 1, kv, 512] -> [1, num_heads, q_len, 512] + + For each matmul, sum(input_a + input_b + output) must be < budget. + The returned kv_block_size satisfies both constraints. + """ + budget_bytes = VTCM_SIZE_THRESHOLD + kv = max_kv_block_size(q_len, budget_bytes, num_heads) + b1 = matmul1_bytes(q_len, kv, num_heads) + b2 = matmul2_bytes(q_len, kv, num_heads) + + assert b1 < budget_bytes, "matmul1 is not under the budget" + assert b2 < budget_bytes, "matmul2 is not under the budget" + + kv_block_size = ctx_len + kv_block_size_list = block_candidates_generator(ctx_len) + for i in range(len(kv_block_size_list) - 1): + if kv_block_size_list[i] < kv < kv_block_size_list[i + 1]: + kv_block_size = kv_block_size_list[i] + return ctx_len // kv_block_size + + def attention_configurator( bs: int, seq_len: int, @@ -201,6 +287,10 @@ def build_transformer_blocking_config( blocking_mode=blocking_mode, ) + if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []): + if "kv" in blocking_mode: + attention_cfg["num_kv_blocks"] = get_num_kv_blocks_for_mla(seq_len, num_heads, ctx_len) + resolved_mode = _normalize_attention_mode(blocking_mode or "hqkv") effective_mode = _resolve_effective_blocking_mode(attention_cfg, resolved_mode) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index cb4c534ea7..4dffa1f7c5 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -502,6 +502,12 @@ def __init__( self._session.skip_buffers( [x for x in self._session.input_names + self._session.output_names if x.startswith("past_")] ) + self._session.skip_buffers( + [x for x in self._session.input_names + self._session.output_names if x.startswith("compressed_")] + ) + self._session.skip_buffers( + [x for x in self._session.input_names + self._session.output_names if x.startswith("k_pe")] + ) def _set_tokenizer_params(self): """ @@ -843,6 +849,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i ] if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] + outputs = self._session.run(chunk_inputs) if self._write_io_dir is not None: diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index e8d9e004cf..799717bf83 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -343,6 +343,156 @@ def update3D( return k_out, v_out +class QEffDynamicCompressedKVRopeLayer: + def __init__(self, ckv, k_pe): + self.ckv = ckv + self.k_pe = k_pe + + def update_ckv(self, compressed_kv, cache_kwargs): + position_ids = cache_kwargs.get("position_ids") + + self.ckv = CtxScatterFunc.apply(self.ckv, position_ids, compressed_kv) + + ckv_out = self.ckv + ctx_len = ckv_out.shape[-2] + ctx_indices = torch.arange(ctx_len)[None, ...] + gather_limit = position_ids.max(1, keepdim=True).values + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + ckv_out = CtxGatherFunc.apply(ckv_out, ctx_indices, ctx_len) + ckv_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), ckv_out) + return ckv_out + + def update_k_pe(self, k_pe_cache, cache_kwargs): + position_ids = cache_kwargs.get("position_ids") + + self.k_pe = CtxScatterFunc.apply(self.k_pe, position_ids, k_pe_cache) + + k_pe_out = self.k_pe + ctx_len = k_pe_out.shape[-2] + ctx_indices = torch.arange(ctx_len)[None, ...] + gather_limit = position_ids.max(1, keepdim=True).values + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + k_pe_out = CtxGatherFunc.apply(k_pe_out, ctx_indices, ctx_len) + k_pe_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), k_pe_out) + return k_pe_out + + def read_only_blocked_ckv(self, start_index, end_index, cache_kwargs): + # Gather + ckv_out = self.ckv + position_ids = cache_kwargs.get("position_ids") + batch, num_kv_heads, _, _ = ckv_out.shape + ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1]) + ckv_out = CtxGatherFuncBlockedKV.apply(ckv_out, ctx_indices) + + ckv_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), ckv_out) + return ckv_out + + def read_only_blocked_k_pe(self, start_index, end_index, cache_kwargs): + # Gather + k_pe_out = self.k_pe + position_ids = cache_kwargs.get("position_ids") + batch, num_kv_heads, _, _ = k_pe_out.shape + ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1]) + k_pe_out = CtxGatherFuncBlockedKV.apply(k_pe_out, ctx_indices) + + k_pe_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), k_pe_out) + return k_pe_out + + def write_only_k_pe(self, k_pe_cache, cache_kwargs): + position_ids = cache_kwargs.get("position_ids") + + self.k_pe = CtxScatterFunc.apply(self.k_pe, position_ids, k_pe_cache) + return self.k_pe + + def write_only_ckv(self, compressed_kv, cache_kwargs): + position_ids = cache_kwargs.get("position_ids") + + self.ckv = CtxScatterFunc.apply(self.ckv, position_ids, compressed_kv) + return self.ckv + + +class QEffDynamicCompressedKVRopeCache: + def __init__( + self, + ): + self.layers = [] + + def add_new(self, ckv, k_pe, layer_idx): + self.layers.append(QEffDynamicCompressedKVRopeLayer(ckv, k_pe)) + + def update_ckv(self, ckv, layer_idx, cache_kwargs): + return self.layers[layer_idx].update_ckv(ckv, cache_kwargs) + + def update_k_pe(self, k_pe, layer_idx, cache_kwargs): + return self.layers[layer_idx].update_k_pe(k_pe, cache_kwargs) + + def read_only_blocked_ckv(self, start_index, end_index, layer_idx, cache_kwargs): + return self.layers[layer_idx].read_only_blocked_ckv(start_index, end_index, cache_kwargs) + + def read_only_blocked_k_pe(self, start_index, end_index, layer_idx, cache_kwargs): + return self.layers[layer_idx].read_only_blocked_k_pe(start_index, end_index, cache_kwargs) + + def write_only_ckv(self, ckv, layer_idx, cache_kwargs): + # self.append_new_layers(layer_idx) + return self.layers[layer_idx].write_only_ckv(ckv, cache_kwargs) + + def write_only_k_pe(self, k_pe, layer_idx, cache_kwargs): + # self.append_new_layers(layer_idx) + return self.layers[layer_idx].write_only_k_pe(k_pe, cache_kwargs) + + @classmethod + def from_legacy_cache(cls, past_key_values): + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + ckv, k_pe = past_key_values[layer_idx] + cache.add_new(ckv, k_pe, layer_idx) + return cache + + def to_legacy_cache( + self, + ): + legacy_cache = () + for layer in self.layers: + x = (layer.ckv, layer.k_pe) + legacy_cache += (x,) + return legacy_cache + + class QEffDynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index a29d0e0966..f9d7fe62cd 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -196,7 +196,7 @@ DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} # This is for supporting different modelling classes specially written for prefill-only model -SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss"} +SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "kimi_k2", "kimi_k25"} _PROXY_ONLY_ONNX_TRANSFORMS = (FP16ClipTransform, SplitTensorsTransform) diff --git a/QEfficient/transformers/models/deepseek_v3/__init__.py b/QEfficient/transformers/models/deepseek_v3/__init__.py new file mode 100644 index 0000000000..da26921c50 --- /dev/null +++ b/QEfficient/transformers/models/deepseek_v3/__init__.py @@ -0,0 +1,7 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + diff --git a/QEfficient/transformers/models/deepseek_v3/configuration_deepseek.py b/QEfficient/transformers/models/deepseek_v3/configuration_deepseek.py new file mode 100644 index 0000000000..7f68c3d86e --- /dev/null +++ b/QEfficient/transformers/models/deepseek_v3/configuration_deepseek.py @@ -0,0 +1,219 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class DeepseekV3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + + >>> # Initializing a Deepseek-V3 style configuration + >>> configuration = DeepseekV3Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size=2048, + num_hidden_layers=61, + num_nextn_predict_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + n_shared_experts=1, + n_routed_experts=256, + ep_size=1, + routed_scaling_factor=2.5, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="noaux_tc", + n_group=8, + topk_group=4, + num_experts_per_tok=8, + moe_layer_freq=1, + first_k_dense_replace=3, + norm_topk_prob=True, + scoring_func="sigmoid", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py new file mode 100644 index 0000000000..57dd4793f2 --- /dev/null +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -0,0 +1,1224 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from QEfficient.blocking.attention_blocking import ( + AttentionBlockingConfig, + generic_blocked_attention_interface, + generic_blocked_mla_attention_interface, +) +from QEfficient.customop.rms_norm import CustomRMSNormFunc +from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffDynamicCompressedKVRopeCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MAX_POSITION_EMBEDDINGS, MIN_MASKED_ATTENTION_VALUE + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float16) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class QEffDeepseekV3CustomRMSNormAIC(nn.Module): + """ + RMSNorm module that works by replacing the current module with compiler known custom-op. + """ + + def forward(self, hidden_states): + """ + Forward pass of the RMSNorm module. + + Args: + hidden_states (torch.Tensor): Input tensor to be normalized. + + Returns: + torch.Tensor: Normalized tensor. + """ + return CustomRMSNormFunc.apply( + hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps + ) + + +class DeepseekV3RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float16, device=device) / dim)) + freq_inter = 1.0 / ( + self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float16, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float16) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float16) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False) + self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def orig_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class QEffDeepseekV3Attention(nn.Module): + """Adapted DeepseekV3Attention with QEff logic, adding batch_index and proper position_ids handling.""" + + def __qeff_init__( + self, + ): + q_up, q_rope = self.q_b_proj.weight.T.view( + -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim + ).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + q_up = q_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) + self.q_up = torch.nn.Parameter(q_up.detach().clone()) + + q_rope = q_rope.reshape(-1, self.num_heads * self.qk_rope_head_dim).unsqueeze(0) + self.q_rope = torch.nn.Parameter(q_rope.detach().clone()) + + k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + k_up = k_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0) + v_up = v_up.reshape(-1, self.num_heads * self.v_head_dim).unsqueeze(0) + + self.k_up = torch.nn.Parameter(k_up.detach()) + self.v_up = torch.nn.Parameter(v_up.detach()) + per_head_q_up = self.q_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + per_head_k_up = ( + self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2) + ) + per_head_v_up = self.v_up.squeeze(0).view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + self.per_head_v_up = torch.nn.Parameter(per_head_v_up.unsqueeze(0).detach().clone()) + self.per_head_q_up = torch.nn.Parameter(per_head_q_up.unsqueeze(0).detach().clone()) + self.per_head_k_up = torch.nn.Parameter(per_head_k_up.unsqueeze(0).detach().clone()) + per_head_k_up_normal = self.per_head_k_up.transpose(2, 3) + self.per_head_k_up_normal = torch.nn.Parameter(per_head_k_up_normal.detach().clone()) + + fusedqk = torch.bmm(per_head_q_up, per_head_k_up).reshape( + -1, self.num_heads, self.q_lora_rank, self.kv_lora_rank + ) + self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone()) + + def fused_forward_h_blocking( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) + + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] + + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) + q_pe = torch.matmul(q_a_proj_out, self.q_rope) + q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + + kva = self.kv_a_layernorm(kva) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + window_cache_layer_idx = self.layer_idx - getattr(QEffDeepseekV3Model, "_start", 0) + + if compressed_kvs is not None: + kva = compressed_kvs.update_ckv(kva, window_cache_layer_idx, cache_kwargs) + + cos, sin = self.rotary_emb(kva, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + if compressed_kvs is not None: + k_pe = compressed_kvs.update_k_pe(k_pe, window_cache_layer_idx, cache_kwargs) + + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + + attn_output, attn_weights = generic_blocked_mla_attention_interface( + module=self, + q_a_proj_out=q_a_proj_out, + fusedqk=self.fusedqk, + q_nope=q_nope, + q_pe=q_pe, + kva=kva, + k_pe=k_pe, + per_head_q_up=self.per_head_q_up, + per_head_k_up=self.per_head_k_up, + per_head_v_up=self.per_head_v_up, + per_head_k_up_normal=self.per_head_k_up_normal, + attention_mask=attention_mask, + scaling=self.softmax_scale, + mla_absorption=mla_absorption, + blocking_config=blocking_config, + position_ids=position_ids, + **kwargs, + ) + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, compressed_kvs + + def fused_forward_kv_blocking( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) + + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] + + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) + q_pe = torch.matmul(q_a_proj_out, self.q_rope) + q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) + + kva = self.kv_a_layernorm(kva) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + window_cache_layer_idx = self.layer_idx - getattr(QEffDeepseekV3Model, "_start", 0) + + ## Write Only + if compressed_kvs is not None: + compressed_kvs.write_only_ckv(kva, window_cache_layer_idx, cache_kwargs) + + cos, sin = self.rotary_emb(hidden_states, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + if compressed_kvs is not None: + compressed_kvs.write_only_k_pe(k_pe, window_cache_layer_idx, cache_kwargs) + + if mla_absorption is not None: + absorption = mla_absorption.get("absorption", False) + online = mla_absorption.get("online", False) + else: + absorption = False + + if absorption: + if online: + qup_kupT = torch.matmul(self.per_head_q_up, self.per_head_k_up) + dq_qup_kupT = torch.matmul(q_a_proj_out, qup_kupT) + else: + dq_qup_kupT = torch.matmul(q_a_proj_out, self.fusedqk) + qkupTrope_nope = torch.cat((dq_qup_kupT, q_pe), dim=-1) + query = qkupTrope_nope + else: + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + qnope_rope = torch.cat((q_nope, q_pe), dim=-1) + query = qnope_rope + + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + + attn_output, attn_weights = generic_blocked_mla_attention_interface( + module=self, + query=query, + per_head_k_up_normal=self.per_head_k_up_normal, + per_head_v_up=self.per_head_v_up, + attention_mask=attention_mask, + scaling=self.softmax_scale, + layer_idx=self.layer_idx, + compressed_kvs=compressed_kvs, + mla_absorption=mla_absorption, + blocking_config=blocking_config, + position_ids=position_ids, + **kwargs, + ) + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, None, compressed_kvs + + def fused_forward_orig( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + # ---- KV compression ---- + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) + + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] + + # ---- Q projections ---- + q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states)) + + q_pe = torch.bmm(q_a_proj_out, self.q_rope) + q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2) + + kva = self.kv_a_layernorm(kva) + + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + window_cache_layer_idx = self.layer_idx - getattr(QEffDeepseekV3Model, "_start", 0) + + if compressed_kvs is not None: + kva = compressed_kvs.update_ckv(kva, window_cache_layer_idx, cache_kwargs) + + # ---- MLA absorption flags ---- + if mla_absorption is not None: + absorption = mla_absorption.get("absorption", False) + online = mla_absorption.get("online", False) + else: + absorption = False + + head_block_size = kva.shape[1] + p = self.num_heads // head_block_size + seq_kv = kva.shape[2] + + # ---- Rotary ---- + cos, sin = self.rotary_emb(q_pe, seq_len=32 * 1024) # Doesn't need q_pe as head_dim is initialized + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + if compressed_kvs is not None: + k_pe = compressed_kvs.update_k_pe(k_pe, window_cache_layer_idx, cache_kwargs) + + kva_expanded = ( + kva.unsqueeze(2).expand(-1, -1, p, -1, -1).reshape(bsz, self.num_heads, seq_kv, self.kv_lora_rank) + ) + + k_pe_expanded = ( + k_pe.unsqueeze(2).expand(-1, -1, p, -1, -1).reshape(bsz, self.num_heads, seq_kv, self.qk_rope_head_dim) + ) + + v_up_per_head = self.v_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.v_head_dim).permute(1, 0, 2) + value_states = torch.matmul(kva_expanded, v_up_per_head) + + if absorption: + if online: + out = torch.matmul(self.per_head_q_up, self.per_head_k_up) + q_nope_compressed = torch.matmul(q_a_proj_out.unsqueeze(1), out) + else: + q_nope_compressed = torch.matmul( + q_a_proj_out.unsqueeze(1), + self.fusedqk, + ) + query_states = torch.cat((q_nope_compressed, q_pe), dim=-1) + key_states = torch.cat((kva_expanded, k_pe_expanded), dim=-1) + else: + q_nope = torch.bmm(q_a_proj_out, self.q_up) + q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) + query_states = torch.cat((q_nope, q_pe), dim=-1) + + k_up_per_head = ( + self.k_up.squeeze(0).view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim).permute(1, 0, 2) + ) + k_nope = torch.matmul(kva_expanded, k_up_per_head) + key_states = torch.cat((k_nope, k_pe_expanded), dim=-1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, + torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype), + attn_weights, + ) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float16).to(q_pe.dtype) + ## Do v_proj here + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, compressed_kvs + + def fused_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + if getattr(blocking_config, "mode", None) == "h": + return self.fused_forward_h_blocking( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + compressed_kvs, + batch_index, + output_attentions, + use_cache, + cache_position, + mla_absorption, + **kwargs, + ) + elif getattr(blocking_config, "mode", None) == "kv": + return self.fused_forward_kv_blocking( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + compressed_kvs, + batch_index, + output_attentions, + use_cache, + cache_position, + mla_absorption, + **kwargs, + ) + else: + return self.fused_forward_orig( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + compressed_kvs, + batch_index, + output_attentions, + use_cache, + cache_position, + mla_absorption, + **kwargs, + ) + + def forward_full_kv( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + + q_nope = q[:, :, :, : self.qk_nope_head_dim] + q_pe = q[:, :, :, self.qk_nope_head_dim :] + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + + kva = compressed_kv[:, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, self.kv_lora_rank :] + + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(kva)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope = kv[:, :, :, : self.qk_nope_head_dim] + value_states = kv[:, :, :, self.qk_nope_head_dim :] + + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = torch.cat((q_nope, q_pe), -1) + k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) + key_states = torch.cat((k_nope, k_pe_new), -1) + window_cache_layer_idx = self.layer_idx - getattr(QEffDeepseekV3Model, "_start", 0) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update( + key_states, value_states, window_cache_layer_idx, cache_kwargs + ) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float16), attn_weights + ) + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float16).to(query_states.dtype) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + def forward_full_kv_h_blocking( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + + q_nope = q[:, :, :, : self.qk_nope_head_dim] + q_pe = q[:, :, :, self.qk_nope_head_dim :] + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank + self.qk_rope_head_dim).transpose(1, 2) + + kva = compressed_kv[:, :, :, : self.kv_lora_rank] + k_pe = compressed_kv[:, :, :, self.kv_lora_rank :] + + kv = ( + self.kv_b_proj(self.kv_a_layernorm(kva)) + .view( + bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) # TODO : split this matmul #with k_up and v_up + .transpose(1, 2) + ) + + k_nope = kv[:, :, :, : self.qk_nope_head_dim] + value_states = kv[:, :, :, self.qk_nope_head_dim :] + + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = torch.cat((q_nope, q_pe), -1) + k_pe_new = k_pe.expand(-1, self.num_heads, -1, -1) + key_states = torch.cat((k_nope, k_pe_new), -1) + + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + + attn_output, attn_weights = generic_blocked_attention_interface( + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + scaling=self.softmax_scale, + layer_idx=self.layer_idx, + past_key_value=past_key_value, + blocking_config=blocking_config, + batch_index=batch_index, + position_ids=position_ids, + ) + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + if getattr(blocking_config, "mode", None) == "h": + return self.forward_full_kv_h_blocking( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + batch_index, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) + else: + return self.forward_full_kv( + hidden_states, + position_embeddings, + attention_mask, + position_ids, + past_key_value, + batch_index, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) + + +class QEffDeepseekV3MoE(nn.Module): + def __qeff_init__( + self, + ): + self.all_gate_proj = torch.nn.Parameter( + torch.cat( + [exp.gate_proj.compressor.decompress_module(exp.gate_proj).T.unsqueeze(0) for exp in self.experts], + dim=0, + ) + ) + self.all_up_proj = torch.nn.Parameter( + torch.cat( + [exp.up_proj.compressor.decompress_module(exp.up_proj).T.unsqueeze(0) for exp in self.experts], dim=0 + ) + ) + self.all_down_proj = torch.nn.Parameter( + torch.cat( + [exp.down_proj.compressor.decompress_module(exp.down_proj).T.unsqueeze(0) for exp in self.experts], + dim=0, + ) + ) + self.act_fn = self.experts[0].act_fn + + def moe( + self, + hidden_states: torch.Tensor, + topk_indices: torch.Tensor, + topk_weights: torch.Tensor, + ): + seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + + gate_proj = self.all_gate_proj[topk_indices.flatten()] + up_proj = self.all_up_proj[topk_indices.flatten()] + down_proj = self.all_down_proj[topk_indices.flatten()] + expert_in = ( + hidden_states.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous().view(-1, 1, self.config.hidden_size) + ) + gate_out = torch.bmm(expert_in, gate_proj) + up_out = torch.bmm(expert_in, up_proj) + hidden = self.act_fn(gate_out) * up_out + expert_output = torch.bmm(hidden, down_proj) + experts_out = expert_output.view(seq_len, self.gate.top_k, self.config.hidden_size) + experts_out = experts_out * topk_weights.unsqueeze(-1) + + final_hidden_states = torch.einsum("abc->ac", experts_out) + + return final_hidden_states.type(hidden_states.dtype) + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +class QEffPrefillOnlyDeepseekV3MoE(nn.Module): + def __qeff_init__( + self, + ): + for exp in self.experts: + gate_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) + up_proj = torch.nn.Linear(self.config.hidden_size, self.config.moe_intermediate_size, bias=False) + down_proj = torch.nn.Linear(self.config.moe_intermediate_size, self.config.hidden_size, bias=False) + + gate_proj.weight = torch.nn.Parameter(exp.gate_proj.compressor.decompress_module(exp.gate_proj)) + up_proj.weight = torch.nn.Parameter(exp.up_proj.compressor.decompress_module(exp.up_proj)) + down_proj.weight = torch.nn.Parameter(exp.down_proj.compressor.decompress_module(exp.down_proj)) + + setattr(exp, "gate_proj", gate_proj) + setattr(exp, "up_proj", up_proj) + setattr(exp, "down_proj", down_proj) + + def moe(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, expert_mask: torch.Tensor, num_experts: int): + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + for expert_idx in range(num_experts): + expert = self.experts[expert_idx] + gate_out = expert.gate_proj(hidden_states) + up_out = expert.up_proj(hidden_states) + hidden = expert.act_fn(gate_out) * up_out + expert_output = expert.down_proj(hidden) + current_hidden_states = expert_output * expert_mask[:, expert_idx].unsqueeze(-1) + final_hidden_states += current_hidden_states + + return final_hidden_states.type(hidden_states.dtype) + + def orig_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + r""" + CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused + to not have to do a loop here (deepseek has 256 experts soooo yeah). + """ + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) + + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, weighted_output) + + # in original deepseek, the output of the experts are gathered once we leave this module + # thus the moe module is itelsf an IsolatedParallel module + # and all expert are "local" meaning we shard but we don't gather + return final_hidden_states.type(hidden_states.dtype) + + def forward(self, hidden_states): + """ + Forward pass of MoE block. + """ + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + # orig_out = self.orig_moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + mask = torch.zeros(hidden_states.shape[0], self.config.n_routed_experts) + mask.scatter_(1, topk_indices, topk_weights) + if os.environ.get("NUM_FFN_BLOCKS", None) is not None and os.environ.get("FFN_W_BLOCK_SIZE", None) is not None: + hidden_states = self.moe_blocked_weights_forward( + hidden_states, topk_weights, mask, self.config.n_routed_experts + ).view(*orig_shape) + elif os.environ.get("NUM_FFN_BLOCKS", None) is not None: + hidden_states = self.moe_blocked_forward( + hidden_states, topk_weights, mask, self.config.n_routed_experts + ).view(*orig_shape) + else: + hidden_states = self.moe(hidden_states, topk_weights, mask, self.config.n_routed_experts).view(*orig_shape) + + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +class QEffDeepseekV3DecoderLayer(nn.Module): + """Adapted DeepseekV3DecoderLayer with batch_index and proper position_ids handling.""" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + compressed_kvs: Optional[torch.Tensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + orig_hidden_states = self.input_layernorm(hidden_states) + if mla_absorption is not None: + cache_compressed = mla_absorption.get("cache_compressed", False) + else: + cache_compressed = False + if cache_compressed: + hidden_states, self_attn_weights, present_compressed_kvs = self.self_attn.fused_forward( + hidden_states=orig_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + past_key_value=past_key_value, + compressed_kvs=compressed_kvs, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + mla_absorption=mla_absorption, + **kwargs, + ) + else: + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=orig_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + if cache_compressed: + outputs += (present_compressed_kvs,) + else: + outputs += (present_key_value,) + + return outputs + + +class QEffDeepseekV3Model(nn.Module): + """Adapted DeepseekV3Model with batch_index and QEff rotary embedding.""" + + _start = 0 + _end = 0 + _total_layers = None + + def __qeff_init__(self): + scaling_factor = self.config.rope_scaling["factor"] + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + self.config.qk_rope_head_dim, + max_position_embeddings=MAX_POSITION_EMBEDDINGS, + scaling_factor=scaling_factor, + base=self.config.rope_theta, + **kwargs, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + compressed_kvs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + mla_absorption: Optional[Dict[str, bool]] = None, + layer_indices_to_run: Optional[List[int]] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and not isinstance(past_key_values, Cache) and past_key_values is not None: + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if mla_absorption is not None: + cache_compressed = mla_absorption.get("cache_compressed", False) + else: + cache_compressed = False + + if cache_compressed: + compressed_kvs = QEffDynamicCompressedKVRopeCache.from_legacy_cache(compressed_kvs) + # target_len = compressed_kvs.layers[0].ckv.shape[-2] + # else: + # target_len = past_key_values[0][0].shape[2] + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_len) + start = QEffDeepseekV3Model._start + end = QEffDeepseekV3Model._end + + ctx_len = compressed_kvs.layers[0].ckv.shape[-2] + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=ctx_len) + hidden_states = inputs_embeds + position_embeddings = None + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for layer_idx, decoder_layer in enumerate(self.layers): + if layer_idx < start or layer_idx >= end: + continue + if layer_indices_to_run is not None and layer_idx not in layer_indices_to_run: + continue + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + compressed_kvs=compressed_kvs, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + mla_absorption=mla_absorption, + **kwargs, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + total_layers = getattr(QEffDeepseekV3Model, "_total_layers", len(self.layers)) + if QEffDeepseekV3Model._end == total_layers: + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + next_cache = next_cache.to_legacy_cache() + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class QEffDeepseekV3ForCausalLM(nn.Module): + """Adapted DeepseekV3ForCausalLM with batch_index and QEff optimizations.""" + + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {self.model.layers[QEffDeepseekV3Model._start].__class__} + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + compressed_kvs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + layer_indices_to_run: Optional[List[int]] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + mla_absorption = getattr(self, "mla_absorption", None) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + compressed_kvs=compressed_kvs, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + mla_absorption=mla_absorption, + layer_indices_to_run=layer_indices_to_run, + **kwargs, + ) + + hidden_states = outputs[0] + total_layers = getattr(QEffDeepseekV3Model, "_total_layers", len(self.model.layers)) + if QEffDeepseekV3Model._end < total_layers: + logits = hidden_states + else: + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = hidden_states[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states).float() + + loss = None + if labels is not None: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def get_dummy_pkv_cache(self, config, batch_size, seq_len): + mla_absorption = getattr(self, "mla_absorption", None) + if mla_absorption is not None: + cache_compressed = mla_absorption.get("cache_compressed", False) + else: + cache_compressed = False + + dummy_cache = [[] for _ in range(config.num_hidden_layers)] + if cache_compressed: + for layer in self.model.layers: + if layer is not None: + num_heads = layer.self_attn.kv_a_proj_with_mqa.weight.shape[0] // ( + self.model.config.kv_lora_rank + config.qk_rope_head_dim + ) + cache_shape_1 = (batch_size, num_heads, seq_len, config.kv_lora_rank) + cache_shape_2 = (batch_size, num_heads, seq_len, config.qk_rope_head_dim) + else: + cache_shape_1 = ( + batch_size, + config.num_attention_heads, + seq_len, + config.qk_nope_head_dim + config.qk_rope_head_dim, + ) + cache_shape_2 = (batch_size, config.num_attention_heads, seq_len, config.v_head_dim) + + for i in range(config.num_hidden_layers): + dummy_cache[i].append(torch.zeros(cache_shape_1, dtype=config.torch_dtype)) + dummy_cache[i].append(torch.zeros(cache_shape_2, dtype=config.torch_dtype)) + + return dummy_cache diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 10dc5ddd90..e6561178a4 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -9,7 +9,7 @@ import warnings from pathlib import Path from time import perf_counter -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np import torch @@ -51,8 +51,10 @@ KVCacheTransform, PoolingTransform, PrefillOnlyChunkedTransform, + PrefillOnlyExternalModuleMapperTransform, PrefillOnlyTransform, RevertPrefillKeepAttentionTransform, + RevertPrefillOnlyExternalModuleMapperTransform, RevertPrefillOnlyTransform, SamplerTransform, SpDTransform, @@ -2770,12 +2772,14 @@ def prefill( retain_full_kv: Optional[bool] = False, ): if enable: + self.model, tf = PrefillOnlyExternalModuleMapperTransform.apply(self.model) if enable_chunking: self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) else: self.model, tf = PrefillOnlyTransform.apply(self.model) else: + self.model, tf = RevertPrefillOnlyExternalModuleMapperTransform.apply(self.model) if retain_full_kv: self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) else: @@ -2788,12 +2792,14 @@ def __update_prefill_transform( retain_full_kv: Optional[bool] = False, ): if enable: + self.model, tf = PrefillOnlyExternalModuleMapperTransform.apply(self.model) if enable_chunking: self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) else: self.model, tf = PrefillOnlyTransform.apply(self.model) else: + self.model, tf = RevertPrefillOnlyExternalModuleMapperTransform.apply(self.model) if retain_full_kv: self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) else: @@ -2868,6 +2874,9 @@ def __init__( self.ccl_enabled = False if qaic_config: self.ccl_enabled = qaic_config.get("ccl_enabled", False) + if mla_absorption := qaic_config.get("mla_absorption", None): + self.hash_params["mla_absorption"] = mla_absorption + setattr(self.model, "mla_absorption", mla_absorption) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.hash_params["max_seq_len_cached"] = max_seq_len_cached @@ -3074,10 +3083,13 @@ def export( block_size = -(-seq_len // max_blocks) seq_len = block_size * max_blocks fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + kv_cache_shape = get_padding_shape_from_config( self.model.config, fbs if self.continuous_batching else bs, seq_len ) enable_chunking = kwargs.get("enable_chunking", False) + + # TODO: move this to a DA Serving utility class if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH: if prefill_only: if not enable_chunking and self.continuous_batching: @@ -3179,6 +3191,41 @@ def export( dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] output_names.append(f"past_{kv}.{i}_RetainedState") + if "DeepseekV3ForCausalLM" in (getattr(self.model.config, "architectures", None) or []): + if self.model.qaic_config is not None and self.model.qaic_config.get("mla_absorption", None) is not None: + mla_absorption = self.model.qaic_config["mla_absorption"] + cache_compressed = mla_absorption.get("cache_compressed", False) + else: + cache_compressed = False + pkv_cache = self.model.get_dummy_pkv_cache( + self.model.config, fbs if self.continuous_batching else bs, seq_len + ) + if cache_compressed: + example_inputs = {k: v for k, v in example_inputs.items() if "past" not in k} + dynamic_axes = {k: v for k, v in dynamic_axes.items() if "past" not in k} + output_names = [v for v in output_names if "past" not in v] + example_inputs["compressed_kvs"] = [[] for _ in range(self.num_layers)] + for i in range(self.num_layers): + example_inputs["compressed_kvs"][i].append( + torch.zeros(pkv_cache[0][0].shape, dtype=self.model.config.torch_dtype) + ) + example_inputs["compressed_kvs"][i].append( + torch.zeros(pkv_cache[0][1].shape, dtype=self.model.config.torch_dtype) + ) + dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 2: "ctx_len"} + dynamic_axes[f"k_pe.{i}"] = {0: "batch_size", 2: "ctx_len"} + output_names.append(f"compressed_kv.{i}_RetainedState") + output_names.append(f"k_pe.{i}_RetainedState") + else: + example_inputs["past_key_values"] = [[] for _ in range(self.num_layers)] + for i in range(self.num_layers): + example_inputs["past_key_values"][i].append( + torch.zeros(pkv_cache[0][0].shape, dtype=self.model.config.torch_dtype) + ) + example_inputs["past_key_values"][i].append( + torch.zeros(pkv_cache[0][1].shape, dtype=self.model.config.torch_dtype) + ) + if self.continuous_batching: example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) dynamic_axes["batch_index"] = {0: "batch_size"} @@ -3355,9 +3402,11 @@ def compile( offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, retain_full_kv: Optional[bool] = None, + mla_absorption: Optional[Dict[str, bool]] = None, **compiler_options, ) -> str: """ + Compile the exported ONNX model using the Cloud AI 100 Platform SDK compiler. This method generates a ``qpc`` package. If the model has not been exported yet, @@ -3398,6 +3447,11 @@ def compile( the decode stage. If None, compiles for both stages. Default is None. use_onnx_subfunctions: bool, optional whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + mla_absorption: Dict[str, bool], optional + Configuration dictionary for multi-head latent Attention (MLA) absorption behavior. + - "cache_compressed" (bool): If True, compresses kvs are cached to save memory. + - "absorption" (bool): If True, enables absorption of attention matrices for efficiency. + - "online" (bool): If True, applies MLA absorption on device during inference **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -3435,6 +3489,13 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ + if self.model.qaic_config is not None and self.model.qaic_config.get("mla_absorption", None) is not None: + mla_absorption = self.model.qaic_config["mla_absorption"] + cache_compressed = mla_absorption.get("cache_compressed", False) + else: + cache_compressed = False + if mla_absorption is not None and not cache_compressed: + logger.warning("mla_absorption will be ignored as cache_compressed is set to False") if (kv_cache_batch_size or full_batch_size) and not self.continuous_batching: logger.warning( "`kv_cache_batch_size` or `full_batch_size` is being passed" @@ -3558,15 +3619,22 @@ def compile( if kw_spec := compiler_options.pop("specializations", None): specializations = kw_spec - # --- Compilation --- - custom_io = {} + target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[target_dtype] + # --- Compilation --- + custom_io = {} + if not cache_compressed: + for suffix in ["", "_RetainedState"]: + for i in range(self.num_layers): + for kv in ["key", "value"]: + custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + else: + for suffix in ["", "_RetainedState"]: + for i in range(self.num_layers): + custom_io[f"compressed_kv.{i}{suffix}"] = kv_cache_dtype + custom_io[f"k_pe.{i}{suffix}"] = kv_cache_dtype - for suffix in ["", "_RetainedState"]: - for i in range(self.num_layers): - for kv in ["key", "value"]: - custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype qpc_path = self._compile( onnx_path=onnx_path, compile_dir=compile_dir, @@ -3585,6 +3653,7 @@ def compile( offload_pt_weights=offload_pt_weights, enable_chunking=enable_chunking, retain_full_kv=retain_full_kv, + mla_absorption=mla_absorption, **compiler_options, ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index a5e16489fe..5ff06e6443 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -9,6 +9,7 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union +import torch from torch import nn from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -248,6 +249,15 @@ from QEfficient.transformers.models.deberta_v2.modeling_deberta_v2 import ( QEffDisentangledSelfAttention, ) +from QEfficient.transformers.models.deepseek_v3.modeling_deepseek import ( + QEffDeepseekV3Attention, + QEffDeepseekV3CustomRMSNormAIC, + QEffDeepseekV3DecoderLayer, + QEffDeepseekV3ForCausalLM, + QEffDeepseekV3Model, + QEffDeepseekV3MoE, + QEffPrefillOnlyDeepseekV3MoE, +) from QEfficient.transformers.models.falcon.modeling_falcon import ( QEffFalconAttention, QEffFalconDecoderLayer, @@ -499,6 +509,7 @@ from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry from QEfficient.transformers.sampler.sampler import sampler_forward from QEfficient.transformers.spd.spd_transform_forward import tlm_forward +from QEfficient.utils.logging_utils import logger SPD_TARGET = "target" @@ -766,6 +777,76 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform): } +class ReplicateKVHeadTransform: + """ + Replicates KV heads in attention modules to match the number of KV heads in the target model. + This transform is used when the source model has fewer KV heads than required in target model. + """ + + def _duplicate_weights_for_linear_layer( + layer: nn.Module, orig_kv_heads: int, repeat: int, dim: int, hidden_size: int + ): + new_kv_heads = repeat # for mla + + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, dim, hidden_size), repeat, 0 + ).view(new_kv_heads * dim, hidden_size) + + if layer.bias is not None: + layer.bias.data = torch.repeat_interleave(layer.bias.data.view(orig_kv_heads, dim), repeat, 0).view( + new_kv_heads * dim + ) + + def _get_text_model(model): + """ + Determine and return the appropriate text_model from a given model object. + """ + # Check for VLMs + if hasattr(model, "language_model"): + if hasattr(model.language_model, "model"): + return model.language_model.model + else: + return model.language_model + # Check for CausalLMs + if hasattr(model, "model"): + return model.model + + raise AttributeError("No suitable text model found in the provided model.") + + @classmethod + def apply(cls, model: nn.Module, num_kv_heads_repeat: int = 1) -> nn.Module: + """ + Replicates KV heads in attention modules based on provided multiplier. + + Args: + model: The model to apply the transform to. + num_kv_heads_repeat: The number of times to repeat the KV heads. + """ + transformed = False + if num_kv_heads_repeat is not None and num_kv_heads_repeat > 1: + text_model = cls._get_text_model(model) + + orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads + new_kv_heads = num_kv_heads_repeat * orig_kv_heads + text_model.config.orig_kv_heads = orig_kv_heads + text_model.config.num_key_value_heads = new_kv_heads + + hidden_size = text_model.config.hidden_size + + logger.warning(f"Original KV heads: {orig_kv_heads}") + logger.warning(f"Modified KV heads: {new_kv_heads}") + transformed = True + for block in text_model.layers: + attn = getattr(block, "cross_attn", getattr(block, "self_attn", None)) + attn.num_key_value_heads = new_kv_heads + head_dim = attn.kv_lora_rank + attn.qk_rope_head_dim + + cls._duplicate_weights_for_linear_layer( + attn.kv_a_proj_with_mqa, orig_kv_heads, num_kv_heads_repeat, head_dim, hidden_size + ) + return model, transformed + + class SpDTransform: """ Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. @@ -886,6 +967,7 @@ class VlmNoKVOffloadTransform(ModuleMappingTransform): class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): + _match_class_replace_method = {} _match_string_replace_method = { "InternVLChatModel": { "forward": QEffInternVLModel.forward, @@ -941,9 +1023,56 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "RMSNorm": { "forward": QEFFGrok1CustomRMSNormAIC.forward, }, + "DeepseekV3ForCausalLM": { + "forward": QEffDeepseekV3ForCausalLM.forward, + "get_submodules_for_export": QEffDeepseekV3ForCausalLM.get_submodules_for_export, + "get_dummy_pkv_cache": QEffDeepseekV3ForCausalLM.get_dummy_pkv_cache, + }, + "DeepseekV3Model": {"forward": QEffDeepseekV3Model.forward, "__qeff_init__": QEffDeepseekV3Model.__qeff_init__}, + "DeepseekV3DecoderLayer": { + "forward": QEffDeepseekV3DecoderLayer.forward, + }, + "DeepseekV3MoE": { + "forward": QEffDeepseekV3MoE.forward, + "moe": QEffDeepseekV3MoE.moe, + "__qeff_init__": QEffDeepseekV3MoE.__qeff_init__, + }, + "DeepseekV3Attention": { + "forward": QEffDeepseekV3Attention.forward, + "forward_full_kv": QEffDeepseekV3Attention.forward_full_kv, + "forward_full_kv_h_blocking": QEffDeepseekV3Attention.forward_full_kv_h_blocking, + "fused_forward": QEffDeepseekV3Attention.fused_forward, + "fused_forward_h_blocking": QEffDeepseekV3Attention.fused_forward_h_blocking, + "fused_forward_kv_blocking": QEffDeepseekV3Attention.fused_forward_kv_blocking, + "fused_forward_orig": QEffDeepseekV3Attention.fused_forward_orig, + "__qeff_init__": QEffDeepseekV3Attention.__qeff_init__, + }, + "DeepseekV3RMSNorm": { + "forward": QEffDeepseekV3CustomRMSNormAIC.forward, + }, } + +class PrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): _match_class_replace_method = {} + _match_string_replace_method = { + "DeepseekV3MoE": { + "forward": QEffPrefillOnlyDeepseekV3MoE.forward, + "moe": QEffPrefillOnlyDeepseekV3MoE.moe, + "__qeff_init__": QEffPrefillOnlyDeepseekV3MoE.__qeff_init__, + }, + } + + +class RevertPrefillOnlyExternalModuleMapperTransform(ExternalModuleMapperTransform): + _match_class_replace_method = {} + _match_string_replace_method = { + "DeepseekV3MoE": { + "forward": QEffDeepseekV3MoE.forward, + "moe": QEffDeepseekV3MoE.moe, + "__qeff_init__": QEffDeepseekV3MoE.__qeff_init__, + }, + } class T5ModelTransform(ModuleMappingTransform): @@ -1022,7 +1151,9 @@ def apply(cls, model: nn.Module, attn_blocking_config) -> Tuple[nn.Module, bool] if type(module) in cls._skip_classes: warnings.warn(f"Blocking is not yet supported for {type(module)}.") continue - if type(module) in supported_attention_classes: + if type(module) in supported_attention_classes or "DeepseekV3ForCausalLM" in ( + getattr(model.config, "architectures", None) or [] + ): module.attn_blocking_config = attn_blocking_config transformed = True elif module.__class__.__name__.endswith("Attention") and type(module) not in supported_attention_classes: diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index cfe17ac452..473d095381 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -38,7 +38,16 @@ require_value, to_named_specializations, ) +from QEfficient.utils.compile_layerwise import ( # noqa: F401 + run_compile_layerwise as compile_layerwise, +) from QEfficient.utils.hash_utils import ( # noqa: F401 create_export_hash, hash_dict_params, ) +from QEfficient.utils.inference_pipeline import ( # noqa: F401 + inference_pipeline, +) +from QEfficient.utils.layerwise_pipeline import ( # noqa: F401 + layerwise_pipeline, +) diff --git a/QEfficient/utils/compile_layerwise.py b/QEfficient/utils/compile_layerwise.py new file mode 100644 index 0000000000..81375472ae --- /dev/null +++ b/QEfficient/utils/compile_layerwise.py @@ -0,0 +1,238 @@ +import argparse +import os +import re +import signal +import subprocess +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +# ===================================================== +# CONFIG +# ===================================================== + +MAX_RETRIES = 1 # retries don't help for long compiles +RETRY_SLEEP = 5 +TIMEOUT = 90 * 60 # 90 minutes + +# ===================================================== +# WORKER CONFIG (CPU-BASED) +# ===================================================== + +MAX_WORKERS = 8 + + +# ===================================================== +# DISCOVERY +# ===================================================== + + +def _discover_onnx_jobs(base_onnx_dir: str): + # agent: defer discovery to runtime and require explicit export path. + onnx_jobs = [] + base_dir_path = Path(base_onnx_dir) + layerwise_dir = base_dir_path / "onnx_layerwise_tmp" + if layerwise_dir.is_dir(): + scan_dir = layerwise_dir + elif base_dir_path.is_dir(): + scan_dir = base_dir_path + else: + raise RuntimeError(f"BASE_ONNX_DIR does not exist: {base_onnx_dir}") + + layer_dir_pat = re.compile(r"^layer_(\d+)_(\d+)$") + for layer_dir in sorted(scan_dir.iterdir()): + if not layer_dir.is_dir(): + continue + + m = layer_dir_pat.match(layer_dir.name) + if not m: + continue + + layer_start = int(m.group(1)) + layer_end = int(m.group(2)) + if layer_end <= layer_start: + continue + + layer_indices = [str(i) for i in range(layer_start, layer_end)] + layer_window = (layer_start, layer_end) + + for f in layer_dir.iterdir(): + if f.name.startswith("DeepseekV3ForCausalLM_layer_tmp_") and f.suffix == ".onnx": + # device_group fixed to single device "0" + onnx_jobs.append((f, layer_dir, layer_window, layer_indices, "0")) + + if not onnx_jobs: + raise RuntimeError(f"No valid ONNX files found under: {scan_dir}") + + return onnx_jobs + + +# ===================================================== +# CUSTOM IO YAML WRITER +# ===================================================== + + +def write_custom_io_yaml(path: Path, indices): + with open(path, "w") as fp: + # agent: write cache entries for all layers in each discovered window. + for idx in indices: + fp.write(f" - IOName: k_pe.{idx}\n") + fp.write(" Precision: mxint8\n\n") + fp.write(f" - IOName: compressed_kv.{idx}\n") + fp.write(" Precision: mxint8\n\n") + + for idx in indices: + fp.write(f" - IOName: k_pe.{idx}_RetainedState\n") + fp.write(" Precision: mxint8\n\n") + fp.write(f" - IOName: compressed_kv.{idx}_RetainedState\n") + fp.write(" Precision: mxint8\n\n") + + +# ===================================================== +# COMPILE FUNCTION +# ===================================================== + + +def compile_one(job): + onnx_path, layer_dir, layer_window, layer_indices, device_group = job + + layer_tag = onnx_path.stem.replace("DeepseekV3ForCausalLM_layer_tmp_", "") + + qpc_dir = layer_dir / f"qpc_{layer_tag}" + log_file = layer_dir / f"qpc_{layer_tag}.log" + qpc_dir.mkdir(parents=True, exist_ok=True) + + custom_io_yaml = layer_dir / "custom_io_fp16.yaml" + if not custom_io_yaml.exists(): + write_custom_io_yaml(custom_io_yaml, layer_indices) + + cmd = [ + "python", + "-m", + "QEfficient.cloud.compile", + "--onnx_path", + str(onnx_path), + "--qpc-path", + str(qpc_dir), + "--batch_size", + "1", + "--prompt_len", + "1", + "--ctx_len", + "128", + "--mxfp6", + "mxint8_kv_cache", + "--num_cores", + "16", + "--device_group", + device_group, + "--mos", + "1", + "--aic_enable_depth_first", + f"-custom-IO-list-file={custom_io_yaml}", + ] + + total_start = time.time() + last_status = "FAILED" + + for attempt in range(1, MAX_RETRIES + 1): + print( + f"[START ] layer {layer_window[0]}_{layer_window[1]} " + f"device {device_group} (attempt {attempt}/{MAX_RETRIES})" + ) + + proc = None + try: + with open(log_file, "a") as lf: + lf.write(f"\n===== ATTEMPT {attempt} =====\n") + proc = subprocess.Popen( + cmd, + stdout=lf, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + proc.wait(timeout=TIMEOUT) + + if proc.returncode == 0: + last_status = "OK" + break + else: + last_status = f"FAILED(rc={proc.returncode})" + + except subprocess.TimeoutExpired: + last_status = "TIMEOUT" + if proc: + os.killpg(proc.pid, signal.SIGTERM) + break # do not retry timeouts + + except KeyboardInterrupt: + if proc: + os.killpg(proc.pid, signal.SIGTERM) + raise + + except Exception as e: + last_status = f"ERROR({e})" + if proc: + os.killpg(proc.pid, signal.SIGTERM) + break + + time.sleep(RETRY_SLEEP) + + total_elapsed = time.time() - total_start + + print(f"[DONE ] layer {layer_window[0]}_{layer_window[1]} {last_status} | {total_elapsed:.1f}s") + + return layer_tag, last_status, total_elapsed + + +# ===================================================== +# MAIN +# ===================================================== + + +def run_compile_layerwise(base_onnx_dir: str): + # agent: path is expected to be export root and is normalized in run.py. + onnx_jobs = _discover_onnx_jobs(base_onnx_dir) + print(f"MAX_WORKERS set to : {MAX_WORKERS}") + print(f"Found {len(onnx_jobs)} ONNX files\n") + + start_time = time.time() + results = [] + interrupted = False + + try: + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + futures = [executor.submit(compile_one, job) for job in onnx_jobs] + + for fut in as_completed(futures): + results.append(fut.result()) + + except KeyboardInterrupt: + interrupted = True + print("\n[INTERRUPT] KeyboardInterrupt received") + + finally: + total_time = time.time() - start_time + + success = sum(1 for _, s, _ in results if s == "OK") + failed = sum(1 for _, s, _ in results if s != "OK") + completed = len(results) + pending = len(onnx_jobs) - completed + + print("\n============================================") + print(f"TOTAL FILES : {len(onnx_jobs)}") + print(f"COMPLETED : {completed}") + print(f"SUCCESS : {success}") + print(f"FAILED : {failed}") + print(f"PENDING : {pending}") + print(f"TOTAL TIME : {total_time:.1f} seconds") + print(f"INTERRUPTED : {interrupted}") + print("============================================") + + +if __name__ == "__main__": + # agent: CLI now takes exported path instead of embedded machine-local constant. + parser = argparse.ArgumentParser(description="Compile layerwise ONNX windows into QPC artifacts.") + parser.add_argument("--base-onnx-dir", required=True, help="Export root containing onnx_layerwise_tmp/") + args = parser.parse_args() + run_compile_layerwise(args.base_onnx_dir) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index cc0b87b604..339e4f4dac 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -136,8 +136,12 @@ def get_models_dir(): LLAMA4_ATTENTION_CHUNK_SIZE = 8192 LLAMA4_MAX_POSITION_EMBEDDINGS = 65536 -# Gemma3 Constant -GEMMA3_MAX_POSITION_EMBEDDINGS = 32768 +# DeepSeek Kimi-k2 Constant +MAX_POSITION_EMBEDDINGS = 32768 +FP16_BYTES = 2 +DEFAULT_NUM_HEADS = 64 +KV_LORA_RANK = 512 +ROPE_DIM = 64 # Wav2Vec2 Constant WAV2VEC2_MAX_SEQ_LEN = 480000 # 30 seconds of audio at 16 kHz sampling rate (16,000 samples/sec × 30 sec) diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py index 5c4ee8054c..901484e724 100644 --- a/QEfficient/utils/export_utils.py +++ b/QEfficient/utils/export_utils.py @@ -166,7 +166,8 @@ def _setup_onnx_subfunctions(qeff_model, args, kwargs): if "output_names" in kwargs: kwargs["output_names"] = [ re.sub("_RetainedState", "_InternalRetainedState", name) - if name.endswith("_RetainedState") and ("key" in name or "value" in name) + if name.endswith("_RetainedState") + and ("key" in name or "value" in name or "compressed_kv" in name or "k_pe" in name) else name for name in kwargs["output_names"] ] diff --git a/QEfficient/utils/inference_pipeline.py b/QEfficient/utils/inference_pipeline.py new file mode 100644 index 0000000000..f7f6b64848 --- /dev/null +++ b/QEfficient/utils/inference_pipeline.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import argparse +import re +import time +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +from transformers import AutoTokenizer + +from QEfficient.generation.cloud_infer import QAICInferenceSession + +SessionInfo = Dict[str, object] +LAYER_DIR_RE = re.compile(r"layer_(\d+)_(\d+)$") + + +def discover_qpc_paths(base_path: Path) -> List[Path]: + layer_dirs = [] + for child in base_path.iterdir(): + if not child.is_dir(): + continue + match = LAYER_DIR_RE.match(child.name) + if not match: + continue + layer_dirs.append((int(match.group(1)), int(match.group(2)), child)) + + if not layer_dirs: + raise FileNotFoundError(f"No layer directories found under: {base_path}") + + layer_dirs.sort(key=lambda x: (x[0], x[1])) + qpc_paths: List[Path] = [] + for _, _, layer_dir in layer_dirs: + candidates = sorted(p for p in layer_dir.glob("**/qpcs") if p.is_dir()) + if not candidates: + raise FileNotFoundError(f"No qpcs directory found in: {layer_dir}") + qpc_paths.append(candidates[0]) + return qpc_paths + + +def pick_token_input_name(session: QAICInferenceSession) -> Optional[str]: + if "input_ids" in session.input_names: + return "input_ids" + for name in session.input_names: + if "input_ids" in name: + return name + return None + + +def pick_hidden_input_name(session: QAICInferenceSession) -> Optional[str]: + for preferred in ("inputs_embeds", "input_embeds"): + if preferred in session.input_names: + return preferred + for name in session.input_names: + if name == "position_ids": + continue + if "compressed_kv" in name or "k_pe" in name: + continue + if "input_ids" in name: + continue + return name + return None + + +def pick_pos_input_name(session: QAICInferenceSession) -> Optional[str]: + if "position_ids" in session.input_names: + return "position_ids" + for name in session.input_names: + if "position" in name: + return name + return None + + +def pick_main_output_name(session: QAICInferenceSession) -> str: + candidates = [name for name in session.output_names] + if not candidates: + raise RuntimeError(f"No usable output name found for session outputs: {session.output_names}") + if "logits" in candidates: + return "logits" + return candidates[-1] + + +def output_placeholder(session: QAICInferenceSession, output_name: str) -> np.ndarray: + idx = session.binding_index_map[output_name] + binding = session.bindings[idx] + dtype = session.aic_to_np_dtype_mapping[binding.type] + shape = tuple(max(1, int(dim)) for dim in binding.dims) + return np.zeros(shape, dtype=dtype) + + +def resolve_base_path(base_path: str | Path) -> Path: + base = Path(base_path) + if (base / "onnx_layerwise_tmp").is_dir(): + return base + children = sorted( + p for p in base.iterdir() if p.is_dir() and (p / "onnx_layerwise_tmp").is_dir() + ) + if len(children) == 1: + return children[0] + if not children: + raise FileNotFoundError(f"No onnx_layerwise_tmp under: {base}") + raise RuntimeError( + f"Multiple candidate model directories under {base}. Pass one of: {[str(p) for p in children]}" + ) + + +def load_single_session(idx: int, qpc: Path, device_start: Optional[int]) -> Tuple[int, SessionInfo]: + device_ids = [device_start + idx] if device_start is not None else None + session = QAICInferenceSession(str(qpc), device_ids=device_ids) + session.skip_buffers( + [n for n in session.input_names + session.output_names if "compressed_kv" in n or "k_pe" in n] + ) + + out_name = pick_main_output_name(session) + session.set_buffers({out_name: output_placeholder(session, out_name)}) + + return idx, { + "session": session, + "token_input": pick_token_input_name(session), + "hidden_input": pick_hidden_input_name(session), + "pos_input": pick_pos_input_name(session), + "out_name": out_name, + } + + +def load_sessions_threaded( + qpc_paths: List[Path], device_start: Optional[int], max_workers: Optional[int] +) -> List[SessionInfo]: + worker_count = max_workers if max_workers is not None else min(64, len(qpc_paths) or 1) + indexed: List[Optional[SessionInfo]] = [None] * len(qpc_paths) + with ThreadPoolExecutor(max_workers=worker_count) as executor: + futures = [executor.submit(load_single_session, idx, qpc, device_start) for idx, qpc in enumerate(qpc_paths)] + for future in futures: + idx, info = future.result() + indexed[idx] = info + print(f"[LOAD] layer {idx}: {qpc_paths[idx]} -> out={info['out_name']}") + + return [info for info in indexed if info is not None] + + +def inference_pipeline( + base_path: str | Path, + model_name: str = "moonshotai/Kimi-K2.5", + prompt: str = "Help me with this", + max_len: int = 1000, + device_start: Optional[int] = None, + max_workers: Optional[int] = None, +) -> List[int]: + tokenizer = AutoTokenizer.from_pretrained( + model_name, + padding_side="right", + trust_remote_code=True, + ) + prompt_ids = tokenizer(prompt, return_tensors="np", add_special_tokens=True)["input_ids"][0].tolist() + all_ids = list(prompt_ids) + + resolved_base = resolve_base_path(base_path) + qpc_paths = discover_qpc_paths(resolved_base / "onnx_layerwise_tmp") + print(f"[LOAD] Found {len(qpc_paths)} layer sessions") + + start = time.time() + sessions = load_sessions_threaded(qpc_paths, device_start, max_workers) + + print(f"[LOAD] Total load time: {time.time() - start:.2f}s") + if not sessions: + raise RuntimeError("No sessions loaded") + if sessions[0]["token_input"] is None: + raise RuntimeError(f"First layer has no token input. inputs={sessions[0]['session'].input_names}") + + logits = None + + for pos, token_id in enumerate(prompt_ids): + hidden = None + for i, info in enumerate(sessions): + session = info["session"] + run_inputs: Dict[str, np.ndarray] = {} + if i == 0: + run_inputs[info["token_input"]] = np.array([[token_id]], dtype=np.int64) + else: + if hidden is None: + raise RuntimeError("Missing hidden state while executing intermediate layer") + if info["hidden_input"] is None: + raise RuntimeError(f"Layer {i} has no hidden-state input. inputs={session.input_names}") + run_inputs[info["hidden_input"]] = hidden + + if info["pos_input"] is not None: + run_inputs[info["pos_input"]] = np.array([[pos]], dtype=np.int64) + + outputs = session.run(run_inputs) + hidden = outputs[info["out_name"]] + logits = hidden + + if logits is None: + raise RuntimeError("Prompt produced no logits") + + start = time.time() + print("[RUN] Starting inference pipeline") + generated_ids: List[int] = [] + while len(all_ids) < max_len: + next_token_id = int(np.argmax(logits, axis=-1)[0, 0]) + generated_ids.append(next_token_id) + all_ids.append(next_token_id) + + if tokenizer.eos_token_id is not None and next_token_id == tokenizer.eos_token_id: + break + + pos = len(all_ids) - 1 + hidden = None + for i, info in enumerate(sessions): + session = info["session"] + run_inputs: Dict[str, np.ndarray] = {} + if i == 0: + run_inputs[info["token_input"]] = np.array([[next_token_id]], dtype=np.int64) + else: + if hidden is None: + raise RuntimeError("Missing hidden state while decoding intermediate layer") + if info["hidden_input"] is None: + raise RuntimeError(f"Layer {i} has no hidden-state input. inputs={session.input_names}") + run_inputs[info["hidden_input"]] = hidden + + if info["pos_input"] is not None: + run_inputs[info["pos_input"]] = np.array([[pos]], dtype=np.int64) + + outputs = session.run(run_inputs) + hidden = outputs[info["out_name"]] + logits = hidden + + print(f"[RUN] Total inference time: {time.time() - start:.2f}s") + print("Generated token ids:") + print(generated_ids) + print("Generated text:") + print(tokenizer.decode(generated_ids, skip_special_tokens=True)) + return generated_ids + + +def main() -> None: + parser = argparse.ArgumentParser(description="Threaded layerwise QAIC pipeline") + parser.add_argument("base_path", type=Path, help="Path to model dir or parent of model dir") + parser.add_argument("--model-name", default="moonshotai/Kimi-K2.5") + parser.add_argument("--prompt", default="Help me with this") + parser.add_argument("--max-len", type=int, default=32) + parser.add_argument( + "--device-start", + type=int, + default=0, + help="Optional starting device id. If set, layer i uses device_start + i.", + ) + parser.add_argument("--max-workers", type=int, default=None, help="Thread pool size for load") + args = parser.parse_args() + + inference_pipeline( + base_path=args.base_path, + model_name=args.model_name, + prompt=args.prompt, + max_len=args.max_len, + device_start=args.device_start, + max_workers=args.max_workers, + ) + + +if __name__ == "__main__": + main() diff --git a/QEfficient/utils/layerwise_pipeline.py b/QEfficient/utils/layerwise_pipeline.py new file mode 100644 index 0000000000..4b59b25466 --- /dev/null +++ b/QEfficient/utils/layerwise_pipeline.py @@ -0,0 +1,546 @@ +#!/usr/bin/env python3 +import argparse +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Tuple + +import onnx +import onnx_ir +from onnx import external_data_helper + +from QEfficient.base.onnx_transforms import CustomOpTransform, RemovePrefix + +# ============================================================ +# PREFIX/DELETION CONFIG (defaults preserved) +# ============================================================ +SAVE_WORKERS = 8 +DELETE_WORKERS = 8 +DELETE_SUFFIXES = ("onnx.data",) +_delete_pool = ThreadPoolExecutor(max_workers=DELETE_WORKERS) + + +# agent change start: generalized layer-window discovery +def _discover_layer_windows(exported_path: str, start_layer: int = 0) -> List[Tuple[int, int]]: + base_path = f"{exported_path}/onnx_layerwise_tmp" + if not os.path.isdir(base_path): + raise FileNotFoundError(f"Missing layerwise directory: {base_path}") + + windows: List[Tuple[int, int]] = [] + pat = re.compile(r"^layer_(\d+)_(\d+)$") + for entry in os.scandir(base_path): + if not entry.is_dir(): + continue + m = pat.match(entry.name) + if not m: + continue + layer_start, layer_end = int(m.group(1)), int(m.group(2)) + if layer_end <= layer_start: + continue + if layer_start < start_layer: + continue + windows.append((layer_start, layer_end)) + + windows.sort(key=lambda x: x[0]) + if not windows: + raise RuntimeError(f"No layer windows found in {base_path}. Expected directories like layer__.") + return windows + + +def _window_paths(exported_path: str, layer_start: int, layer_end: int) -> Tuple[str, str, str]: + base_dir = f"{exported_path}/onnx_layerwise_tmp/layer_{layer_start}_{layer_end}" + onnx_tmp = f"{base_dir}/DeepseekV3ForCausalLM_layer_tmp_{layer_start}_{layer_end}.onnx" + split_graph = f"{base_dir}/split_graph.onnx" + return base_dir, onnx_tmp, split_graph + + +# agent change end: generalized layer-window discovery + + +# ============================================================ +# STAGE 1: SPLITTING +# ============================================================ +def split_layer_graph( + shard_idx: int, + total_shards: int, + exported_path: str, + layer_start: int, + layer_end: int, +) -> bool: + base_dir, onnx_path, out_path = _window_paths(exported_path, layer_start, layer_end) + + if not os.path.exists(onnx_path): + print(f"[SKIP] ONNX not found: {onnx_path}") + return False + + model = onnx.load(onnx_path, load_external_data=False) + + decoder_input = None + decoder_output = None + for node in model.graph.node: + if "DecoderLayer" in node.name: + decoder_input = list(node.input) + decoder_output = list(node.output) + break + + if decoder_input is None or decoder_output is None: + raise RuntimeError(f"DecoderLayer not found in layer window {layer_start}_{layer_end}") + + model_ir = onnx_ir.load(onnx_path) + + # agent change start: generalized shard io selection (works for 1-layer and multi-layer windows) + graph_inputs = [v.name for v in model.graph.input] + graph_outputs = [v.name for v in model.graph.output] + + if layer_start == 0: + preferred_inputs = ["input_ids", "position_ids"] + else: + preferred_inputs = ["inputs_embeds", "position_ids"] + + cache_inputs = sorted([n for n in graph_inputs if n.startswith("compressed_kv.") or n.startswith("k_pe.")]) + input_names = [n for n in preferred_inputs if n in graph_inputs] + cache_inputs + + output_names = list(graph_outputs) + if shard_idx != total_shards - 1 and "position_ids" in graph_inputs and "position_ids" not in output_names: + output_names.append("position_ids") + # agent change end: generalized shard io selection (works for 1-layer and multi-layer windows) + + model_ir.graph = onnx_ir.convenience.extract( + model_ir.graph, + input_names, + output_names, + ) + + onnx_ir.save(model_ir, out_path) + onnx.load(out_path, load_external_data=False) + + print(f"[DONE] Layer window {layer_start}_{layer_end}: saved split graph -> {out_path}") + return True + + +def run_split_pipeline(exported_path: str, num_layers: int = 61, start_layer: int = 0) -> None: + windows = _discover_layer_windows(exported_path, start_layer=start_layer) + print( + f"[START] split pipeline | exported_path={exported_path}, " + f"start_layer={start_layer}, discovered_shards={len(windows)}" + ) + for shard_idx, (layer_start, layer_end) in enumerate(windows): + print(f"[PROCESS] Layer window {layer_start}_{layer_end}") + split_layer_graph(shard_idx, len(windows), exported_path, layer_start, layer_end) + print("[DONE] split pipeline complete") + + +# ============================================================ +# STAGE 2: PREFIX + DELETION +# ============================================================ +def async_delete_files(paths: List[str]) -> None: + def _delete(p): + try: + os.remove(p) + except FileNotFoundError: + pass + except Exception as e: + print(f"[delete] failed {p}: {e}") + + for p in paths: + _delete_pool.submit(_delete, p) + + +def collect_chunk_deletable_files(exported_path: str, layer_windows: List[Tuple[int, int]]) -> List[str]: + files = [] + for layer_start, layer_end in layer_windows: + layer_dir = f"{exported_path}/onnx_layerwise_tmp/layer_{layer_start}_{layer_end}" + if not os.path.isdir(layer_dir): + continue + for entry in os.scandir(layer_dir): + if entry.is_file() and entry.name.endswith(DELETE_SUFFIXES): + files.append(entry.path) + return files + + +def rewrite_tensors_with_prefix( + model: onnx.ModelProto, + prefix: str, + func_attr_tens, + size_threshold: int = 1024, + file_chunk_size: int = 10 * 2**30, +) -> None: + size = 0 + file_num = 0 + + for tensor in external_data_helper._get_all_tensors(model): + if tensor.HasField("raw_data") and tensor.name != "int64_2" and tensor.name not in func_attr_tens: + tsize = len(tensor.raw_data) + if tsize > size_threshold: + if size + tsize > file_chunk_size: + file_num += 1 + size = tsize + else: + size += tsize + + external_data_helper.set_external_data(tensor, f"{prefix}_{file_num}.onnx.data") + + +def saving_prefix_file( + location: str, layer_start: int, layer_end: int, exported_path: str, final_data_dir: str +) -> None: + model = onnx.load(location, load_external_data=False) + + model_pref = onnx.compose.add_prefix(model, f"layer_{layer_start}/", rename_functions=False) + + base_dir = f"{exported_path}/onnx_layerwise_tmp/layer_{layer_start}_{layer_end}" + external_data_helper.load_external_data_for_model(model_pref, base_dir) + + func_attr_tens = set() + if model_pref.functions: + func_attr_tens = { + v.name for v in external_data_helper._get_attribute_tensors_from_graph(model_pref.functions[0]) + } + + rewrite_tensors_with_prefix( + model_pref, + prefix=f"layer_{layer_start}", + func_attr_tens=func_attr_tens, + ) + + out_dir = f"{exported_path}/{final_data_dir}" + os.makedirs(out_dir, exist_ok=True) + onnx.save(model_pref, f"{out_dir}/pref_{layer_start}.onnx") + + +def run_saving_prefix(layer_start: int, layer_end: int, exported_path: str, final_data_dir: str) -> int: + _, _, loc = _window_paths(exported_path, layer_start, layer_end) + saving_prefix_file(loc, layer_start, layer_end, exported_path, final_data_dir) + return layer_start + + +def run_prefix_pipeline( + exported_path: str, + num_layers: int = 61, + chunk_size: int = 8, + final_data_dir: str = "final_data", +) -> None: + windows = _discover_layer_windows(exported_path, start_layer=0) + print( + f"[START] prefix+deletion pipeline | exported_path={exported_path}, " + f"discovered_shards={len(windows)}, chunk_size={chunk_size}" + ) + + for chunk_start in range(0, len(windows), chunk_size): + chunk_end = min(chunk_start + chunk_size, len(windows)) + chunk_windows = windows[chunk_start:chunk_end] + + print(f"\\n[Chunk] {chunk_start} -> {chunk_end - 1} (window count)") + t0 = time.time() + + with ThreadPoolExecutor(max_workers=SAVE_WORKERS) as pool: + futures = [ + pool.submit(run_saving_prefix, layer_start, layer_end, exported_path, final_data_dir) + for (layer_start, layer_end) in chunk_windows + ] + for f in as_completed(futures): + f.result() + + print(f"[Chunk] saved in {time.time() - t0:.2f}s") + + # deletables = collect_chunk_deletable_files(exported_path, chunk_windows) + # async_delete_files(deletables) + # print(f"[Chunk] scheduled deletion of {len(deletables)} files") + + print("[DONE] prefix+deletion pipeline complete") + + +# ============================================================ +# STAGE 3: MERGING +# ============================================================ +def compare_onnx_func(func1: onnx.FunctionProto, func2: onnx.FunctionProto): + if ( + len(func1.input) != len(func2.input) + or len(func1.output) != len(func2.output) + or len(func1.node) != len(func2.node) + ): + return False + + for i in range(len(func1.node)): + node1 = func1.node[i] + node2 = func2.node[i] + + if len(node1.input) != len(node2.input): + return False + for j in range(len(node1.input)): + if node1.input[j] in func1.input: + idx = list(func1.input).index(node1.input[j]) + if node2.input[j] not in func2.input or list(func2.input).index(node2.input[j]) != idx: + return False + elif node1.input[j] != node2.input[j]: + if node1.input[j] in func1.output: + idx = list(func1.output).index(node1.input[j]) + if node2.input[j] not in func2.output or list(func2.output).index(node2.input[j]) != idx: + return False + else: + return False + + if node1.op_type != node2.op_type: + return False + if len(node1.attribute) != len(node2.attribute): + return False + for j in range(len(node1.attribute)): + if node1.attribute[j] != node2.attribute[j]: + return False + + if len(node1.output) != len(node2.output): + return False + for j in range(len(node1.output)): + if node1.output[j] in func1.output: + idx = list(func1.output).index(node1.output[j]) + if node2.output[j] not in func2.output or list(func2.output).index(node2.output[j]) != idx: + return False + else: + if node1.output[j] != node2.output[j]: + return False + + return True + + +def merge_models(m1, m2, io_map): + def is_decoder(name: str) -> bool: + return "DecoderLayer" in name + + def copy_with_name(func: onnx.FunctionProto, new_name: str) -> onnx.FunctionProto: + f = onnx.FunctionProto() + f.CopyFrom(func) + f.name = new_name + return f + + def update_node_calls(graph: onnx.GraphProto, old_name: str, new_name: str): + if old_name == new_name: + return + for node in graph.node: + if node.op_type == old_name: + node.op_type = new_name + + graph = onnx.compose.merge_graphs(m1.graph, m2.graph, io_map) + model = onnx.helper.make_model_gen_version( + graph, + producer_name="QEfficient", + producer_version="1.21", + ir_version=10, + opset_imports=m1.opset_import, + ) + + props = {} + for p in m1.metadata_props: + props[p.key] = p.value + for p in m2.metadata_props: + if p.key in props and props[p.key] != p.value: + raise ValueError( + "Can't merge models with different values for the same model metadata property." + f" Found: property = {p.key}, with values {props[p.key]} and {p.value}." + ) + props[p.key] = p.value + onnx.helper.set_model_props(model, props) + + m1_funcs = [f.name for f in m1.functions] + m2_funcs = [f.name for f in m2.functions] + decoder_variants = {} + + def assign_decoder_variant(base_name: str, func: onnx.FunctionProto, src_graph: onnx.GraphProto) -> str: + variants = decoder_variants.setdefault(base_name, []) + + for existing_func, assigned_name in variants: + if compare_onnx_func(func, existing_func): + return assigned_name + + assigned = base_name if not variants else f"{base_name}__v{len(variants) + 1}" + variants.append((func, assigned)) + if assigned != base_name: + update_node_calls(src_graph, base_name, assigned) + return assigned + + final_funcs = {} + all_names = set(m1_funcs + m2_funcs) + + for name in all_names: + in_m1 = name in m1_funcs + in_m2 = name in m2_funcs + + if in_m1 and in_m2: + func1 = m1.functions[m1_funcs.index(name)] + func2 = m2.functions[m2_funcs.index(name)] + + if compare_onnx_func(func1, func2): + final_funcs[(func1.domain, func1.name)] = func1 + else: + if is_decoder(name): + name1 = assign_decoder_variant(name, func1, m1.graph) + name2 = assign_decoder_variant(name, func2, m2.graph) + + f1 = func1 if func1.name == name1 else copy_with_name(func1, name1) + f2 = func2 if func2.name == name2 else copy_with_name(func2, name2) + final_funcs[(f1.domain, f1.name)] = f1 + final_funcs[(f2.domain, f2.name)] = f2 + else: + raise ValueError(f"Function '{name}' differs between models and is not a DecoderLayer.") + elif in_m1: + f = m1.functions[m1_funcs.index(name)] + final_funcs[(f.domain, f.name)] = f + elif in_m2: + f = m2.functions[m2_funcs.index(name)] + final_funcs[(f.domain, f.name)] = f + else: + raise ValueError("Function not found") + + graph2 = onnx.compose.merge_graphs(m1.graph, m2.graph, io_map) + model.graph.CopyFrom(graph2) + + for (domain, name), f in final_funcs.items(): + if f.name != name: + f = copy_with_name(f, name) + model.functions.MergeFrom([f]) + + return model + + +def run_merge_pipeline(exported_path: str, num_layers: int = 61, final_data_dir: str = "final_data") -> str: + windows = _discover_layer_windows(exported_path, start_layer=0) + if len(windows) < 1: + raise ValueError("Need at least one discovered shard to merge") + + base_dir = f"{exported_path}/{final_data_dir}" + start = time.time() + print( + f"[START] merge pipeline | exported_path={exported_path}, " + f"discovered_shards={len(windows)}, final_data_dir={final_data_dir}" + ) + + # agent change start: generalized merge over discovered shard starts + shard_starts = [layer_start for (layer_start, _) in windows] + first_start = shard_starts[0] + last_start = shard_starts[-1] + + if len(shard_starts) == 1: + only_model = f"{base_dir}/pref_{first_start}.onnx" + if not os.path.exists(only_model): + raise FileNotFoundError(f"Missing input model: {only_model}") + print(f"[DONE] merge pipeline skipped (single layer): {only_model}") + return only_model + + for idx in range(len(shard_starts) - 1): + left = shard_starts[len(shard_starts) - idx - 2] + right = shard_starts[len(shard_starts) - idx - 1] + + m1_path = f"{base_dir}/pref_{left}.onnx" + m2_path = f"{base_dir}/pref_{right}.onnx" if idx == 0 else f"{base_dir}/merged_{right}-{last_start}.onnx" + + if not os.path.exists(m1_path): + raise FileNotFoundError(f"Missing input model: {m1_path}") + if not os.path.exists(m2_path): + raise FileNotFoundError(f"Missing input model: {m2_path}") + + print(f"[MERGE] {left}-{last_start}") + m1_pref = onnx.load(m1_path, load_external_data=False) + m2_pref = onnx.load(m2_path, load_external_data=False) + + decoder_nodes = [n for n in m1_pref.graph.node if "DecoderLayer" in n.name] + if not decoder_nodes: + raise RuntimeError(f"DecoderLayer node not found in {m1_path}") + if len(decoder_nodes) > 1: + decoder_output = list(decoder_nodes[1].output) + else: + decoder_output = list(decoder_nodes[0].output) + decoder_output = list(decoder_nodes[0].output) + merged_model = merge_models( + m1_pref, + m2_pref, + io_map=[ + (f"{decoder_output[2]}", f"layer_{right}/inputs_embeds"), + (f"layer_{left}/position_ids", f"layer_{right}/position_ids"), + ], + ) + + if idx == len(shard_starts) - 2: + CustomOpTransform.apply(merged_model) + + out_path = f"{base_dir}/merged_{left}-{last_start}.onnx" + onnx.save(merged_model, out_path) + print(f"[SAVED] {out_path}") + + final_path = f"{base_dir}/merged_{first_start}-{last_start}.onnx" + model = onnx.load(final_path, load_external_data=False) + RemovePrefix.apply(model) + onnx.save(model, final_path) + print(f"[DONE] merge pipeline complete in {time.time() - start:.2f}s") + return final_path + + +# ============================================================ +# ONE-SHOT ENTRY +# ============================================================ +def run_sequential_pipeline( + exported_path: str, + num_layers: int = 61, + start_layer: int = 0, + chunk_size: int = 8, + final_data_dir: str = "final_data", +) -> str: + print("\\n=== Stage 1/3: Splitting ===") + run_split_pipeline( + exported_path=exported_path, + num_layers=num_layers, + start_layer=start_layer, + ) + + print("\\n=== Stage 2/3: Prefix + Deletion ===") + run_prefix_pipeline( + exported_path=exported_path, + num_layers=num_layers, + chunk_size=chunk_size, + final_data_dir=final_data_dir, + ) + + print("\\n=== Stage 3/3: Merging ===") + final_path = run_merge_pipeline( + exported_path=exported_path, + num_layers=num_layers, + final_data_dir=final_data_dir, + ) + + print(f"\\n[PIPELINE DONE] Final merged model: {final_path}") + return final_path + + +def layerwise_pipeline( + exported_path: str, + num_layers: int = 61, + start_layer: int = 0, + chunk_size: int = 8, + final_data_dir: str = "final_data", +) -> str: + return run_sequential_pipeline( + exported_path=exported_path, + num_layers=num_layers, + start_layer=start_layer, + chunk_size=chunk_size, + final_data_dir=final_data_dir, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="All-in-one layer-wise ONNX split -> prefix/deletion -> merge pipeline." + ) + parser.add_argument("--exported_path", required=True, help="Base export path") + parser.add_argument("--num-layers", type=int, default=2) + parser.add_argument("--start-layer", type=int, default=0) + parser.add_argument("--chunk-size", type=int, default=8) + parser.add_argument("--final-data-dir", default="final_data") + args = parser.parse_args() + + run_sequential_pipeline( + exported_path=args.exported_path, + num_layers=args.num_layers, + start_layer=args.start_layer, + chunk_size=args.chunk_size, + final_data_dir=args.final_data_dir, + ) diff --git a/examples/kimi_k2/README.md b/examples/kimi_k2/README.md new file mode 100644 index 0000000000..230127ebbe --- /dev/null +++ b/examples/kimi_k2/README.md @@ -0,0 +1,28 @@ +# We should be using disaggragate serving for KImi-K2 model for best performance + - Kimi-K2 model has 384/8 ratio of total_experts/experts_per_tok + - Currently We use read all experts only once always strategy in prefill-only model + - And we treat weights activtions meaning read only chosen experts for decode-only model + +# Multi-head Latent Attention(MLA) +Kimi-K2 uses Multi-head Latent Attention(MLA) which is impleneted with dual cache (for compressed_kv and k_pe) + +# Absorption +MLA has 3 configurations based on order of evaluation different matrices, to enable, mla absorption config needs to passed like this : +- No absorption : mla_absorption = {"cache_compressed": True, "absorption": False, "online": False} +- Offline No absorption : mla_absorption = {"cache_compressed": True, "absorption": True, "online": False} +- Online absorption : mla_absorption = {"cache_compressed": True, "absorption": True, "online": True} + +mla_absorption has 3 keys: +- cache_compressed: True/False -> gets enabled if compressed KVs are cached to save memory. +- absorption: True/False -> gets enabled only when compressed cache is used, if True, enables absorption of attention matrices for efficiency. +- online: True/False -> gets enabled only when absorption is True, enables on device absorption. + +# Blocking +We have also implemented KV head replication, HEAD Blocking and KV Blocking which can be enable like this : +- For No Blocking : qaic_config = {"mla_absorption" : mla_absorption} +- For No blocking with kv head replication : qaic_config = {"mla_absorption" : mla_absorption, "num_kv_heads_repeat": TS} +- For KV blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking +- For Head Blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat + +- Currently Decode-Only model is giving best perf with Head Blocking and compressed cache. +- Contnuous batching is not enabled yet. \ No newline at end of file diff --git a/examples/kimi_k2/export_kimik2.py b/examples/kimi_k2/export_kimik2.py new file mode 100644 index 0000000000..416d8133ad --- /dev/null +++ b/examples/kimi_k2/export_kimik2.py @@ -0,0 +1,51 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +# parameters to be configured +prompt = "Once upon a time," +num_hidden_layers = 2 +TS = 4 +mla_absorption = {"cache_compressed": True, "absorption": False, "online": False} +# qaic_config = {"mla_absorption": mla_absorption} # for No Blocking +# qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking +qaic_config = { + "mla_absorption": mla_absorption, + "enable_blocking": True, + "blocking_mode": "h", + "num_kv_heads_repeat": TS, +} +# for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat + +model_name = "moonshotai/Kimi-K2-Thinking" +model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + +qeff_model = QEFFAutoModelForCausalLM(model, qaic_config=qaic_config) + +prefill_seq_len = 1 +ctx_len = 16 * 1024 + +qpc_path = qeff_model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + mxfp6_matmul=True, + mxint8_kv_cache=False, + num_devices=TS, + num_cores=16, + use_onnx_subfunctions=True, + qaic_config=qaic_config, +) + +qeff_model.generate(prompts=["Once upon a time,"], tokenizer=tokenizer) diff --git a/examples/text_generation/run_kimik2.py b/examples/text_generation/run_kimik2.py new file mode 100644 index 0000000000..b6f8d821f3 --- /dev/null +++ b/examples/text_generation/run_kimik2.py @@ -0,0 +1,117 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import numpy as np +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +# parameters to be configured +prompt = "Once upon a time," +num_hidden_layers = 2 +TS = 4 +mla_absorption = {"cache_compressed": False, "absorption": False, "online": False} +# qaic_config = {"mla_absorption": mla_absorption} # for No Blocking +# qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking +qaic_config = { + "mla_absorption": mla_absorption, + "enable_blocking": True, + "blocking_mode": "h", + "num_kv_heads_repeat": TS, +} +# for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat + +model_name = "moonshotai/Kimi-K2-Thinking" +model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + +PREFILL_SEQ_LEN = 32 +CTX_LEN = 8192 +generation_len = 10 +generated_ids = [] + +inputs = tokenizer(prompt, return_tensors="pt", padding=True) +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + +# with torch.no_grad(): +# out = model(**inputs) +# predictions = torch.argmax(out.logits, dim=-1) + +qeff_model = QEFFAutoModelForCausalLM(model, qaic_config=qaic_config) +qeff_model.transform(ctx_len=CTX_LEN, seq_len=PREFILL_SEQ_LEN, bs=1, num_devices=TS, qaic_config=qaic_config) + +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} + +pad_shape_k = ( + 1, + model.config.num_attention_heads, + CTX_LEN, + model.config.qk_nope_head_dim + model.config.qk_rope_head_dim, +) +pad_shape_v = (1, model.config.num_attention_heads, CTX_LEN, model.config.v_head_dim) + +num_heads = model.model.layers[0].self_attn.kv_a_proj_with_mqa.weight.shape[0] // ( + model.config.kv_lora_rank + model.config.qk_rope_head_dim +) +pad_shape_ckv = (1, num_heads, CTX_LEN, model.config.kv_lora_rank) +pad_shape_k_pe = (1, num_heads, CTX_LEN, model.config.qk_rope_head_dim) + +past_key_values = [] +compressed_kvs = [] + +for i in range(model.config.num_hidden_layers): + past_key = torch.zeros((pad_shape_k), dtype=torch.float32) + past_value = torch.zeros((pad_shape_v), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + + ckv = torch.zeros((pad_shape_ckv), dtype=torch.float32) + k_pe = torch.zeros((pad_shape_k_pe), dtype=torch.float32) + x = (ckv, k_pe) + compressed_kvs.append(x) + +cache_compressed = mla_absorption.get("cache_compressed", False) +if cache_compressed: + inputs["compressed_kvs"] = compressed_kvs +else: + inputs["past_key_values"] = past_key_values + +prefill_qeff_out = qeff_model.model(**inputs) + +position_ids = inputs["position_ids"] +qeff_out = prefill_qeff_out +qeff_generated_ids = [] +for _ in range(1, generation_len): + next_token_id = qeff_out["logits"][:, -1, :].argmax(-1).reshape(-1, 1) + qeff_generated_ids.append(next_token_id) + position_ids = position_ids.max(1, keepdim=True).values + 1 + decode_inputs = { + "input_ids": next_token_id, + "position_ids": position_ids, + } + if cache_compressed: + decode_inputs["compressed_kvs"] = qeff_out["past_key_values"] + else: + decode_inputs["past_key_values"] = qeff_out["past_key_values"] + + qeff_out = qeff_model.model(**decode_inputs) + +qeff_generated_ids.append(qeff_out["logits"][:, -1, :].argmax(-1).reshape(-1, 1)) +qeff_generated_ids = np.concatenate(qeff_generated_ids, axis=1) +predicted_string = tokenizer.batch_decode(qeff_generated_ids, skip_special_tokens=True) +print("QEFF Transformed Model Outputs (Torch CPU): \n") +print("Prompt:", repr(prompt)) +print("Completion:", repr(predicted_string)) diff --git a/pyproject.toml b/pyproject.toml index 003143d8f3..2c50db5b5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "fsspec==2023.6.0", "sentencepiece==0.2.0", "onnx==1.18.0", + "onnx_ir", "onnxruntime==1.22", "numpy==1.26.4", "protobuf==6.31.0", @@ -41,6 +42,8 @@ dependencies = [ "ftfy==6.3.1", "imageio==2.37.2", "imageio-ffmpeg==0.6.0", + "tiktoken==0.12.0", + "compressed-tensors==0.14.0", "torch==2.7.0; platform_machine=='aarch64'", # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11 "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'", diff --git a/run.py b/run.py new file mode 100644 index 0000000000..f950d62e51 --- /dev/null +++ b/run.py @@ -0,0 +1,227 @@ +import copy +import functools +import json +import tempfile +from pathlib import Path + +import torch +import transformers +from transformers import AutoConfig, AutoTokenizer +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +import QEfficient +from QEfficient import QEFFAutoModelForCausalLM + +MODEL_PATH = Path( + "/home/huggingface_hub/models--moonshotai--Kimi-K2.5/snapshots/54383e83fa343a1331754112fb9e3410c55efa2f" +) + +TS = 1 +enable_mla = True +mla_absorption = {"cache_compressed": True, "absorption": True, "online": False} +prefill_seq_len = 1 +ctx_len = 128 +qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication + +EXPORT_START = 1 +EXPORT_END = 3 +LAYERWISE_MODE = "multiple_qpc" + + +def _ensure_pretrained_window_attrs(): + if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): + transformers.modeling_utils.PreTrainedModel._start = 0 + if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): + transformers.modeling_utils.PreTrainedModel._end = 0 + + +def _null_outside_window_layers(model): + start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) + end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) + + layers = getattr(getattr(model, "model", None), "layers", None) + if layers is None: + return + + print(f"{start} to {end}") + for idx, _ in enumerate(layers): + if idx < start or idx >= end: + layers[idx] = None + + +def _install_window_patch(model_cls): + if getattr(model_cls, "_window_patch_installed", False): + return + + original_init = model_cls.__init__ + + @functools.wraps(original_init) + def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + _null_outside_window_layers(self) + + model_cls.__init__ = patched_init + model_cls._window_patch_installed = True + + +def load_text_only_kimi(model_path: Path, num_hidden_layers: int): + _ensure_pretrained_window_attrs() + kimi_config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=True) + + # Kimi K2.5 is multimodal, so we load only the text stack config. + text_config = copy.deepcopy(kimi_config.text_config) + + deepseek_cls = get_class_from_dynamic_module("modeling_deepseek.DeepseekV3ForCausalLM", str(model_path)) + _install_window_patch(deepseek_cls) + + checkpoint_index = json.loads((model_path / "model.safetensors.index.json").read_text()) + weight_map = checkpoint_index["weight_map"] + + allowed_prefixes = [ + "language_model.model.embed_tokens.", + "language_model.model.norm.", + "language_model.lm_head.", + ] + layer_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) + layer_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) + allowed_prefixes.extend( + [f"language_model.model.layers.{layer_idx}." for layer_idx in range(layer_start, layer_end)] + ) + + required_shards = sorted( + { + shard_name + for checkpoint_key, shard_name in weight_map.items() + if any(checkpoint_key.startswith(prefix) for prefix in allowed_prefixes) + } + ) + filtered_weight_map = { + checkpoint_key: shard_name + for checkpoint_key, shard_name in weight_map.items() + if any(checkpoint_key.startswith(prefix) for prefix in allowed_prefixes) + } + if not filtered_weight_map: + raise RuntimeError("No text-only weights were selected from the Kimi K2.5 checkpoint.") + + with tempfile.TemporaryDirectory() as tmpdir: + temp_model_path = Path(tmpdir) + (temp_model_path / "config.json").write_text(text_config.to_json_string(use_diff=False)) + (temp_model_path / "model.safetensors.index.json").write_text( + json.dumps( + { + "metadata": { + "total_size": sum((model_path / shard_name).stat().st_size for shard_name in required_shards) + }, + "weight_map": filtered_weight_map, + } + ) + ) + for shard_name in required_shards: + (temp_model_path / shard_name).symlink_to(model_path / shard_name) + + # We are loading a task checkpoint into the base text model, so disable the + # base/task prefix heuristic and let `key_mapping` strip `language_model.`. + original_base_model_prefix = deepseek_cls.base_model_prefix + deepseek_cls.base_model_prefix = "" + try: + model, loading_info = deepseek_cls.from_pretrained( + str(temp_model_path), + config=text_config, + local_files_only=True, + key_mapping={r"^language_model\.": ""}, + output_loading_info=True, + ) + finally: + deepseek_cls.base_model_prefix = original_base_model_prefix + + unexpected_keys = loading_info["unexpected_keys"] + missing_keys = loading_info["missing_keys"] + mismatched_keys = loading_info["mismatched_keys"] + if unexpected_keys or missing_keys or mismatched_keys: + raise RuntimeError( + "Failed to load the text-only Kimi K2.5 checkpoint slice cleanly. " + f"missing={missing_keys}, unexpected={unexpected_keys}, mismatched={mismatched_keys}" + ) + + model.eval() + tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) + return model, tokenizer + + +def _build_layer_windows(total_layers: int, start: int, end: int): + if not (0 <= start < end <= total_layers): + raise ValueError( + f"Invalid export window start={start}, end={end} for total_layers={total_layers}. " + "Expected: 0 <= start < end <= total_layers." + ) + + windows = [] + if start > 0: + windows.append((0, start)) + + step = end - start + current = start + while current < total_layers: + current_end = min(current + step, total_layers) + windows.append((current, current_end)) + current = current_end + + return windows + + +def _resolve_export_root(onnx_path: Path) -> Path: + parts = list(onnx_path.parts) + if "onnx_layerwise_tmp" in parts: + marker_idx = parts.index("onnx_layerwise_tmp") + return Path(*parts[:marker_idx]) + return onnx_path.parent + + +def main(): + _ensure_pretrained_window_attrs() + text_config = AutoConfig.from_pretrained(str(MODEL_PATH), trust_remote_code=True).text_config + total_layers = getattr(text_config, "num_hidden_layers", None) + if total_layers is None: + raise ValueError("Could not resolve `num_hidden_layers` from text_config.") + windows = _build_layer_windows(total_layers=total_layers, start=EXPORT_START, end=EXPORT_END) + first_onnx_path = None + for start, end in windows: + transformers.modeling_utils.PreTrainedModel._start = start + transformers.modeling_utils.PreTrainedModel._end = end + transformers.modeling_utils.PreTrainedModel._total_layers = total_layers + QEfficient.transformers.models.deepseek_v3.modeling_deepseek.QEffDeepseekV3Model._start = start + QEfficient.transformers.models.deepseek_v3.modeling_deepseek.QEffDeepseekV3Model._end = end + QEfficient.transformers.models.deepseek_v3.modeling_deepseek.QEffDeepseekV3Model._total_layers = total_layers + QEfficient.base.modeling_qeff.QEFFBaseModel._start = start + QEfficient.base.modeling_qeff.QEFFBaseModel._end = end + QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = total_layers + model, tokenizer = load_text_only_kimi(MODEL_PATH, num_hidden_layers=end - start) + qeff_model = QEFFAutoModelForCausalLM( + model, num_kv_heads_repeat=1, qaic_config=qaic_config, torch_dtype=torch.float16 + ) + onnx_path = qeff_model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + mxfp6_matmul=True, + mxint8_kv_cache=False, + num_devices=TS, + num_cores=16, + qaic_config=qaic_config, + use_onnx_subfunctions=True, + ) + if first_onnx_path is None: + first_onnx_path = Path(onnx_path) + + if first_onnx_path is None: + raise RuntimeError("No ONNX path produced during compilation.") + export_root = _resolve_export_root(first_onnx_path) + + if LAYERWISE_MODE == "multiple_qpc": + QEfficient.utils.compile_layerwise(str(export_root)) + QEfficient.utils.inference_pipeline(str(export_root)) + else: + QEfficient.utils.layerwise_pipeline(str(export_root)) + + +if __name__ == "__main__": + main() diff --git a/tests/transformers/models/audio_models/test_audio_embedding_models.py b/tests/transformers/models/audio_models/test_audio_embedding_models.py index 64dc06a595..82c613e557 100644 --- a/tests/transformers/models/audio_models/test_audio_embedding_models.py +++ b/tests/transformers/models/audio_models/test_audio_embedding_models.py @@ -139,7 +139,6 @@ def check_ctc_pytorch_vs_kv_vs_ort_vs_ai100( qnn_config: Optional[str] = None, compare_results: Optional[bool] = False, ): - replace_transformers_quantizers() model_config = {"model_name": model_name} model_config["n_layer"] = n_layer @@ -200,7 +199,6 @@ def check_ctc_pytorch_vs_kv_vs_ort_vs_ai100( @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_full_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_ctc_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, compare_results=True, manual_cleanup=manual_cleanup, num_devices=4 @@ -211,7 +209,6 @@ def test_full_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_few_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=4, manual_cleanup=manual_cleanup) diff --git a/tests/transformers/models/audio_models/test_speech_seq2seq_models.py b/tests/transformers/models/audio_models/test_speech_seq2seq_models.py index 6509d02fe7..0c6fb29087 100644 --- a/tests/transformers/models/audio_models/test_speech_seq2seq_models.py +++ b/tests/transformers/models/audio_models/test_speech_seq2seq_models.py @@ -374,7 +374,6 @@ def check_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100( @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_full_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, compare_results=True, manual_cleanup=manual_cleanup, num_devices=4 @@ -385,7 +384,6 @@ def test_full_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_few_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=4, manual_cleanup=manual_cleanup) diff --git a/tests/transformers/models/causal_lm_models/check_causal_models.py b/tests/transformers/models/causal_lm_models/check_causal_models.py index cc2d074a08..f878acbe73 100644 --- a/tests/transformers/models/causal_lm_models/check_causal_models.py +++ b/tests/transformers/models/causal_lm_models/check_causal_models.py @@ -57,7 +57,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( retain_full_kv: Optional[bool] = None, compare_results: bool = False, ): - torch.manual_seed(42) replace_transformers_quantizers() model_hf = load_hf_causal_lm_model(model_name, num_hidden_layers=n_layer, config=config) diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py b/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py index 4bf067e7c4..0568939cd2 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py @@ -31,7 +31,6 @@ @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -77,7 +76,6 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manu @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -123,7 +121,6 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manua @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -178,7 +175,6 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, man @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -244,7 +240,6 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, m @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -310,7 +305,6 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, ma @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_models.py b/tests/transformers/models/causal_lm_models/test_causal_lm_models.py index 8dbb0915b8..8c61cdc98d 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_models.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_models.py @@ -33,7 +33,6 @@ @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_causal) def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - if model_name in ModelConfig.FULL_MODEL_TESTS_TO_SKIP: pytest.skip(f"Skipping full model test for {model_name} due to resource constraints.") check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -55,7 +54,6 @@ def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup) @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_causal) def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( model_name, @@ -89,7 +87,6 @@ def test_full_causal_lm_pytorch_vs_ort_vs_ai100_cb(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_causal) def test_few_causal_lm_pytorch_vs_ort_vs_ai100_cb(model_name, manual_cleanup): - n_layer = get_custom_n_layers(model_name) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, @@ -104,7 +101,6 @@ def test_few_causal_lm_pytorch_vs_ort_vs_ai100_cb(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_causal) def test_dummy_causal_lm_pytorch_vs_ort_vs_ai100_cb(model_name, manual_cleanup): - custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( model_name, diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_pl1.py b/tests/transformers/models/causal_lm_models/test_causal_lm_pl1.py index b6641d7951..f5f2384e67 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_pl1.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_pl1.py @@ -32,7 +32,6 @@ @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") @@ -52,7 +51,6 @@ def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_ful @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") torch.manual_seed(42) @@ -71,7 +69,6 @@ def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_full @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") @@ -97,7 +94,6 @@ def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_fu @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") torch.manual_seed(42) @@ -117,7 +113,6 @@ def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_ @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") torch.manual_seed(42) @@ -137,7 +132,6 @@ def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_f @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") diff --git a/tests/transformers/models/causal_lm_models/test_causal_tlm_models.py b/tests/transformers/models/causal_lm_models/test_causal_tlm_models.py index 0b488a5037..9d02acbd29 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_tlm_models.py +++ b/tests/transformers/models/causal_lm_models/test_causal_tlm_models.py @@ -32,7 +32,6 @@ @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_full_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS, @@ -46,7 +45,6 @@ def test_full_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanu @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_few_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - n_layer = get_custom_n_layers(model_name) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, @@ -61,7 +59,6 @@ def test_few_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_dummy_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( model_name, @@ -81,7 +78,6 @@ def test_dummy_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_clean @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_full_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS, @@ -96,7 +92,6 @@ def test_full_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cle @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_few_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - n_layer = get_custom_n_layers(model_name) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, @@ -112,7 +107,6 @@ def test_few_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_clea @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_dummy_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( model_name, diff --git a/tests/transformers/models/causal_lm_models/test_fp16_causal_lm.py b/tests/transformers/models/causal_lm_models/test_fp16_causal_lm.py index 2ff366ece2..af8c3b70f0 100644 --- a/tests/transformers/models/causal_lm_models/test_fp16_causal_lm.py +++ b/tests/transformers/models/causal_lm_models/test_fp16_causal_lm.py @@ -127,7 +127,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ai100( @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_full_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_causal_lm_pytorch_vs_kv_vs_ai100( model_name=model_name, torch_dtype=torch.float16, manual_cleanup=manual_cleanup @@ -139,7 +138,6 @@ def test_full_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_few_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) n_layer = get_custom_n_layers(model_name) check_causal_lm_pytorch_vs_kv_vs_ai100( @@ -152,7 +150,6 @@ def test_few_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_dummy_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( diff --git a/tests/transformers/models/image_text_to_text/test_custom_dtype.py b/tests/transformers/models/image_text_to_text/test_custom_dtype.py index 95f62f1ac9..f291c5d12c 100644 --- a/tests/transformers/models/image_text_to_text/test_custom_dtype.py +++ b/tests/transformers/models/image_text_to_text/test_custom_dtype.py @@ -41,7 +41,6 @@ def test_full_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype( model_name, kv_offload, torch_dtype, manual_cleanup ): - if model_name in ModelConfig.SKIPPED_MODELS: pytest.skip("Test skipped for this model due to some issues.") if model_name in ModelConfig.DUAL_QPC_MODELS and not kv_offload: @@ -65,7 +64,6 @@ def test_full_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype( def test_few_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype( model_name, kv_offload, torch_dtype, manual_cleanup ): - if model_name in ModelConfig.SKIPPED_MODELS: pytest.skip("Test skipped for this model due to some issues.") if model_name in ModelConfig.DUAL_QPC_MODELS and not kv_offload: diff --git a/tests/transformers/subfunction/test_causal_lm_blocking_subfunction.py b/tests/transformers/subfunction/test_causal_lm_blocking_subfunction.py index 5c58508385..b3f42e1b0c 100644 --- a/tests/transformers/subfunction/test_causal_lm_blocking_subfunction.py +++ b/tests/transformers/subfunction/test_causal_lm_blocking_subfunction.py @@ -64,7 +64,6 @@ def check_blockedKV_onnx_function_count_with_subfunction( @pytest.mark.feature @pytest.mark.parametrize("model_name", test_models_blockedKV) def test_full_blockedKV_onnx_function_count_with_subfunction(model_name, manual_cleanup): - # Keep model small for test runtime, and avoid CB path (not needed for function count). check_blockedKV_onnx_function_count_with_subfunction(model_name, manual_cleanup=manual_cleanup) @@ -73,7 +72,6 @@ def test_full_blockedKV_onnx_function_count_with_subfunction(model_name, manual_ @pytest.mark.feature @pytest.mark.parametrize("model_name", test_models_blockedKV) def test_few_blockedKV_onnx_function_count_with_subfunction(model_name, manual_cleanup): - # Keep model small for test runtime, and avoid CB path (not needed for function count). n_layer = get_custom_n_layers(model_name) @@ -84,7 +82,6 @@ def test_few_blockedKV_onnx_function_count_with_subfunction(model_name, manual_c @pytest.mark.feature @pytest.mark.parametrize("model_name", test_models_blockedKV) def test_dummy_blockedKV_onnx_function_count_with_subfunction(model_name, manual_cleanup): - # Keep model small for test runtime, and avoid CB path (not needed for function count). hf_config = AutoConfig.from_pretrained( model_name, diff --git a/tests/transformers/subfunction/test_subfunction_vlm.py b/tests/transformers/subfunction/test_subfunction_vlm.py index baf690e638..39e2c6d0ac 100644 --- a/tests/transformers/subfunction/test_subfunction_vlm.py +++ b/tests/transformers/subfunction/test_subfunction_vlm.py @@ -50,7 +50,6 @@ def check_image_text_to_text_subfunction_core( num_devices: int = 1, config: Optional[AutoConfig] = None, ): - img_size = model_config_dict[model_name]["img_size"] img_url = model_config_dict[model_name]["img_url"] query = model_config_dict[model_name]["text_prompt"] @@ -117,7 +116,6 @@ def check_image_text_to_text_subfunction_core( @pytest.mark.parametrize("model_name", test_mm_models) @pytest.mark.parametrize("kv_offload", [True]) def test_full_image_text_to_text_subfunction(model_name, kv_offload, manual_cleanup): - torch.manual_seed(42) check_image_text_to_text_subfunction_core(model_name, kv_offload=kv_offload, manual_cleanup=manual_cleanup) @@ -127,7 +125,6 @@ def test_full_image_text_to_text_subfunction(model_name, kv_offload, manual_clea @pytest.mark.parametrize("model_name", test_mm_models) @pytest.mark.parametrize("kv_offload", [True]) def test_few_image_text_to_text_subfunction(model_name, kv_offload, manual_cleanup): - torch.manual_seed(42) check_image_text_to_text_subfunction_core( model_name, @@ -142,7 +139,6 @@ def test_few_image_text_to_text_subfunction(model_name, kv_offload, manual_clean @pytest.mark.parametrize("model_name", test_mm_models) @pytest.mark.parametrize("kv_offload", [True]) def test_dummy_image_text_to_text_subfunction(model_name, kv_offload, manual_cleanup): - torch.manual_seed(42) hf_config = AutoConfig.from_pretrained( model_name, trust_remote_code=True, **model_config_dict[model_name].get("additional_params", {})